@@ -112,28 +112,6 @@ char **parse_string_array(JNIEnv *env, const jobjectArray string_array, const js
112
112
return result;
113
113
}
114
114
115
- std::vector<std::string> parse_string_array_for_rerank (JNIEnv *env, const jobjectArray string_array,
116
- const jsize length) {
117
- std::vector<std::string> result;
118
- result.reserve (length); // Reserve memory for efficiency
119
-
120
- for (jsize i = 0 ; i < length; i++) {
121
- jstring javaString = static_cast <jstring>(env->GetObjectArrayElement (string_array, i));
122
- if (javaString == nullptr )
123
- continue ;
124
-
125
- const char *cString = env->GetStringUTFChars (javaString, nullptr );
126
- if (cString != nullptr ) {
127
- result.emplace_back (cString); // Add to vector
128
- env->ReleaseStringUTFChars (javaString, cString);
129
- }
130
-
131
- env->DeleteLocalRef (javaString); // Avoid memory leaks
132
- }
133
-
134
- return result;
135
- }
136
-
137
115
void free_string_array (char **array, jsize length) {
138
116
if (array != nullptr ) {
139
117
for (jsize i = 0 ; i < length; i++) {
@@ -720,17 +698,18 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo
720
698
const auto tokenized_query = tokenize_mixed (ctx_server->vocab , prompt, true , true );
721
699
722
700
json responses = json::array ();
723
- bool error = false ;
724
701
725
702
std::vector<server_task> tasks;
726
- const jsize argc = env->GetArrayLength (documents);
727
- std::vector<std::string> documentsArray = parse_string_array_for_rerank (env, documents, argc);
703
+ const jsize amount_documents = env->GetArrayLength (documents);
704
+ auto *document_array = parse_string_array (env, documents, amount_documents);
705
+ auto document_vector = std::vector<std::string>(document_array, document_array + amount_documents);
706
+ free_string_array (document_array, amount_documents);
728
707
729
- std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts (ctx_server->vocab , documentsArray , true , true );
708
+ std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts (ctx_server->vocab , document_vector , true , true );
730
709
731
710
tasks.reserve (tokenized_docs.size ());
732
- for (size_t i = 0 ; i < tokenized_docs.size (); i++) {
733
- server_task task = server_task (SERVER_TASK_TYPE_RERANK);
711
+ for (int i = 0 ; i < tokenized_docs.size (); i++) {
712
+ auto task = server_task (SERVER_TASK_TYPE_RERANK);
734
713
task.id = ctx_server->queue_tasks .get_new_id ();
735
714
task.index = i;
736
715
task.prompt_tokens = format_rerank (ctx_server->vocab , tokenized_query, tokenized_docs[i]);
@@ -753,7 +732,7 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo
753
732
for (int i = 0 ; i < (int )task_ids.size (); i++) {
754
733
server_task_result_ptr result = ctx_server->queue_results .recv (task_ids);
755
734
if (result->is_error ()) {
756
- std::string response = result->to_json ()[" message" ].get <std::string>();
735
+ auto response = result->to_json ()[" message" ].get <std::string>();
757
736
for (const int id_task : task_ids) {
758
737
ctx_server->queue_results .remove_waiting_task_id (id_task);
759
738
}
@@ -771,7 +750,7 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo
771
750
772
751
int index = out_res[" index" ].get <int >();
773
752
float score = out_res[" score" ].get <float >();
774
- std::string tok_str = documentsArray [index];
753
+ std::string tok_str = document_vector [index];
775
754
jstring jtok_str = env->NewStringUTF (tok_str.c_str ());
776
755
777
756
jobject jprob = env->NewObject (c_float, cc_float, score);
0 commit comments