-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathanswer.py
96 lines (77 loc) · 3.7 KB
/
answer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Copyright (c) 2022 Cohere Inc. and its affiliates.
#
# Licensed under the MIT License (the "License");
# you may not use this file except in compliance with the License.
#
# You may obtain a copy of the License in the LICENSE file at the top
# level of this repository.
import numpy as np
from qa.model import get_sample_answer
from qa.search import embedding_search, get_results_paragraphs_multi_process
from qa.util import pretty_print
def trim_stop_sequences(s, stop_sequences):
"""Remove stop sequences found at the end of returned generated text."""
for stop_sequence in stop_sequences:
if s.endswith(stop_sequence):
return s[:-len(stop_sequence)]
return s
def answer(question, context, co, model, chat_history=""):
"""Answer a question given some context."""
if 'command' in model:
prompt = (
f'read the paragraph below and answer the question, if the question cannot be answered based on the context alone, write "sorry i had trouble answering this question, based on the information i found\n'
f"\n"
f"Context:\n"
f"{ context }\n"
f"\n"
f"Question: { question }\n"
"Answer:")
stop_sequences = []
else:
prompt = ("This is an example of question answering based on a text passage:\n "
f"Context:-{context}\nQuestion:\n-{question}\nAnswer:\n-")
if chat_history:
prompt = ("This is an example of factual question answering chat bot. It "
"takes the text context and answers related questions:\n "
f"Context:-{context}\nChat Log\n{chat_history}\nbot:")
stop_sequences = ["\n"]
num_generations = 4
prompt = "".join(co.tokenize(text=prompt).token_strings[-1900:])
prediction = co.generate(model=model,
prompt=prompt,
max_tokens=100,
temperature=0.3,
stop_sequences=stop_sequences,
num_generations=num_generations,
return_likelihoods="GENERATION")
generations = [[
trim_stop_sequences(prediction.generations[i].text.strip(), stop_sequences),
prediction.generations[i].likelihood
] for i in range(num_generations)]
generations = list(filter(lambda x: not x[0].isspace(), generations))
response = generations[np.argmax([g[1] for g in generations])][0]
return response.strip()
def answer_with_search(question,
co,
serp_api_token,
chat_history="",
model='command-xlarge-20221108',
embedding_model="multilingual-22-12",
url=None,
n_paragraphs=1,
verbosity=0):
"""Generates completion based on search results."""
paragraphs, paragraph_sources = get_results_paragraphs_multi_process(question, serp_api_token, url=url)
if not paragraphs:
return ("", "", "")
sample_answer = get_sample_answer(question, co)
results = embedding_search(paragraphs, paragraph_sources, sample_answer, co, model=embedding_model)
if verbosity > 1:
pprint_results = "\n".join([r[0] for r in results])
pretty_print("OKGREEN", f"all search result context: {pprint_results}")
results = results[-n_paragraphs:]
context = "\n".join([r[0] for r in results])
if verbosity:
pretty_print("OKCYAN", "relevant result context: " + context)
response = answer(question, context, co, chat_history=chat_history, model=model)
return (response, [r[1] for r in results], [r[0] for r in results])