Skip to content

Commit 6e95f61

Browse files
committed
reformat c++ code
1 parent 3d28a98 commit 6e95f61

File tree

1 file changed

+79
-80
lines changed

1 file changed

+79
-80
lines changed

src/main/cpp/jllama.cpp

Lines changed: 79 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,15 @@ char **parse_string_array(JNIEnv *env, const jobjectArray string_array, const js
112112
return result;
113113
}
114114

115-
std::vector<std::string> parse_string_array_for_rerank(JNIEnv *env, const jobjectArray string_array, const jsize length) {
115+
std::vector<std::string> parse_string_array_for_rerank(JNIEnv *env, const jobjectArray string_array,
116+
const jsize length) {
116117
std::vector<std::string> result;
117118
result.reserve(length); // Reserve memory for efficiency
118119

119120
for (jsize i = 0; i < length; i++) {
120121
jstring javaString = static_cast<jstring>(env->GetObjectArrayElement(string_array, i));
121-
if (javaString == nullptr) continue;
122+
if (javaString == nullptr)
123+
continue;
122124

123125
const char *cString = env->GetStringUTFChars(javaString, nullptr);
124126
if (cString != nullptr) {
@@ -259,7 +261,6 @@ JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *reserved) {
259261
cc_integer = env->GetMethodID(c_integer, "<init>", "(I)V");
260262
cc_float = env->GetMethodID(c_float, "<init>", "(F)V");
261263

262-
263264
if (!(cc_output && cc_hash_map && cc_integer && cc_float)) {
264265
goto error;
265266
}
@@ -663,12 +664,11 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
663664
env->ThrowNew(c_llama_error, response.c_str());
664665
return nullptr;
665666
}
666-
667+
667668
if (result->is_stop()) {
668669
ctx_server->queue_results.remove_waiting_task_id(id_task);
669670
}
670671

671-
672672
const auto out_res = result->to_json();
673673

674674
// Extract "embedding" as a vector of vectors (2D array)
@@ -704,100 +704,99 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
704704
return j_embedding;
705705
}
706706

707-
JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jobject obj, jstring jprompt, jobjectArray documents) {
707+
JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank(JNIEnv *env, jobject obj, jstring jprompt,
708+
jobjectArray documents) {
708709
jlong server_handle = env->GetLongField(obj, f_model_pointer);
709710
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
710711

711-
if (!ctx_server->params_base.reranking || ctx_server->params_base.embedding) {
712-
env->ThrowNew(c_llama_error,
712+
if (!ctx_server->params_base.reranking || ctx_server->params_base.embedding) {
713+
env->ThrowNew(c_llama_error,
713714
"This server does not support reranking. Start it with `--reranking` and without `--embedding`");
714-
return nullptr;
715+
return nullptr;
715716
}
716-
717717

718718
const std::string prompt = parse_jstring(env, jprompt);
719719

720-
721-
722720
const auto tokenized_query = tokenize_mixed(ctx_server->vocab, prompt, true, true);
723-
721+
724722
json responses = json::array();
725723
bool error = false;
726-
727-
std::vector<server_task> tasks;
728-
const jsize argc = env->GetArrayLength(documents);
729-
std::vector<std::string> documentsArray = parse_string_array_for_rerank(env, documents, argc);
730-
731-
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server->vocab, documentsArray, true, true);
732-
733-
tasks.reserve(tokenized_docs.size());
734-
for (size_t i = 0; i < tokenized_docs.size(); i++) {
735-
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
736-
task.id = ctx_server->queue_tasks.get_new_id();
737-
task.index = i;
738-
task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]);
739-
tasks.push_back(task);
740-
}
741-
ctx_server->queue_results.add_waiting_tasks(tasks);
742-
ctx_server->queue_tasks.post(tasks);
743-
744-
// get the result
745-
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
746-
std::vector<server_task_result_ptr> results(task_ids.size());
747-
748-
// Create a new HashMap instance
749-
jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map);
750-
if (o_probabilities == nullptr) {
751-
env->ThrowNew(c_llama_error, "Failed to create HashMap object.");
752-
return nullptr;
753-
}
754-
755-
for (int i = 0; i < (int)task_ids.size(); i++) {
756-
server_task_result_ptr result = ctx_server->queue_results.recv(task_ids);
757-
if (result->is_error()) {
758-
std::string response = result->to_json()["message"].get<std::string>();
759-
for (const int id_task : task_ids) {
760-
ctx_server->queue_results.remove_waiting_task_id(id_task);
761-
}
762-
env->ThrowNew(c_llama_error, response.c_str());
763-
return nullptr;
764-
}
765-
766-
const auto out_res = result->to_json();
767-
768-
if (result->is_stop()) {
769-
for (const int id_task : task_ids) {
770-
ctx_server->queue_results.remove_waiting_task_id(id_task);
771-
}
772-
}
773-
774-
int index = out_res["index"].get<int>();
775-
float score = out_res["score"].get<float>();
776-
std::string tok_str = documentsArray[index];
777-
jstring jtok_str = env->NewStringUTF(tok_str.c_str());
778-
779-
jobject jprob = env->NewObject(c_float, cc_float, score);
780-
env->CallObjectMethod(o_probabilities, m_map_put, jtok_str, jprob);
781-
env->DeleteLocalRef(jtok_str);
782-
env->DeleteLocalRef(jprob);
783-
}
724+
725+
std::vector<server_task> tasks;
726+
const jsize argc = env->GetArrayLength(documents);
727+
std::vector<std::string> documentsArray = parse_string_array_for_rerank(env, documents, argc);
728+
729+
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server->vocab, documentsArray, true, true);
730+
731+
tasks.reserve(tokenized_docs.size());
732+
for (size_t i = 0; i < tokenized_docs.size(); i++) {
733+
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
734+
task.id = ctx_server->queue_tasks.get_new_id();
735+
task.index = i;
736+
task.prompt_tokens = format_rerank(ctx_server->vocab, tokenized_query, tokenized_docs[i]);
737+
tasks.push_back(task);
738+
}
739+
ctx_server->queue_results.add_waiting_tasks(tasks);
740+
ctx_server->queue_tasks.post(tasks);
741+
742+
// get the result
743+
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
744+
std::vector<server_task_result_ptr> results(task_ids.size());
745+
746+
// Create a new HashMap instance
747+
jobject o_probabilities = env->NewObject(c_hash_map, cc_hash_map);
748+
if (o_probabilities == nullptr) {
749+
env->ThrowNew(c_llama_error, "Failed to create HashMap object.");
750+
return nullptr;
751+
}
752+
753+
for (int i = 0; i < (int)task_ids.size(); i++) {
754+
server_task_result_ptr result = ctx_server->queue_results.recv(task_ids);
755+
if (result->is_error()) {
756+
std::string response = result->to_json()["message"].get<std::string>();
757+
for (const int id_task : task_ids) {
758+
ctx_server->queue_results.remove_waiting_task_id(id_task);
759+
}
760+
env->ThrowNew(c_llama_error, response.c_str());
761+
return nullptr;
762+
}
763+
764+
const auto out_res = result->to_json();
765+
766+
if (result->is_stop()) {
767+
for (const int id_task : task_ids) {
768+
ctx_server->queue_results.remove_waiting_task_id(id_task);
769+
}
770+
}
771+
772+
int index = out_res["index"].get<int>();
773+
float score = out_res["score"].get<float>();
774+
std::string tok_str = documentsArray[index];
775+
jstring jtok_str = env->NewStringUTF(tok_str.c_str());
776+
777+
jobject jprob = env->NewObject(c_float, cc_float, score);
778+
env->CallObjectMethod(o_probabilities, m_map_put, jtok_str, jprob);
779+
env->DeleteLocalRef(jtok_str);
780+
env->DeleteLocalRef(jprob);
781+
}
784782
jbyteArray jbytes = parse_jbytes(env, prompt);
785-
return env->NewObject(c_output, cc_output, jbytes, o_probabilities, true);
786-
783+
return env->NewObject(c_output, cc_output, jbytes, o_probabilities, true);
787784
}
788785

789-
JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *env, jobject obj, jstring jparams){
790-
jlong server_handle = env->GetLongField(obj, f_model_pointer);
786+
JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate(JNIEnv *env, jobject obj, jstring jparams) {
787+
jlong server_handle = env->GetLongField(obj, f_model_pointer);
791788
auto *ctx_server = reinterpret_cast<server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
792789

793790
std::string c_params = parse_jstring(env, jparams);
794791
json data = json::parse(c_params);
795-
796-
json templateData = oaicompat_completion_params_parse(data, ctx_server->params_base.use_jinja, ctx_server->params_base.reasoning_format, ctx_server->chat_templates.get());
792+
793+
json templateData =
794+
oaicompat_completion_params_parse(data, ctx_server->params_base.use_jinja,
795+
ctx_server->params_base.reasoning_format, ctx_server->chat_templates.get());
797796
std::string tok_str = templateData.at("prompt");
798-
jstring jtok_str = env->NewStringUTF(tok_str.c_str());
799-
800-
return jtok_str;
797+
jstring jtok_str = env->NewStringUTF(tok_str.c_str());
798+
799+
return jtok_str;
801800
}
802801

803802
JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode(JNIEnv *env, jobject obj, jstring jprompt) {

0 commit comments

Comments
 (0)