Skip to content

Commit e673dc0

Browse files
committed
initial commit
0 parents  commit e673dc0

File tree

11 files changed

+555
-0
lines changed

11 files changed

+555
-0
lines changed

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
documents.json
2+
arxiv-metadata-oai-snapshot.json
3+
secrets.json
4+
5+
meili_data/
6+
qdrant_storage/

README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Arxiv QA
2+
3+
Retrieval-augmented generation example that answers questions from Arxiv abstracts and titles.
4+
5+
## Setup
6+
7+
* Copy `secrets-example.json` and replace with your own key.
8+
* Fetch `arxiv-metadata-oai-snapshot.json`
9+
* `kaggle datasets download -d Cornell-University/arxiv`
10+
* Run `preprocess_dataset.py`
11+
* Input file: `arxiv-metadata-oai-snapshot.json`
12+
* Output file: `documents.json` (a bit smaller)
13+
* `docker compose up -d` to run MeiliSearch and Qdrant
14+
* Then
15+
* `ingest_to_meilisearch.py`
16+
* `ingest_to_qdrant.py`
17+
* You'll want a GPU 😁, use `nvitop` to check it's using GPU.
18+
* Example performance: g5.xlarge (1x A10G), ~600k abstracts, ~12 minutes
19+
* Finally `query.py` to ask some questions.
20+
21+
# Other tips
22+
23+
* You can connect to a nice server to test Meilisearch keyword lookup on `http://localhost:8080/`
24+
* `cli.py` could be useful but at the moment only exposes `meilisearch_index` and `meilisearch_client`

__init__.py

Whitespace-only changes.

cli.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import subprocess
2+
3+
command = """\
4+
import meilisearch
5+
meilisearch_client = meilisearch.Client('http://127.0.0.1:7700')
6+
meilisearch_index = meilisearch_client.index("papers")
7+
"""
8+
9+
try:
10+
subprocess.run(["ipython", "-i", "-c", command])
11+
except FileNotFoundError:
12+
print("IPython is not installed. Please install it by running: pip install ipython")

docker-compose.yml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
version: '3'
2+
services:
3+
qdrant:
4+
image: qdrant/qdrant
5+
ports:
6+
- 6333:6333
7+
volumes:
8+
- ./qdrant_storage:/qdrant/storage
9+
10+
meilisearch:
11+
image: getmeili/meilisearch:v1.2
12+
ports:
13+
- 7700:7700
14+
volumes:
15+
- ./meili_data:/meili_data
16+
17+
search-ui:
18+
image: nginx:latest
19+
volumes:
20+
- ./search_ui:/usr/share/nginx/html:ro
21+
ports:
22+
- 8080:80
23+
restart: always

