@@ -112,13 +112,15 @@ char **parse_string_array(JNIEnv *env, const jobjectArray string_array, const js
112
112
return result;
113
113
}
114
114
115
- std::vector<std::string> parse_string_array_for_rerank (JNIEnv *env, const jobjectArray string_array, const jsize length) {
115
+ std::vector<std::string> parse_string_array_for_rerank (JNIEnv *env, const jobjectArray string_array,
116
+ const jsize length) {
116
117
std::vector<std::string> result;
117
118
result.reserve (length); // Reserve memory for efficiency
118
119
119
120
for (jsize i = 0 ; i < length; i++) {
120
121
jstring javaString = static_cast <jstring>(env->GetObjectArrayElement (string_array, i));
121
- if (javaString == nullptr ) continue ;
122
+ if (javaString == nullptr )
123
+ continue ;
122
124
123
125
const char *cString = env->GetStringUTFChars (javaString, nullptr );
124
126
if (cString != nullptr ) {
@@ -259,7 +261,6 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {
259
261
cc_integer = env->GetMethodID (c_integer, " <init>" , " (I)V" );
260
262
cc_float = env->GetMethodID (c_float, " <init>" , " (F)V" );
261
263
262
-
263
264
if (!(cc_output && cc_hash_map && cc_integer && cc_float)) {
264
265
goto error;
265
266
}
@@ -663,12 +664,11 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
663
664
env->ThrowNew (c_llama_error, response.c_str ());
664
665
return nullptr ;
665
666
}
666
-
667
+
667
668
if (result->is_stop ()) {
668
669
ctx_server->queue_results .remove_waiting_task_id (id_task);
669
670
}
670
671
671
-
672
672
const auto out_res = result->to_json ();
673
673
674
674
// Extract "embedding" as a vector of vectors (2D array)
@@ -704,100 +704,99 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
704
704
return j_embedding;
705
705
}
706
706
707
- JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank (JNIEnv *env, jobject obj, jstring jprompt, jobjectArray documents) {
707
+ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank (JNIEnv *env, jobject obj, jstring jprompt,
708
+ jobjectArray documents) {
708
709
jlong server_handle = env->GetLongField (obj, f_model_pointer);
709
710
auto *ctx_server = reinterpret_cast <server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
710
711
711
- if (!ctx_server->params_base .reranking || ctx_server->params_base .embedding ) {
712
- env->ThrowNew (c_llama_error,
712
+ if (!ctx_server->params_base .reranking || ctx_server->params_base .embedding ) {
713
+ env->ThrowNew (c_llama_error,
713
714
" This server does not support reranking. Start it with `--reranking` and without `--embedding`" );
714
- return nullptr ;
715
+ return nullptr ;
715
716
}
716
-
717
717
718
718
const std::string prompt = parse_jstring (env, jprompt);
719
719
720
-
721
-
722
720
const auto tokenized_query = tokenize_mixed (ctx_server->vocab , prompt, true , true );
723
-
721
+
724
722
json responses = json::array ();
725
723
bool error = false ;
726
-
727
- std::vector<server_task> tasks;
728
- const jsize argc = env->GetArrayLength (documents);
729
- std::vector<std::string> documentsArray = parse_string_array_for_rerank (env, documents, argc);
730
-
731
- std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts (ctx_server->vocab , documentsArray, true , true );
732
-
733
- tasks.reserve (tokenized_docs.size ());
734
- for (size_t i = 0 ; i < tokenized_docs.size (); i++) {
735
- server_task task = server_task (SERVER_TASK_TYPE_RERANK);
736
- task. id = ctx_server->queue_tasks .get_new_id ();
737
- task. index = i;
738
- task.prompt_tokens = format_rerank (ctx_server->vocab , tokenized_query, tokenized_docs[i]);
739
- tasks.push_back (task);
740
- }
741
- ctx_server->queue_results .add_waiting_tasks (tasks);
742
- ctx_server->queue_tasks .post (tasks);
743
-
744
- // get the result
745
- std::unordered_set<int > task_ids = server_task::get_list_id (tasks);
746
- std::vector<server_task_result_ptr> results (task_ids.size ());
747
-
748
- // Create a new HashMap instance
749
- jobject o_probabilities = env->NewObject (c_hash_map, cc_hash_map);
750
- if (o_probabilities == nullptr ) {
751
- env->ThrowNew (c_llama_error, " Failed to create HashMap object." );
752
- return nullptr ;
753
- }
754
-
755
- for (int i = 0 ; i < (int )task_ids.size (); i++) {
756
- server_task_result_ptr result = ctx_server->queue_results .recv (task_ids);
757
- if (result->is_error ()) {
758
- std::string response = result->to_json ()[" message" ].get <std::string>();
759
- for (const int id_task : task_ids) {
760
- ctx_server->queue_results .remove_waiting_task_id (id_task);
761
- }
762
- env->ThrowNew (c_llama_error, response.c_str ());
763
- return nullptr ;
764
- }
765
-
766
- const auto out_res = result->to_json ();
767
-
768
- if (result->is_stop ()) {
769
- for (const int id_task : task_ids) {
770
- ctx_server->queue_results .remove_waiting_task_id (id_task);
771
- }
772
- }
773
-
774
- int index = out_res[" index" ].get <int >();
775
- float score = out_res[" score" ].get <float >();
776
- std::string tok_str = documentsArray[index];
777
- jstring jtok_str = env->NewStringUTF (tok_str.c_str ());
778
-
779
- jobject jprob = env->NewObject (c_float, cc_float, score);
780
- env->CallObjectMethod (o_probabilities, m_map_put, jtok_str, jprob);
781
- env->DeleteLocalRef (jtok_str);
782
- env->DeleteLocalRef (jprob);
783
- }
724
+
725
+ std::vector<server_task> tasks;
726
+ const jsize argc = env->GetArrayLength (documents);
727
+ std::vector<std::string> documentsArray = parse_string_array_for_rerank (env, documents, argc);
728
+
729
+ std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts (ctx_server->vocab , documentsArray, true , true );
730
+
731
+ tasks.reserve (tokenized_docs.size ());
732
+ for (size_t i = 0 ; i < tokenized_docs.size (); i++) {
733
+ server_task task = server_task (SERVER_TASK_TYPE_RERANK);
734
+ task. id = ctx_server->queue_tasks .get_new_id ();
735
+ task. index = i;
736
+ task.prompt_tokens = format_rerank (ctx_server->vocab , tokenized_query, tokenized_docs[i]);
737
+ tasks.push_back (task);
738
+ }
739
+ ctx_server->queue_results .add_waiting_tasks (tasks);
740
+ ctx_server->queue_tasks .post (tasks);
741
+
742
+ // get the result
743
+ std::unordered_set<int > task_ids = server_task::get_list_id (tasks);
744
+ std::vector<server_task_result_ptr> results (task_ids.size ());
745
+
746
+ // Create a new HashMap instance
747
+ jobject o_probabilities = env->NewObject (c_hash_map, cc_hash_map);
748
+ if (o_probabilities == nullptr ) {
749
+ env->ThrowNew (c_llama_error, " Failed to create HashMap object." );
750
+ return nullptr ;
751
+ }
752
+
753
+ for (int i = 0 ; i < (int )task_ids.size (); i++) {
754
+ server_task_result_ptr result = ctx_server->queue_results .recv (task_ids);
755
+ if (result->is_error ()) {
756
+ std::string response = result->to_json ()[" message" ].get <std::string>();
757
+ for (const int id_task : task_ids) {
758
+ ctx_server->queue_results .remove_waiting_task_id (id_task);
759
+ }
760
+ env->ThrowNew (c_llama_error, response.c_str ());
761
+ return nullptr ;
762
+ }
763
+
764
+ const auto out_res = result->to_json ();
765
+
766
+ if (result->is_stop ()) {
767
+ for (const int id_task : task_ids) {
768
+ ctx_server->queue_results .remove_waiting_task_id (id_task);
769
+ }
770
+ }
771
+
772
+ int index = out_res[" index" ].get <int >();
773
+ float score = out_res[" score" ].get <float >();
774
+ std::string tok_str = documentsArray[index];
775
+ jstring jtok_str = env->NewStringUTF (tok_str.c_str ());
776
+
777
+ jobject jprob = env->NewObject (c_float, cc_float, score);
778
+ env->CallObjectMethod (o_probabilities, m_map_put, jtok_str, jprob);
779
+ env->DeleteLocalRef (jtok_str);
780
+ env->DeleteLocalRef (jprob);
781
+ }
784
782
jbyteArray jbytes = parse_jbytes (env, prompt);
785
- return env->NewObject (c_output, cc_output, jbytes, o_probabilities, true );
786
-
783
+ return env->NewObject (c_output, cc_output, jbytes, o_probabilities, true );
787
784
}
788
785
789
- JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate (JNIEnv *env, jobject obj, jstring jparams){
790
- jlong server_handle = env->GetLongField (obj, f_model_pointer);
786
+ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate (JNIEnv *env, jobject obj, jstring jparams) {
787
+ jlong server_handle = env->GetLongField (obj, f_model_pointer);
791
788
auto *ctx_server = reinterpret_cast <server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
792
789
793
790
std::string c_params = parse_jstring (env, jparams);
794
791
json data = json::parse (c_params);
795
-
796
- json templateData = oaicompat_completion_params_parse (data, ctx_server->params_base .use_jinja , ctx_server->params_base .reasoning_format , ctx_server->chat_templates .get ());
792
+
793
+ json templateData =
794
+ oaicompat_completion_params_parse (data, ctx_server->params_base .use_jinja ,
795
+ ctx_server->params_base .reasoning_format , ctx_server->chat_templates .get ());
797
796
std::string tok_str = templateData.at (" prompt" );
798
- jstring jtok_str = env->NewStringUTF (tok_str.c_str ());
799
-
800
- return jtok_str;
797
+ jstring jtok_str = env->NewStringUTF (tok_str.c_str ());
798
+
799
+ return jtok_str;
801
800
}
802
801
803
802
JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode (JNIEnv *env, jobject obj, jstring jprompt) {
0 commit comments