Skip to content

Commit b6c4f72

Browse files
MrlolDevykolliestanley
authored
Discord bot: Pythia model (LAION-AI#2831)
Changes: - [x] Added pythia model to discord bot - [x] Added a button in the answers that shows the model used - [x] Added a selector when click the model button to change the model --------- Co-authored-by: Yannic Kilcher <yk@users.noreply.github.com> Co-authored-by: Oliver Stanley <olivergestanley@gmail.com>
1 parent dfbafa5 commit b6c4f72

File tree

7 files changed

+257
-62
lines changed

7 files changed

+257
-62
lines changed

discord-bots/oa-bot-js/.env.sample

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ OA_APIURL=OpenAssistant API url
55
INFERENCE_SERVER_API_KEY=
66
INFERENCE_SERVER_HOST=
77
REDIS_PASSWORD=
8-
DEFAULT_MODEL=default model if user doesn't specify one
8+
DEFAULT_MODEL=default model if user does not specify one
9+
HUGGINGFACE_TOKEN=huggingface token

discord-bots/oa-bot-js/package.json

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"license": "Apache-2.0",
1717
"private": true,
1818
"dependencies": {
19+
"@huggingface/inference": "^2.0.0",
1920
"axios": "^1.3.5",
2021
"chalk": "^5.2.0",
2122
"discord.js": "^14.7.1",

discord-bots/oa-bot-js/src/commands/chat.ts

+104-59
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ import {
44
ButtonStyle,
55
ButtonBuilder,
66
} from "discord.js";
7-
import { createInferenceClient } from "../modules/inference/client.js";
87
import redis from "../modules/redis.js";
8+
import chatFN from "../modules/chat.js";
99

1010
export default {
1111
disablePing: null,
@@ -23,10 +23,16 @@ export default {
2323
.setName("model")
2424
.setDescription("The model you want to use for the AI.")
2525
.setRequired(false)
26-
.addChoices({
27-
name: "OA_SFT_Llama_30B",
28-
value: "OA_SFT_Llama_30B",
29-
})
26+
.addChoices(
27+
{
28+
name: "OA_SFT_Llama_30B",
29+
value: "OA_SFT_Llama_30B",
30+
},
31+
{
32+
name: "oasst-sft-4-pythia-12b",
33+
value: "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
34+
}
35+
)
3036
)
3137
.addStringOption((option) =>
3238
option
@@ -73,71 +79,110 @@ export default {
7379
model = interaction.options.getString("model");
7480
preset = interaction.options.getString("preset");
7581
}
76-
if (!model)
77-
model = process.env.OPEN_ASSISTANT_DEFAULT_MODEL || "OA_SFT_Llama_30B";
82+
if (!model) {
83+
let userModel = await redis.get(`model_${interaction.user.id}`);
84+
if (userModel) {
85+
model = userModel;
86+
} else {
87+
model = process.env.OPEN_ASSISTANT_DEFAULT_MODEL || "OA_SFT_Llama_30B";
88+
redis.set(`model_${interaction.user.id}`, model);
89+
}
90+
} else {
91+
redis.set(`model_${interaction.user.id}`, model);
92+
}
7893
if (!preset) preset = "k50";
7994
// sleep for 30s
8095

81-
const OA = await createInferenceClient(
82-
interaction.user.username,
83-
interaction.user.id
84-
);
96+
if (model.includes("Llama")) {
97+
try {
98+
let chat = await redis.get(`chat_${interaction.user.id}`);
99+
let chatId = chat ? chat.split("_")[0] : null;
100+
let parentId = chat ? chat.split("_")[1] : null;
101+
let { assistant_message, OA } = await chatFN(
102+
model,
103+
interaction.user,
104+
message,
105+
chatId,
106+
parentId,
107+
presets,
108+
preset
109+
);
110+
await redis.set(
111+
`chat_${interaction.user.id}`,
112+
`${chatId}_${assistant_message.id}`
113+
);
85114

86-
try {
87-
let chat = await redis.get(`chat_${interaction.user.id}`);
88-
let chatId = chat ? chat.split("_")[0] : null;
89-
let parentId = chat ? chat.split("_")[1] : null;
90-
if (!chatId) {
91-
let chat = await OA.create_chat();
92-
chatId = chat.id;
115+
const row = new ActionRowBuilder().addComponents(
116+
new ButtonBuilder()
117+
.setStyle(ButtonStyle.Secondary)
118+
.setLabel(`👍`)
119+
.setCustomId(`vote_${assistant_message.id}_up`),
120+
new ButtonBuilder()
121+
.setStyle(ButtonStyle.Secondary)
122+
.setLabel(`👎`)
123+
.setCustomId(`vote_${assistant_message.id}_down`),
124+
new ButtonBuilder()
125+
.setStyle(ButtonStyle.Secondary)
126+
.setDisabled(false)
127+
.setLabel(
128+
`${model.replaceAll("OpenAssistant/", "").replaceAll("_", "")}`
129+
)
130+
.setCustomId(`model_${assistant_message.id}`)
131+
);
132+
// using events
133+
let events = await OA.stream_events({
134+
chat_id: chatId,
135+
message_id: assistant_message.id,
136+
});
137+
events.on("data", async (c) => {
138+
/* let string = JSON.parse(c);
139+
if (!string.queue_position) {
140+
await commandType.reply(interaction, {
141+
content: `${string} <a:loading:1051419341914132554>`,
142+
components: [],
143+
});
144+
}*/
145+
});
146+
events.on("end", async (c) => {
147+
let msg = await OA.get_message(chatId, assistant_message.id);
148+
await commandType.reply(interaction, {
149+
content: msg.content,
150+
components: [row],
151+
});
152+
});
153+
} catch (err: any) {
154+
console.log(err);
155+
// get details of the error
156+
await commandType.reply(
157+
interaction,
158+
`There was an error while executing this command! ${err.message}`
159+
);
93160
}
94-
let prompter_message = await OA.post_prompter_message({
95-
chat_id: chatId,
96-
content: message,
97-
parent_id: parentId,
98-
});
99-
100-
let assistant_message = await OA.post_assistant_message({
101-
chat_id: chatId,
102-
model_config_name: model,
103-
parent_id: prompter_message.id,
104-
sampling_parameters: presets[preset],
105-
});
106-
await redis.set(
107-
`chat_${interaction.user.id}`,
108-
`${chatId}_${assistant_message.id}`
161+
} else {
162+
let { assistant_message, error } = await chatFN(
163+
model,
164+
interaction.user,
165+
message
109166
);
110-
167+
if (error) {
168+
await commandType.reply(
169+
interaction,
170+
`There was an error while executing this command! ${error}`
171+
);
172+
}
111173
const row = new ActionRowBuilder().addComponents(
112174
new ButtonBuilder()
113175
.setStyle(ButtonStyle.Secondary)
114-
.setLabel(`👍`)
115-
.setCustomId(`vote_${assistant_message.id}_up`),
116-
new ButtonBuilder()
117-
.setStyle(ButtonStyle.Secondary)
118-
.setLabel(`👎`)
119-
.setCustomId(`vote_${assistant_message.id}_down`)
176+
.setDisabled(false)
177+
.setLabel(
178+
`${model.replaceAll("OpenAssistant/", "").replaceAll("_", "")}`
179+
)
180+
.setCustomId(`model_${interaction.user.id}`)
120181
);
121-
// using events
122-
let events = await OA.stream_events({
123-
chat_id: chatId,
124-
message_id: assistant_message.id,
182+
await commandType.reply(interaction, {
183+
content: assistant_message,
184+
components: [row],
125185
});
126-
events.on("data", async (c) => {});
127-
events.on("end", async (c) => {
128-
let msg = await OA.get_message(chatId, assistant_message.id);
129-
await commandType.reply(interaction, {
130-
content: msg.content,
131-
components: [row],
132-
});
133-
});
134-
} catch (err: any) {
135-
console.log(err);
136-
// get details of the error
137-
await commandType.reply(
138-
interaction,
139-
`There was an error while executing this command! ${err.message}`
140-
);
141186
}
142187
},
143188
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import {
2+
SlashCommandBuilder,
3+
ActionRowBuilder,
4+
StringSelectMenuBuilder,
5+
StringSelectMenuOptionBuilder,
6+
} from "discord.js";
7+
import { createInferenceClient } from "../modules/inference/client.js";
8+
import redis from "../modules/redis.js";
9+
10+
export default {
11+
data: {
12+
customId: "model",
13+
description: "Switch to another model.",
14+
},
15+
async execute(interaction, client, userId) {
16+
if (interaction.user.id != userId)
17+
return interaction.reply({
18+
content: "You don't have permission to do this.",
19+
ephemeral: true,
20+
});
21+
// model selector
22+
let row = new ActionRowBuilder().addComponents(
23+
new StringSelectMenuBuilder()
24+
.setCustomId("modelselect")
25+
.setPlaceholder("Select a model")
26+
.setMinValues(1)
27+
.setMaxValues(1)
28+
.addOptions(
29+
new StringSelectMenuOptionBuilder()
30+
.setLabel("OA_SFT_Llama_30B")
31+
.setDescription("Llama (default)")
32+
.setValue("OA_SFT_Llama_30B"),
33+
new StringSelectMenuOptionBuilder()
34+
.setLabel("oasst-sft-4-pythia-12b")
35+
.setDescription("Pythia")
36+
.setValue("OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5")
37+
)
38+
);
39+
await interaction.reply({
40+
content: "Select a model.",
41+
components: [row],
42+
ephemeral: true,
43+
});
44+
},
45+
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import {
2+
SlashCommandBuilder,
3+
ActionRowBuilder,
4+
StringSelectMenuBuilder,
5+
StringSelectMenuOptionBuilder,
6+
} from "discord.js";
7+
import { createInferenceClient } from "../modules/inference/client.js";
8+
import redis from "../modules/redis.js";
9+
10+
export default {
11+
data: {
12+
customId: "modelselect",
13+
description: "Switch to another model.",
14+
},
15+
async execute(interaction, client) {
16+
// get selected value
17+
let model = interaction.values[0];
18+
// set model
19+
await interaction.deferReply({
20+
ephemeral: true,
21+
});
22+
redis.set(`model_${interaction.user.id}`, model);
23+
await interaction.editReply({
24+
content: `Model set to ${model}.`,
25+
ephemeral: true,
26+
});
27+
},
28+
};
+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import { createInferenceClient } from "../modules/inference/client.js";
2+
import { HfInference } from "@huggingface/inference";
3+
const hf = new HfInference(process.env.HUGGINGFACE_TOKEN);
4+
5+
export default async function chat(
6+
model,
7+
user,
8+
message,
9+
chatId?,
10+
parentId?,
11+
presets?,
12+
preset?
13+
) {
14+
if (model.includes("Llama")) {
15+
const OA = await createInferenceClient(user.username, user.id);
16+
if (!chatId) {
17+
let chat = await OA.create_chat();
18+
chatId = chat.id;
19+
}
20+
let prompter_message = await OA.post_prompter_message({
21+
chat_id: chatId,
22+
content: message,
23+
parent_id: parentId,
24+
});
25+
26+
let assistant_message = await OA.post_assistant_message({
27+
chat_id: chatId,
28+
model_config_name: model,
29+
parent_id: prompter_message.id,
30+
sampling_parameters: presets[preset],
31+
});
32+
return { assistant_message, OA };
33+
} else {
34+
let result = await huggingface(
35+
model,
36+
`<|prompter|>${message}<|endoftext|>\n<|assistant|>`
37+
);
38+
if (result.error) {
39+
return { error: result.error };
40+
}
41+
return { assistant_message: result.response };
42+
}
43+
}
44+
45+
export async function huggingface(model, input) {
46+
try {
47+
let oldText;
48+
let loop = true;
49+
while (loop) {
50+
let response = await hf.textGeneration({
51+
model: model,
52+
inputs: input,
53+
});
54+
let answer = response.generated_text.split("<|assistant|>")[1];
55+
if (answer == oldText) {
56+
loop = false;
57+
} else {
58+
if (!oldText) {
59+
oldText = answer;
60+
input += answer;
61+
} else {
62+
oldText += answer;
63+
input += answer;
64+
}
65+
}
66+
}
67+
68+
return { response: oldText };
69+
} catch (err: any) {
70+
console.log(err);
71+
return {
72+
error: err.message,
73+
};
74+
}
75+
}

discord-bots/oa-bot-js/src/modules/open-assistant/interactions/init.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ export async function initInteraction(interaction, translation, lang) {
1313
.setFooter({ text: `${getLocaleDisplayName(lang)}` })
1414
.setTitle("Open assistant")
1515
.setDescription(`${translation["conversational"]}`)
16-
.setURL("https://open-assistant.io/?ref=turing")
16+
.setURL("https://open-assistant.io/?ref=discordbot")
1717
.setThumbnail("https://open-assistant.io/images/logos/logo.png");
1818

1919
const row = new ActionRowBuilder().addComponents(
@@ -25,7 +25,7 @@ export async function initInteraction(interaction, translation, lang) {
2525
.setLabel(translation.grab_a_task)
2626
.setCustomId(`oa_tasks_n_${interaction.user.id}`)
2727
.setStyle(ButtonStyle.Primary)
28-
.setDisabled(false),
28+
.setDisabled(true),
2929
new ButtonBuilder()
3030
.setLabel("Change language")
3131
.setCustomId(`oa_lang-btn_n_${interaction.user.id}`)

0 commit comments

Comments
 (0)