4
4
ButtonStyle ,
5
5
ButtonBuilder ,
6
6
} from "discord.js" ;
7
- import { createInferenceClient } from "../modules/inference/client.js" ;
8
7
import redis from "../modules/redis.js" ;
8
+ import chatFN from "../modules/chat.js" ;
9
9
10
10
export default {
11
11
disablePing : null ,
@@ -23,10 +23,16 @@ export default {
23
23
. setName ( "model" )
24
24
. setDescription ( "The model you want to use for the AI." )
25
25
. setRequired ( false )
26
- . addChoices ( {
27
- name : "OA_SFT_Llama_30B" ,
28
- value : "OA_SFT_Llama_30B" ,
29
- } )
26
+ . addChoices (
27
+ {
28
+ name : "OA_SFT_Llama_30B" ,
29
+ value : "OA_SFT_Llama_30B" ,
30
+ } ,
31
+ {
32
+ name : "oasst-sft-4-pythia-12b" ,
33
+ value : "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5" ,
34
+ }
35
+ )
30
36
)
31
37
. addStringOption ( ( option ) =>
32
38
option
@@ -73,71 +79,110 @@ export default {
73
79
model = interaction . options . getString ( "model" ) ;
74
80
preset = interaction . options . getString ( "preset" ) ;
75
81
}
76
- if ( ! model )
77
- model = process . env . OPEN_ASSISTANT_DEFAULT_MODEL || "OA_SFT_Llama_30B" ;
82
+ if ( ! model ) {
83
+ let userModel = await redis . get ( `model_${ interaction . user . id } ` ) ;
84
+ if ( userModel ) {
85
+ model = userModel ;
86
+ } else {
87
+ model = process . env . OPEN_ASSISTANT_DEFAULT_MODEL || "OA_SFT_Llama_30B" ;
88
+ redis . set ( `model_${ interaction . user . id } ` , model ) ;
89
+ }
90
+ } else {
91
+ redis . set ( `model_${ interaction . user . id } ` , model ) ;
92
+ }
78
93
if ( ! preset ) preset = "k50" ;
79
94
// sleep for 30s
80
95
81
- const OA = await createInferenceClient (
82
- interaction . user . username ,
83
- interaction . user . id
84
- ) ;
96
+ if ( model . includes ( "Llama" ) ) {
97
+ try {
98
+ let chat = await redis . get ( `chat_${ interaction . user . id } ` ) ;
99
+ let chatId = chat ? chat . split ( "_" ) [ 0 ] : null ;
100
+ let parentId = chat ? chat . split ( "_" ) [ 1 ] : null ;
101
+ let { assistant_message, OA } = await chatFN (
102
+ model ,
103
+ interaction . user ,
104
+ message ,
105
+ chatId ,
106
+ parentId ,
107
+ presets ,
108
+ preset
109
+ ) ;
110
+ await redis . set (
111
+ `chat_${ interaction . user . id } ` ,
112
+ `${ chatId } _${ assistant_message . id } `
113
+ ) ;
85
114
86
- try {
87
- let chat = await redis . get ( `chat_${ interaction . user . id } ` ) ;
88
- let chatId = chat ? chat . split ( "_" ) [ 0 ] : null ;
89
- let parentId = chat ? chat . split ( "_" ) [ 1 ] : null ;
90
- if ( ! chatId ) {
91
- let chat = await OA . create_chat ( ) ;
92
- chatId = chat . id ;
115
+ const row = new ActionRowBuilder ( ) . addComponents (
116
+ new ButtonBuilder ( )
117
+ . setStyle ( ButtonStyle . Secondary )
118
+ . setLabel ( `👍` )
119
+ . setCustomId ( `vote_${ assistant_message . id } _up` ) ,
120
+ new ButtonBuilder ( )
121
+ . setStyle ( ButtonStyle . Secondary )
122
+ . setLabel ( `👎` )
123
+ . setCustomId ( `vote_${ assistant_message . id } _down` ) ,
124
+ new ButtonBuilder ( )
125
+ . setStyle ( ButtonStyle . Secondary )
126
+ . setDisabled ( false )
127
+ . setLabel (
128
+ `${ model . replaceAll ( "OpenAssistant/" , "" ) . replaceAll ( "_" , "" ) } `
129
+ )
130
+ . setCustomId ( `model_${ assistant_message . id } ` )
131
+ ) ;
132
+ // using events
133
+ let events = await OA . stream_events ( {
134
+ chat_id : chatId ,
135
+ message_id : assistant_message . id ,
136
+ } ) ;
137
+ events . on ( "data" , async ( c ) => {
138
+ /* let string = JSON.parse(c);
139
+ if (!string.queue_position) {
140
+ await commandType.reply(interaction, {
141
+ content: `${string} <a:loading:1051419341914132554>`,
142
+ components: [],
143
+ });
144
+ }*/
145
+ } ) ;
146
+ events . on ( "end" , async ( c ) => {
147
+ let msg = await OA . get_message ( chatId , assistant_message . id ) ;
148
+ await commandType . reply ( interaction , {
149
+ content : msg . content ,
150
+ components : [ row ] ,
151
+ } ) ;
152
+ } ) ;
153
+ } catch ( err : any ) {
154
+ console . log ( err ) ;
155
+ // get details of the error
156
+ await commandType . reply (
157
+ interaction ,
158
+ `There was an error while executing this command! ${ err . message } `
159
+ ) ;
93
160
}
94
- let prompter_message = await OA . post_prompter_message ( {
95
- chat_id : chatId ,
96
- content : message ,
97
- parent_id : parentId ,
98
- } ) ;
99
-
100
- let assistant_message = await OA . post_assistant_message ( {
101
- chat_id : chatId ,
102
- model_config_name : model ,
103
- parent_id : prompter_message . id ,
104
- sampling_parameters : presets [ preset ] ,
105
- } ) ;
106
- await redis . set (
107
- `chat_${ interaction . user . id } ` ,
108
- `${ chatId } _${ assistant_message . id } `
161
+ } else {
162
+ let { assistant_message, error } = await chatFN (
163
+ model ,
164
+ interaction . user ,
165
+ message
109
166
) ;
110
-
167
+ if ( error ) {
168
+ await commandType . reply (
169
+ interaction ,
170
+ `There was an error while executing this command! ${ error } `
171
+ ) ;
172
+ }
111
173
const row = new ActionRowBuilder ( ) . addComponents (
112
174
new ButtonBuilder ( )
113
175
. setStyle ( ButtonStyle . Secondary )
114
- . setLabel ( `👍` )
115
- . setCustomId ( `vote_${ assistant_message . id } _up` ) ,
116
- new ButtonBuilder ( )
117
- . setStyle ( ButtonStyle . Secondary )
118
- . setLabel ( `👎` )
119
- . setCustomId ( `vote_${ assistant_message . id } _down` )
176
+ . setDisabled ( false )
177
+ . setLabel (
178
+ `${ model . replaceAll ( "OpenAssistant/" , "" ) . replaceAll ( "_" , "" ) } `
179
+ )
180
+ . setCustomId ( `model_${ interaction . user . id } ` )
120
181
) ;
121
- // using events
122
- let events = await OA . stream_events ( {
123
- chat_id : chatId ,
124
- message_id : assistant_message . id ,
182
+ await commandType . reply ( interaction , {
183
+ content : assistant_message ,
184
+ components : [ row ] ,
125
185
} ) ;
126
- events . on ( "data" , async ( c ) => { } ) ;
127
- events . on ( "end" , async ( c ) => {
128
- let msg = await OA . get_message ( chatId , assistant_message . id ) ;
129
- await commandType . reply ( interaction , {
130
- content : msg . content ,
131
- components : [ row ] ,
132
- } ) ;
133
- } ) ;
134
- } catch ( err : any ) {
135
- console . log ( err ) ;
136
- // get details of the error
137
- await commandType . reply (
138
- interaction ,
139
- `There was an error while executing this command! ${ err . message } `
140
- ) ;
141
186
}
142
187
} ,
143
188
} ;
0 commit comments