ingest_to_meilisearch.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""
2+
This module reads all the docs and ingests them into MeiliSearch.
3+
"""
4+
import json
5+
import meilisearch
6+
from tqdm import tqdm
7+
8+
def yield_docs():
9+
with open("documents.json", "r") as fp:
10+
for line in fp:
11+
yield json.loads(line)
12+
13+
docs = list(yield_docs())
14+
15+
client = meilisearch.Client('http://127.0.0.1:7700')
16+
17+
index = client.index("papers")
18+
19+
# Replace any . in the doc ID with a - becasue MeiliSearch doesn't like dots
20+
for doc in docs:
21+
doc["id"] = doc["id"].replace(".", "-")
22+
23+
batch_size = 100
24+
chunked_docs = [docs[i:i + batch_size] for i in range(0, len(docs), batch_size)]
25+
26+
for doc_chunk in tqdm(chunked_docs, desc="Indexing documents"):
27+
index.add_documents(doc_chunk, primary_key="id")

ingest_to_qdrant.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
2+
3+
from qdrant_client import models, QdrantClient
4+
import hashlib
5+
from concurrent.futures import ProcessPoolExecutor
6+
import json
7+
from sentence_transformers import SentenceTransformer
8+
from tqdm import tqdm
9+
import os
10+
11+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
12+
13+
def upload_records_process(documents_chunk):
14+
qdrant = QdrantClient()
15+
16+
qdrant.upload_records("papers", [
17+
models.Record(
18+
id=hashlib.md5(doc["id"].encode()).hexdigest(),
19+
vector=doc["vector"],
20+
payload=doc
21+
) for doc in documents_chunk
22+
])
23+
24+
25+
print("Loading encoder...")
26+
encoder = SentenceTransformer('all-MiniLM-L6-v2', device='cuda')
27+
28+
print(f"Opening documents file...")
29+
30+
documents_list = []
31+
with open("documents.json", "r") as fp:
32+
for line in fp:
33+
documents_list.append(json.loads(line))
34+
35+
print(f"Indexing {len(documents_list)} documents...")
36+
37+
batch_size = 4096
38+
documents_list_chunked = [documents_list[i:i + batch_size] for i in range(0, len(documents_list), batch_size)]
39+
40+
qdrant = QdrantClient()
41+
qdrant.recreate_collection(
42+
collection_name="papers",
43+
vectors_config=models.VectorParams(
44+
size=encoder.get_sentence_embedding_dimension(), # Vector size is defined by used model
45+
distance=models.Distance.COSINE
46+
)
47+
)
48+
49+
# We want to upload the documents in parallel with continuing
50+
# to encode the next batch of documents. If we don't do this,
51+
# then we have a lot of GPU idle time while docs are being
52+
# uploaded to Qdrant.
53+
upload_executor = ProcessPoolExecutor(max_workers=3)
54+
55+
for documents_chunk in tqdm(documents_list_chunked, desc="Processing document chunks"):
56+
abstracts = encoder.encode([doc["abstract"] for doc in documents_chunk])
57+
for idx, doc in enumerate(documents_chunk):
58+
doc["vector"] = abstracts[idx].tolist()
59+
60+
upload_executor.submit(upload_records_process, documents_chunk)
61+
62+
# Wait for the executors to finish
63+
upload_executor.shutdown()

preprocess_dataset.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import json
2+
from typing import Generator
3+
4+
def get_dataset_generator(path: str) -> Generator:
5+
with open(path, "r") as fp:
6+
for line in fp:
7+
row = json.loads(line)
8+
yield row
9+
10+
11+
def filter_generator(g: Generator, filter_fn):
12+
for item in g:
13+
if filter_fn(item):
14+
yield item
15+
16+
def stop_after(g, num_items):
17+
for i, item in enumerate(g):
18+
if i == num_items:
19+
break
20+
yield item
21+
22+
def clean_document(doc):
23+
return {
24+
"id": doc["id"],
25+
"title": doc["title"].replace("\n", " "),
26+
"abstract": doc["abstract"],
27+
"categories": doc["categories"].split(" "),
28+
"update_date": doc["update_date"],
29+
}
30+
31+
documents_list = []
32+
try:
33+
with open("documents.json", "r") as fp:
34+
for line in fp:
35+
documents_list.append(json.loads(line))
36+
except FileNotFoundError:
37+
dataset_generator = get_dataset_generator(
38+
path="arxiv-metadata-oai-snapshot.json"
39+
)
40+
41+
def filter_relevant(doc):
42+
for category in doc["categories"]:
43+
if category.startswith("cs."):
44+
return True
45+
46+
return False
47+
48+
documents = map(clean_document, dataset_generator)
49+
documents = filter(filter_relevant, documents)
50+
51+
print(f"Generating in-memory documents structure")
52+
documents_list = list(documents)
53+
54+
print(f"Writing {len(documents_list)} documents...")
55+
with open("documents.json", "w") as fp:
56+
for doc in documents_list:
57+
fp.write(json.dumps(doc) + "\n")
58+
59+
print("Document examples:")
60+
for doc in documents_list[:3]:
61+
print(f"[{doc['update_date']}] {doc['title']} ({doc['categories']})")
62+

0 commit comments

Comments
 (0)