Skip to content

Commit 986bddf

Browse files
committed
re-use parse_string_array for re-ranking
1 parent 6e95f61 commit 986bddf

File tree

1 file changed

+9
-30
lines changed

1 file changed

+9
-30
lines changed

src/main/cpp/jllama.cpp

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -112,28 +112,6 @@ char **parse_string_array(JNIEnv *env, const jobjectArray string_array, const js
112112
return result;
113113
}
114114

115-
std::vector<std::string> parse_string_array_for_rerank(JNIEnv *env, const jobjectArray string_array,
116-
const jsize length) {
117-
std::vector<std::string> result;
118-
result.reserve(length); // Reserve memory for efficiency
119-
120-
for (jsize i = 0; i < length; i++) {
121-
jstring javaString = static_cast<jstring>(env->GetObjectArrayElement(string_array, i));
122-
if (javaString == nullptr)
123-
continue;
124-
125-
const char *cString = env->GetStringUTFChars(javaString, nullptr);
126-
if (cString != nullptr) {
127-
result.emplace_back(cString); // Add to vector
128-
env->ReleaseStringUTFChars(javaString, cString);
129-
}
130-
131-
env->DeleteLocalRef(javaString); // Avoid memory leaks
132-
}
133-
134-
return result;
135-
}
136-
137115
void free_string_array(char **array, jsize length) {
138116
if (array != nullptr) {
139117
for (jsize i = 0; i < length; i++) {
@@ -720,17 +698,18 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo
720698
const auto tokenized_query = tokenize_mixed(ctx_server->vocab, prompt, true, true);
721699

722700
json responses = json::array();
723-
bool error = false;
724701

725702
std::vector<server_task> tasks;
726-
const jsize argc = env->GetArrayLength(documents);
727-
std::vector<std::string> documentsArray = parse_string_array_for_rerank(env, documents, argc);
703+
const jsize amount_documents = env->GetArrayLength(documents);
704+
auto *document_array = parse_string_array(env, documents, amount_documents);
705+
auto document_vector = std::vector<std::string>(document_array, document_array + amount_documents);
706+
free_string_array(document_array, amount_documents);
728707

729-
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server->vocab, documentsArray, true, true);
708+
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server->vocab, document_vector, true, true);
730709

731710
tasks.reserve(tokenized_docs.size());
732-
for (size_t i = 0; i < tokenized_docs.size(); i++) {
733-
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
711+
for (int i = 0; i < tokenized_docs.size(); i++) {
712+
auto task = server_task(SERVER_TASK_TYPE_RERANK);
734713
task.id = ctx_server->queue_tasks.get_new_id();
735714
task.index = i;
736715
task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]);
@@ -753,7 +732,7 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo
753732
for (int i = 0; i < (int)task_ids.size(); i++) {
754733
server_task_result_ptr result = ctx_server->queue_results.recv(task_ids);
755734
if (result->is_error()) {
756-
std::string response = result->to_json()["message"].get<std::string>();
735+
auto response = result->to_json()["message"].get<std::string>();
757736
for (const int id_task : task_ids) {
758737
ctx_server->queue_results.remove_waiting_task_id(id_task);
759738
}
@@ -771,7 +750,7 @@ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jo
771750

772751
int index = out_res["index"].get<int>();
773752
float score = out_res["score"].get<float>();
774-
std::string tok_str = documentsArray[index];
753+
std::string tok_str = document_vector[index];
775754
jstring jtok_str = env->NewStringUTF(tok_str.c_str());
776755

777756
jobject jprob = env->NewObject(c_float, cc_float, score);

0 commit comments

Comments
 (0)