Skip to content

Commit 3d28a98

Browse files
author
Vaijanath Rao
committed
adding support for messages.
1 parent 335875c commit 3d28a98

File tree

6 files changed

+92
-3
lines changed

6 files changed

+92
-3
lines changed

pom.xml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
<groupId>de.kherud</groupId>
77
<artifactId>llama</artifactId>
8-
<version>4.0.1</version>
8+
<version>4.0.0</version>
99
<packaging>jar</packaging>
1010

1111
<name>${project.groupId}:${project.artifactId}</name>
@@ -65,7 +65,11 @@
6565
<version>24.1.0</version>
6666
<scope>compile</scope>
6767
</dependency>
68-
68+
<dependency>
69+
<groupId>com.fasterxml.jackson.core</groupId>
70+
<artifactId>jackson-databind</artifactId>
71+
<version>2.16.0</version> <!-- Use the latest version -->
72+
</dependency>
6973
</dependencies>
7074

7175
<build>

src/main/cpp/jllama.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,20 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo
786786

787787
}
788788

789+
JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *env, jobject obj, jstring jparams){
790+
jlong server_handle = env->GetLongField(obj, f_model_pointer);
791+
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
792+
793+
std::string c_params = parse_jstring(env, jparams);
794+
json data = json::parse(c_params);
795+
796+
json templateData = oaicompat_completion_params_parse(data, ctx_server->params_base.use_jinja, ctx_server->params_base.reasoning_format, ctx_server->chat_templates.get());
797+
std::string tok_str = templateData.at("prompt");
798+
jstring jtok_str = env->NewStringUTF(tok_str.c_str());
799+
800+
return jtok_str;
801+
}
802+
789803
JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt) {
790804
jlong server_handle = env->GetLongField(obj, f_model_pointer);
791805
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)

src/main/cpp/jllama.h

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/main/java/de/kherud/llama/InferenceParameters.java

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
package de.kherud.llama;
22

33
import java.util.Collection;
4+
import java.util.List;
45
import java.util.Map;
56

7+
import com.fasterxml.jackson.databind.ObjectMapper;
8+
import com.fasterxml.jackson.databind.node.ArrayNode;
9+
import com.fasterxml.jackson.databind.node.ObjectNode;
10+
611
import de.kherud.llama.args.MiroStat;
712
import de.kherud.llama.args.Sampler;
813

@@ -12,6 +17,9 @@
1217
* {@link LlamaModel#complete(InferenceParameters)}.
1318
*/
1419
public final class InferenceParameters extends JsonParameters {
20+
21+
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); // Reusable ObjectMapper
22+
1523

1624
private static final String PARAM_PROMPT = "prompt";
1725
private static final String PARAM_INPUT_PREFIX = "input_prefix";
@@ -47,6 +55,7 @@ public final class InferenceParameters extends JsonParameters {
4755
private static final String PARAM_STREAM = "stream";
4856
private static final String PARAM_USE_CHAT_TEMPLATE = "use_chat_template";
4957
private static final String PARAM_USE_JINJA = "use_jinja";
58+
private static final String PARAM_MESSAGES = "messages";
5059

5160
public InferenceParameters(String prompt) {
5261
// we always need a prompt
@@ -493,7 +502,41 @@ public InferenceParameters setUseChatTemplate(boolean useChatTemplate) {
493502
return this;
494503
}
495504

496-
505+
/**
506+
* Set the messages for chat-based inference.
507+
* - Allows **only one** system message.
508+
* - Allows **one or more** user/assistant messages.
509+
*/
510+
public InferenceParameters setMessages(String systemMessage, List<Pair<String, String>> messages) {
511+
ArrayNode messagesArray = OBJECT_MAPPER.createArrayNode();
512+
513+
// Add system message (if provided)
514+
if (systemMessage != null && !systemMessage.isEmpty()) {
515+
ObjectNode systemObj = OBJECT_MAPPER.createObjectNode();
516+
systemObj.put("role", "system");
517+
systemObj.put("content", systemMessage);
518+
messagesArray.add(systemObj);
519+
}
520+
521+
// Add user/assistant messages
522+
for (Pair<String, String> message : messages) {
523+
String role = message.getKey();
524+
String content = message.getValue();
525+
526+
if (!role.equals("user") && !role.equals("assistant")) {
527+
throw new IllegalArgumentException("Invalid role: " + role + ". Role must be 'user' or 'assistant'.");
528+
}
529+
530+
ObjectNode messageObj = OBJECT_MAPPER.createObjectNode();
531+
messageObj.put("role", role);
532+
messageObj.put("content", content);
533+
messagesArray.add(messageObj);
534+
}
535+
536+
// Convert ArrayNode to a JSON string and store it in parameters
537+
parameters.put(PARAM_MESSAGES, messagesArray.toString());
538+
return this;
539+
}
497540

498541

499542

src/main/java/de/kherud/llama/LlamaModel.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,4 +163,9 @@ public List<Pair<String, Float>> rerank(boolean reRank, String query, String ...
163163
}
164164

165165
public native LlamaOutput rerank(String query, String... documents);
166+
167+
public String applyTemplate(InferenceParameters parameters) {
168+
return applyTemplate(parameters.toString());
169+
}
170+
public native String applyTemplate(String parametersJson);
166171
}

src/test/java/de/kherud/llama/LlamaModelTest.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,4 +316,20 @@ public void testJsonSchemaToGrammar() {
316316
String actualGrammar = LlamaModel.jsonSchemaToGrammar(schema);
317317
Assert.assertEquals(expectedGrammar, actualGrammar);
318318
}
319+
320+
@Test
321+
public void testTemplate() {
322+
323+
List<Pair<String, String>> userMessages = new ArrayList<>();
324+
userMessages.add(new Pair<>("user", "What is the best book?"));
325+
userMessages.add(new Pair<>("assistant", "It depends on your interests. Do you like fiction or non-fiction?"));
326+
327+
InferenceParameters params = new InferenceParameters("A book recommendation system.")
328+
.setMessages("Book", userMessages)
329+
.setTemperature(0.95f)
330+
.setStopStrings("\"\"\"")
331+
.setNPredict(nPredict)
332+
.setSeed(42);
333+
Assert.assertEquals(model.applyTemplate(params), "<|im_start|>system\nBook<|im_end|>\n<|im_start|>user\nWhat is the best book?<|im_end|>\n<|im_start|>assistant\nIt depends on your interests. Do you like fiction or non-fiction?<|im_end|>\n<|im_start|>assistant\n");
334+
}
319335
}

0 commit comments

Comments
 (0)