-
Notifications
You must be signed in to change notification settings - Fork 67
/
Copy pathrag_demo_pro.py
1978 lines (1684 loc) · 72.7 KB
/
rag_demo_pro.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import gradio as gr
from pdfminer.high_level import extract_text_to_fp
from sentence_transformers import SentenceTransformer
# 导入交叉编码器
from sentence_transformers import CrossEncoder
import chromadb
from chromadb.config import Settings
import requests
import json
from io import StringIO
from langchain.text_splitter import RecursiveCharacterTextSplitter
import os
import socket
import webbrowser
import logging
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import time
from datetime import datetime
import hashlib
import re
from dotenv import load_dotenv
# 导入BM25算法库
from rank_bm25 import BM25Okapi
import numpy as np
import jieba
import threading
from functools import lru_cache
# 加载环境变量
load_dotenv()
SERPAPI_KEY = os.getenv("SERPAPI_KEY") # 在.env中设置 SERPAPI_KEY
SEARCH_ENGINE = "google" # 可根据需要改为其他搜索引擎
# 新增:重排序方法配置(交叉编码器或LLM)
RERANK_METHOD = os.getenv("RERANK_METHOD", "cross_encoder") # "cross_encoder" 或 "llm"
# 新增:SiliconFlow API配置
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "sk-lnflogwkgcgchrztauesjderjgjqmwldwtxwkkfwzcnshbgf")
SILICONFLOW_API_URL = os.getenv("SILICONFLOW_API_URL", "https://api.siliconflow.cn/v1/chat/completions")
# 在文件开头添加超时设置
import requests
requests.adapters.DEFAULT_RETRIES = 3 # 增加重试次数
# 在文件开头添加环境变量设置
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # 禁用oneDNN优化
# 在文件最开头添加代理配置
import os
os.environ['NO_PROXY'] = 'localhost,127.0.0.1' # 新增代理绕过设置
# 初始化组件
EMBED_MODEL = SentenceTransformer('all-MiniLM-L6-v2')
CHROMA_CLIENT = chromadb.PersistentClient(
path="./chroma_db",
settings=chromadb.Settings(anonymized_telemetry=False)
)
COLLECTION = CHROMA_CLIENT.get_or_create_collection("rag_docs")
# 新增:交叉编码器初始化(延迟加载)
cross_encoder = None
cross_encoder_lock = threading.Lock()
def get_cross_encoder():
"""延迟加载交叉编码器模型"""
global cross_encoder
if cross_encoder is None:
with cross_encoder_lock:
if cross_encoder is None:
try:
# 使用多语言交叉编码器,更适合中文
cross_encoder = CrossEncoder('sentence-transformers/distiluse-base-multilingual-cased-v2')
logging.info("交叉编码器加载成功")
except Exception as e:
logging.error(f"加载交叉编码器失败: {str(e)}")
# 设置为None,下次调用会重试
cross_encoder = None
return cross_encoder
# 新增:BM25索引管理
def recursive_retrieval(initial_query, max_iterations=3, enable_web_search=False, model_choice="ollama"):
"""
实现递归检索与迭代查询功能
通过分析当前查询结果,确定是否需要进一步查询
Args:
initial_query: 初始查询
max_iterations: 最大迭代次数
enable_web_search: 是否启用网络搜索
model_choice: 使用的模型选择("ollama"或"siliconflow")
Returns:
包含所有检索内容的列表
"""
query = initial_query
all_contexts = []
all_doc_ids = []
all_metadata = []
for i in range(max_iterations):
logging.info(f"递归检索迭代 {i+1}/{max_iterations},当前查询: {query}")
# 如果启用了联网搜索,先进行网络搜索
web_results = []
if enable_web_search and check_serpapi_key():
try:
web_results = update_web_results(query)
except Exception as e:
logging.error(f"网络搜索错误: {str(e)}")
# 生成查询嵌入
query_embedding = EMBED_MODEL.encode([query]).tolist()
# 语义向量检索
try:
semantic_results = COLLECTION.query(
query_embeddings=query_embedding,
n_results=10,
include=['documents', 'metadatas']
)
except Exception as e:
logging.error(f"向量检索错误: {str(e)}")
semantic_results = {"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]]}
# BM25关键词检索
bm25_results = BM25_MANAGER.search(query, top_k=10)
# 混合检索结果
hybrid_results = hybrid_merge(semantic_results, bm25_results, alpha=0.7)
# 提取结果
doc_ids = []
docs = []
metadata_list = []
if hybrid_results:
for doc_id, result_data in hybrid_results[:10]:
doc_ids.append(doc_id)
docs.append(result_data['content'])
metadata_list.append(result_data['metadata'])
# 重排序结果
if docs:
try:
reranked_results = rerank_results(query, docs, doc_ids, metadata_list, top_k=5)
except Exception as e:
logging.error(f"重排序错误: {str(e)}")
reranked_results = [(doc_id, {'content': doc, 'metadata': meta, 'score': 1.0})
for doc_id, doc, meta in zip(doc_ids, docs, metadata_list)]
else:
reranked_results = []
# 收集当前迭代的结果
current_contexts = []
for doc_id, result_data in reranked_results:
doc = result_data['content']
metadata = result_data['metadata']
# 添加到总结果集
if doc_id not in all_doc_ids: # 避免重复
all_doc_ids.append(doc_id)
all_contexts.append(doc)
all_metadata.append(metadata)
current_contexts.append(doc)
# 如果已经是最后一次迭代,结束循环
if i == max_iterations - 1:
break
# 使用LLM分析是否需要进一步查询
if current_contexts:
# 简单总结当前检索内容
current_summary = "\n".join(current_contexts[:3]) if current_contexts else "未找到相关信息"
next_query_prompt = f"""基于原始问题: {initial_query}
以及已检索信息:
{current_summary}
分析是否需要进一步查询。如果需要,请提供新的查询问题,使用不同角度或更具体的关键词。
如果已经有充分信息,请回复'不需要进一步查询'。
新查询(如果需要):"""
try:
# 根据模型选择使用不同的API
if model_choice == "siliconflow":
# 使用SiliconFlow API
logging.info("使用SiliconFlow API分析是否需要进一步查询")
next_query_result = call_siliconflow_api(next_query_prompt, temperature=0.7, max_tokens=256)
# 去除可能的思维链标记
if "<think>" in next_query_result:
next_query = next_query_result.split("<think>")[0].strip()
else:
next_query = next_query_result
else:
# 使用本地Ollama
logging.info("使用本地Ollama模型分析是否需要进一步查询")
response = session.post(
"http://localhost:11434/api/generate",
json={
"model": "deepseek-r1:1.5b",
"prompt": next_query_prompt,
"stream": False
},
timeout=30
)
next_query = response.json().get("response", "")
if "不需要" in next_query or "不需要进一步查询" in next_query or len(next_query.strip()) < 5:
logging.info("LLM判断不需要进一步查询,结束递归检索")
break
# 使用新查询继续迭代
query = next_query
logging.info(f"生成新查询: {query}")
except Exception as e:
logging.error(f"生成新查询时出错: {str(e)}")
break
else:
# 如果当前迭代没有检索到内容,结束迭代
break
return all_contexts, all_doc_ids, all_metadata
class BM25IndexManager:
def __init__(self):
self.bm25_index = None
self.doc_mapping = {} # 映射BM25索引位置到文档ID
self.tokenized_corpus = []
self.raw_corpus = []
def build_index(self, documents, doc_ids):
"""构建BM25索引"""
self.raw_corpus = documents
self.doc_mapping = {i: doc_id for i, doc_id in enumerate(doc_ids)}
# 对文档进行分词,使用jieba分词器更适合中文
self.tokenized_corpus = []
for doc in documents:
# 对中文文档进行分词
tokens = list(jieba.cut(doc))
self.tokenized_corpus.append(tokens)
# 创建BM25索引
self.bm25_index = BM25Okapi(self.tokenized_corpus)
return True
def search(self, query, top_k=5):
"""使用BM25检索相关文档"""
if not self.bm25_index:
return []
# 对查询进行分词
tokenized_query = list(jieba.cut(query))
# 获取BM25得分
bm25_scores = self.bm25_index.get_scores(tokenized_query)
# 获取得分最高的文档索引
top_indices = np.argsort(bm25_scores)[-top_k:][::-1]
# 返回结果
results = []
for idx in top_indices:
if bm25_scores[idx] > 0: # 只返回有相关性的结果
results.append({
'id': self.doc_mapping[idx],
'score': float(bm25_scores[idx]),
'content': self.raw_corpus[idx]
})
return results
def clear(self):
"""清空索引"""
self.bm25_index = None
self.doc_mapping = {}
self.tokenized_corpus = []
self.raw_corpus = []
# 初始化BM25索引管理器
BM25_MANAGER = BM25IndexManager()
logging.basicConfig(level=logging.INFO)
print("Gradio version:", gr.__version__) # 添加版本输出
# 在初始化组件后添加:
session = requests.Session()
retries = Retry(
total=3,
backoff_factor=0.1,
status_forcelist=[500, 502, 503, 504]
)
session.mount('http://', HTTPAdapter(max_retries=retries))
#########################################
# SerpAPI 网络查询及向量化处理函数
#########################################
def serpapi_search(query: str, num_results: int = 5) -> list:
"""
执行 SerpAPI 搜索,并返回解析后的结构化结果
"""
if not SERPAPI_KEY:
raise ValueError("未设置 SERPAPI_KEY 环境变量。请在.env文件中设置您的 API 密钥。")
try:
params = {
"engine": SEARCH_ENGINE,
"q": query,
"api_key": SERPAPI_KEY,
"num": num_results,
"hl": "zh-CN", # 中文界面
"gl": "cn"
}
response = requests.get("https://serpapi.com/search", params=params, timeout=15)
response.raise_for_status()
search_data = response.json()
return _parse_serpapi_results(search_data)
except Exception as e:
logging.error(f"网络搜索失败: {str(e)}")
return []
def _parse_serpapi_results(data: dict) -> list:
"""解析 SerpAPI 返回的原始数据"""
results = []
if "organic_results" in data:
for item in data["organic_results"]:
result = {
"title": item.get("title"),
"url": item.get("link"),
"snippet": item.get("snippet"),
"timestamp": item.get("date") # 若有时间信息,可选
}
results.append(result)
# 如果有知识图谱信息,也可以添加置顶(可选)
if "knowledge_graph" in data:
kg = data["knowledge_graph"]
results.insert(0, {
"title": kg.get("title"),
"url": kg.get("source", {}).get("link", ""),
"snippet": kg.get("description"),
"source": "knowledge_graph"
})
return results
def update_web_results(query: str, num_results: int = 5) -> list:
"""
基于 SerpAPI 搜索结果,向量化并存储到 ChromaDB
为网络结果添加元数据,ID 格式为 "web_{index}"
"""
results = serpapi_search(query, num_results)
if not results:
return []
# 删除旧的网络搜索结果(使用更健壮的方式)
try:
# 获取所有文档的元数据
collection_data = COLLECTION.get(include=['metadatas'])
if collection_data and 'metadatas' in collection_data:
# 使用集合推导生成要删除的ID列表
web_ids = []
for i, metadata in enumerate(collection_data['metadatas']):
# 如果元数据中的source字段是'web',那么这是一个网络结果
if metadata.get('source') == 'web' and i < len(collection_data['ids']):
web_ids.append(collection_data['ids'][i])
# 删除找到的网络结果
if web_ids:
COLLECTION.delete(ids=web_ids)
logging.info(f"已删除 {len(web_ids)} 条旧的网络搜索结果")
except Exception as e:
logging.warning(f"删除旧的网络搜索结果时出错: {str(e)}")
# 继续执行,不影响新结果添加
# 准备新的网络搜索结果
docs = []
metadatas = []
ids = []
for idx, res in enumerate(results):
text = f"标题:{res.get('title', '')}\n摘要:{res.get('snippet', '')}"
docs.append(text)
meta = {"source": "web", "url": res.get("url", ""), "title": res.get("title")}
meta["content_hash"] = hashlib.md5(text.encode()).hexdigest()[:8]
metadatas.append(meta)
ids.append(f"web_{idx}")
embeddings = EMBED_MODEL.encode(docs)
COLLECTION.add(ids=ids, embeddings=embeddings.tolist(), documents=docs, metadatas=metadatas)
return results
# 检查是否配置了SERPAPI_KEY
def check_serpapi_key():
"""检查是否配置了SERPAPI_KEY"""
return SERPAPI_KEY is not None and SERPAPI_KEY.strip() != ""
# 添加文件处理状态跟踪
class FileProcessor:
def __init__(self):
self.processed_files = {} # 存储已处理文件的状态
def clear_files(self):
"""清空所有文件记录"""
self.processed_files = {}
def add_file(self, file_name):
self.processed_files[file_name] = {
'status': '等待处理',
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'chunks': 0
}
def update_status(self, file_name, status, chunks=None):
if file_name in self.processed_files:
self.processed_files[file_name]['status'] = status
if chunks is not None:
self.processed_files[file_name]['chunks'] = chunks
def get_file_list(self):
return [
f"📄 {fname} | {info['status']}"
for fname, info in self.processed_files.items()
]
file_processor = FileProcessor()
#########################################
# 矛盾检测函数
#########################################
def detect_conflicts(sources):
"""精准矛盾检测算法"""
key_facts = {}
for item in sources:
facts = extract_facts(item['text'] if 'text' in item else item.get('excerpt', ''))
for fact, value in facts.items():
if fact in key_facts:
if key_facts[fact] != value:
return True
else:
key_facts[fact] = value
return False
def extract_facts(text):
"""从文本提取关键事实(示例逻辑)"""
facts = {}
# 提取数值型事实
numbers = re.findall(r'\b\d{4}年|\b\d+%', text)
if numbers:
facts['关键数值'] = numbers
# 提取技术术语
if "产业图谱" in text:
facts['技术方法'] = list(set(re.findall(r'[A-Za-z]+模型|[A-Z]{2,}算法', text)))
return facts
def evaluate_source_credibility(source):
"""评估来源可信度"""
credibility_scores = {
"gov.cn": 0.9,
"edu.cn": 0.85,
"weixin": 0.7,
"zhihu": 0.6,
"baidu": 0.5
}
url = source.get('url', '')
if not url:
return 0.5 # 默认中等可信度
domain_match = re.search(r'//([^/]+)', url)
if not domain_match:
return 0.5
domain = domain_match.group(1)
# 检查是否匹配任何已知域名
for known_domain, score in credibility_scores.items():
if known_domain in domain:
return score
return 0.5 # 默认中等可信度
def extract_text(filepath):
"""改进的PDF文本提取方法"""
output = StringIO()
with open(filepath, 'rb') as file:
extract_text_to_fp(file, output)
return output.getvalue()
def process_multiple_pdfs(files, progress=gr.Progress()):
"""处理多个PDF文件"""
if not files:
return "请选择要上传的PDF文件", []
try:
# 清空向量数据库
progress(0.1, desc="清理历史数据...")
try:
# 直接获取所有ID,不使用include参数
# 由于ChromaDB的限制,我们只能获取所有数据,并从中提取ID
existing_data = COLLECTION.get()
if existing_data and 'ids' in existing_data and existing_data['ids']:
COLLECTION.delete(ids=existing_data['ids'])
logging.info(f"成功清理 {len(existing_data['ids'])} 条历史向量数据")
else:
logging.info("没有找到历史向量数据需要清理")
# 清空BM25索引
BM25_MANAGER.clear()
except Exception as e:
logging.error(f"清理历史数据时出错: {str(e)}")
return f"清理历史数据失败: {str(e)}", []
# 清空文件处理状态
file_processor.clear_files()
total_files = len(files)
processed_results = []
total_chunks = 0
for idx, file in enumerate(files, 1):
try:
file_name = os.path.basename(file.name)
progress((idx-1)/total_files, desc=f"处理文件 {idx}/{total_files}: {file_name}")
# 添加文件到处理器
file_processor.add_file(file_name)
# 处理单个文件
text = extract_text(file.name)
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=400,
chunk_overlap=40,
separators=["\n\n", "\n", "。", ",", ";", ":", " ", ""] # 按自然语言结构分割
)
chunks = text_splitter.split_text(text)
if not chunks:
raise ValueError("文档内容为空或无法提取文本")
# 生成文档唯一标识符
doc_id = f"doc_{int(time.time())}_{idx}"
# 生成嵌入
embeddings = EMBED_MODEL.encode(chunks)
# 存储向量,添加文档源信息
ids = [f"{doc_id}_chunk_{i}" for i in range(len(chunks))]
metadatas = [{"source": file_name, "doc_id": doc_id} for _ in chunks]
COLLECTION.add(
ids=ids,
embeddings=embeddings.tolist(),
documents=chunks,
metadatas=metadatas
)
# 更新处理状态
total_chunks += len(chunks)
file_processor.update_status(file_name, "处理完成", len(chunks))
processed_results.append(f"✅ {file_name}: 成功处理 {len(chunks)} 个文本块")
except Exception as e:
error_msg = str(e)
logging.error(f"处理文件 {file_name} 时出错: {error_msg}")
file_processor.update_status(file_name, f"处理失败: {error_msg}")
processed_results.append(f"❌ {file_name}: 处理失败 - {error_msg}")
# 添加总结信息
summary = f"\n总计处理 {total_files} 个文件,{total_chunks} 个文本块"
processed_results.append(summary)
# 更新BM25索引
progress(0.95, desc="构建BM25检索索引...")
update_bm25_index()
# 获取更新后的文件列表
file_list = file_processor.get_file_list()
return "\n".join(processed_results), file_list
except Exception as e:
error_msg = str(e)
logging.error(f"整体处理过程出错: {error_msg}")
return f"处理过程出错: {error_msg}", []
# 新增:交叉编码器重排序函数
def rerank_with_cross_encoder(query, docs, doc_ids, metadata_list, top_k=5):
"""
使用交叉编码器对检索结果进行重排序
参数:
query: 查询字符串
docs: 文档内容列表
doc_ids: 文档ID列表
metadata_list: 元数据列表
top_k: 返回结果数量
返回:
重排序后的结果列表 [(doc_id, {'content': doc, 'metadata': metadata, 'score': score}), ...]
"""
if not docs:
return []
encoder = get_cross_encoder()
if encoder is None:
logging.warning("交叉编码器不可用,跳过重排序")
# 返回原始顺序(按索引排序)
return [(doc_id, {'content': doc, 'metadata': meta, 'score': 1.0 - idx/len(docs)})
for idx, (doc_id, doc, meta) in enumerate(zip(doc_ids, docs, metadata_list))]
# 准备交叉编码器输入
cross_inputs = [[query, doc] for doc in docs]
try:
# 计算相关性得分
scores = encoder.predict(cross_inputs)
# 组合结果
results = [
(doc_id, {
'content': doc,
'metadata': meta,
'score': float(score) # 确保是Python原生类型
})
for doc_id, doc, meta, score in zip(doc_ids, docs, metadata_list, scores)
]
# 按得分排序
results = sorted(results, key=lambda x: x[1]['score'], reverse=True)
# 返回前K个结果
return results[:top_k]
except Exception as e:
logging.error(f"交叉编码器重排序失败: {str(e)}")
# 出错时返回原始顺序
return [(doc_id, {'content': doc, 'metadata': meta, 'score': 1.0 - idx/len(docs)})
for idx, (doc_id, doc, meta) in enumerate(zip(doc_ids, docs, metadata_list))]
# 新增:LLM相关性评分函数
@lru_cache(maxsize=32)
def get_llm_relevance_score(query, doc):
"""
使用LLM对查询和文档的相关性进行评分(带缓存)
参数:
query: 查询字符串
doc: 文档内容
返回:
相关性得分 (0-10)
"""
try:
# 构建评分提示词
prompt = f"""给定以下查询和文档片段,评估它们的相关性。
评分标准:0分表示完全不相关,10分表示高度相关。
只需返回一个0-10之间的整数分数,不要有任何其他解释。
查询: {query}
文档片段: {doc}
相关性分数(0-10):"""
# 调用本地LLM
response = session.post(
"http://localhost:11434/api/generate",
json={
"model": "deepseek-r1:1.5b", # 使用较小模型进行评分
"prompt": prompt,
"stream": False
},
timeout=30
)
# 提取得分
result = response.json().get("response", "").strip()
# 尝试解析为数字
try:
score = float(result)
# 确保分数在0-10范围内
score = max(0, min(10, score))
return score
except ValueError:
# 如果无法解析为数字,尝试从文本中提取数字
match = re.search(r'\b([0-9]|10)\b', result)
if match:
return float(match.group(1))
else:
# 默认返回中等相关性
return 5.0
except Exception as e:
logging.error(f"LLM评分失败: {str(e)}")
# 默认返回中等相关性
return 5.0
def rerank_with_llm(query, docs, doc_ids, metadata_list, top_k=5):
"""
使用LLM对检索结果进行重排序
参数:
query: 查询字符串
docs: 文档内容列表
doc_ids: 文档ID列表
metadata_list: 元数据列表
top_k: 返回结果数量
返回:
重排序后的结果列表
"""
if not docs:
return []
results = []
# 对每个文档进行评分
for doc_id, doc, meta in zip(doc_ids, docs, metadata_list):
# 获取LLM评分
score = get_llm_relevance_score(query, doc)
# 添加到结果列表
results.append((doc_id, {
'content': doc,
'metadata': meta,
'score': score / 10.0 # 归一化到0-1
}))
# 按得分排序
results = sorted(results, key=lambda x: x[1]['score'], reverse=True)
# 返回前K个结果
return results[:top_k]
# 新增:通用重排序函数
def rerank_results(query, docs, doc_ids, metadata_list, method=None, top_k=5):
"""
对检索结果进行重排序
参数:
query: 查询字符串
docs: 文档内容列表
doc_ids: 文档ID列表
metadata_list: 元数据列表
method: 重排序方法 ("cross_encoder", "llm" 或 None)
top_k: 返回结果数量
返回:
重排序后的结果
"""
# 如果未指定方法,使用全局配置
if method is None:
method = RERANK_METHOD
# 根据方法选择重排序函数
if method == "llm":
return rerank_with_llm(query, docs, doc_ids, metadata_list, top_k)
elif method == "cross_encoder":
return rerank_with_cross_encoder(query, docs, doc_ids, metadata_list, top_k)
else:
# 默认不进行重排序,按原始顺序返回
return [(doc_id, {'content': doc, 'metadata': meta, 'score': 1.0 - idx/len(docs)})
for idx, (doc_id, doc, meta) in enumerate(zip(doc_ids, docs, metadata_list))]
def stream_answer(question, enable_web_search=False, model_choice="ollama", progress=gr.Progress()):
"""改进的流式问答处理流程,支持联网搜索、混合检索和重排序,以及多种模型选择"""
try:
# 检查向量数据库是否为空
try:
collection_data = COLLECTION.get(include=["documents"])
if not collection_data or not collection_data.get("documents") or len(collection_data.get("documents", [])) == 0:
if not enable_web_search:
yield "⚠️ 知识库为空,请先上传文档。", "遇到错误"
return
else:
logging.warning("知识库为空,将仅使用网络搜索结果")
except Exception as e:
if not enable_web_search:
yield f"⚠️ 检查知识库时出错: {str(e)},请确保已上传文档。", "遇到错误"
return
logging.error(f"检查知识库时出错: {str(e)}")
progress(0.3, desc="执行递归检索...")
# 使用递归检索获取更全面的答案上下文
all_contexts, all_doc_ids, all_metadata = recursive_retrieval(
initial_query=question,
max_iterations=3,
enable_web_search=enable_web_search,
model_choice=model_choice
)
# 组合上下文,包含来源信息
context_with_sources = []
sources_for_conflict_detection = []
# 使用检索到的结果构建上下文
for doc, doc_id, metadata in zip(all_contexts, all_doc_ids, all_metadata):
source_type = metadata.get('source', '本地文档')
source_item = {
'text': doc,
'type': source_type
}
if source_type == 'web':
url = metadata.get('url', '未知URL')
title = metadata.get('title', '未知标题')
context_with_sources.append(f"[网络来源: {title}] (URL: {url})\n{doc}")
source_item['url'] = url
source_item['title'] = title
else:
source = metadata.get('source', '未知来源')
context_with_sources.append(f"[本地文档: {source}]\n{doc}")
source_item['source'] = source
sources_for_conflict_detection.append(source_item)
# 检测矛盾
conflict_detected = detect_conflicts(sources_for_conflict_detection)
# 获取可信源
if conflict_detected:
credible_sources = [s for s in sources_for_conflict_detection
if s['type'] == 'web' and evaluate_source_credibility(s) > 0.7]
context = "\n\n".join(context_with_sources)
# 添加时间敏感检测
time_sensitive = any(word in question for word in ["最新", "今年", "当前", "最近", "刚刚"])
# 改进提示词模板,提高回答质量
prompt_template = """作为一个专业的问答助手,你需要基于以下{context_type}回答用户问题。
提供的参考内容:
{context}
用户问题:{question}
请遵循以下回答原则:
1. 仅基于提供的参考内容回答问题,不要使用你自己的知识
2. 如果参考内容中没有足够信息,请坦诚告知你无法回答
3. 回答应该全面、准确、有条理,并使用适当的段落和结构
4. 请用中文回答
5. 在回答末尾标注信息来源{time_instruction}{conflict_instruction}
请现在开始回答:"""
prompt = prompt_template.format(
context_type="本地文档和网络搜索结果" if enable_web_search else "本地文档",
context=context,
question=question,
time_instruction=",优先使用最新的信息" if time_sensitive and enable_web_search else "",
conflict_instruction=",并明确指出不同来源的差异" if conflict_detected else ""
)
progress(0.7, desc="生成回答...")
full_answer = ""
# 根据模型选择使用不同的API
if model_choice == "siliconflow":
# 对于SiliconFlow API,不支持流式响应,所以一次性获取
progress(0.8, desc="通过SiliconFlow API生成回答...")
full_answer = call_siliconflow_api(prompt, temperature=0.7, max_tokens=1536)
# 处理思维链
if "<think>" in full_answer and "</think>" in full_answer:
processed_answer = process_thinking_content(full_answer)
else:
processed_answer = full_answer
yield processed_answer, "完成!"
else:
# 使用本地Ollama模型的流式响应
response = session.post(
"http://localhost:11434/api/generate",
json={
"model": "deepseek-r1:1.5b",
"prompt": prompt,
"stream": True
},
timeout=120,
stream=True
)
for line in response.iter_lines():
if line:
chunk = json.loads(line.decode()).get("response", "")
full_answer += chunk
# 检查是否有完整的思维链标签可以处理
if "<think>" in full_answer and "</think>" in full_answer:
# 需要确保完整收集一个思维链片段后再显示
processed_answer = process_thinking_content(full_answer)
else:
processed_answer = full_answer
yield processed_answer, "生成回答中..."
# 处理最终输出,确保应用思维链处理
final_answer = process_thinking_content(full_answer)
yield final_answer, "完成!"
except Exception as e:
yield f"系统错误: {str(e)}", "遇到错误"
def query_answer(question, enable_web_search=False, model_choice="ollama", progress=gr.Progress()):
"""问答处理流程,支持联网搜索、混合检索和重排序,以及多种模型选择"""
try:
logging.info(f"收到问题:{question},联网状态:{enable_web_search},模型选择:{model_choice}")
# 检查向量数据库是否为空
try:
collection_data = COLLECTION.get(include=["documents"])
if not collection_data or not collection_data.get("documents") or len(collection_data.get("documents", [])) == 0:
if not enable_web_search:
return "⚠️ 知识库为空,请先上传文档。"
else:
logging.warning("知识库为空,将仅使用网络搜索结果")
except Exception as e:
if not enable_web_search:
return f"⚠️ 检查知识库时出错: {str(e)},请确保已上传文档。"
logging.error(f"检查知识库时出错: {str(e)}")
progress(0.3, desc="执行递归检索...")
# 使用递归检索获取更全面的答案上下文
all_contexts, all_doc_ids, all_metadata = recursive_retrieval(
initial_query=question,
max_iterations=3,
enable_web_search=enable_web_search,
model_choice=model_choice
)
# 组合上下文,包含来源信息
context_with_sources = []
sources_for_conflict_detection = []
# 使用检索到的结果构建上下文
for doc, doc_id, metadata in zip(all_contexts, all_doc_ids, all_metadata):
source_type = metadata.get('source', '本地文档')
source_item = {
'text': doc,
'type': source_type
}
if source_type == 'web':
url = metadata.get('url', '未知URL')
title = metadata.get('title', '未知标题')
context_with_sources.append(f"[网络来源: {title}] (URL: {url})\n{doc}")
source_item['url'] = url
source_item['title'] = title
else:
source = metadata.get('source', '未知来源')
context_with_sources.append(f"[本地文档: {source}]\n{doc}")
source_item['source'] = source
sources_for_conflict_detection.append(source_item)
# 检测矛盾
conflict_detected = detect_conflicts(sources_for_conflict_detection)
# 获取可信源
if conflict_detected:
credible_sources = [s for s in sources_for_conflict_detection
if s['type'] == 'web' and evaluate_source_credibility(s) > 0.7]
context = "\n\n".join(context_with_sources)
# 添加时间敏感检测
time_sensitive = any(word in question for word in ["最新", "今年", "当前", "最近", "刚刚"])
# 改进提示词模板,提高回答质量
prompt_template = """作为一个专业的问答助手,你需要基于以下{context_type}回答用户问题。
提供的参考内容:
{context}
用户问题:{question}
请遵循以下回答原则:
1. 仅基于提供的参考内容回答问题,不要使用你自己的知识
2. 如果参考内容中没有足够信息,请坦诚告知你无法回答
3. 回答应该全面、准确、有条理,并使用适当的段落和结构
4. 请用中文回答
5. 在回答末尾标注信息来源{time_instruction}{conflict_instruction}
请现在开始回答:"""
prompt = prompt_template.format(
context_type="本地文档和网络搜索结果" if enable_web_search else "本地文档",
context=context,
question=question,
time_instruction=",优先使用最新的信息" if time_sensitive and enable_web_search else "",
conflict_instruction=",并明确指出不同来源的差异" if conflict_detected else ""
)
progress(0.8, desc="生成回答...")
# 根据模型选择使用不同的API
if model_choice == "siliconflow":
# 使用SiliconFlow API
result = call_siliconflow_api(prompt, temperature=0.7, max_tokens=1536)