Skip to content

Commit b9a640e

Browse files
committed
move from AutonomousUseCase to SimpleStrategy
1 parent 825a795 commit b9a640e

File tree

6 files changed

+79
-256
lines changed

6 files changed

+79
-256
lines changed

src/hackingBuddyGPT/strategies.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from hackingBuddyGPT.usecases.base import UseCase
88
from hackingBuddyGPT.utils import llm_util
99
from hackingBuddyGPT.utils.histories import HistoryCmdOnly, HistoryFull, HistoryNone
10+
from hackingBuddyGPT.utils.openai.openai_lib import OpenAILib
1011
from hackingBuddyGPT.utils.openai.openai_llm import OpenAIConnection
1112
from hackingBuddyGPT.utils.logging import log_conversation, Logger, log_param, log_section
1213
from hackingBuddyGPT.utils.capability_manager import CapabilityManager
@@ -140,3 +141,59 @@ def check_success(self, cmd:str, result:str) -> bool:
140141

141142
def postprocess_commands(self, cmd:str) -> List[str]:
142143
return [cmd]
144+
145+
@dataclass
146+
class SimpleStrategy(UseCase, abc.ABC):
147+
max_turns: int = 10
148+
149+
llm: OpenAILib = None
150+
151+
log: Logger = log_param
152+
153+
_got_root: bool = False
154+
155+
_capabilities: CapabilityManager = None
156+
157+
def init(self):
158+
super().init()
159+
self._capabilities = CapabilityManager(self.log)
160+
161+
@abc.abstractmethod
162+
def perform_round(self, turn: int):
163+
pass
164+
165+
def before_run(self):
166+
pass
167+
168+
def after_run(self):
169+
pass
170+
171+
def run(self, configuration):
172+
self.configuration = configuration
173+
self.log.start_run(self.get_name(), self.serialize_configuration(configuration))
174+
175+
self.before_run()
176+
177+
turn = 1
178+
try:
179+
while turn <= self.max_turns and not self._got_root:
180+
with self.log.section(f"round {turn}"):
181+
self.log.console.log(f"[yellow]Starting turn {turn} of {self.max_turns}")
182+
183+
self._got_root = self.perform_round(turn)
184+
185+
turn += 1
186+
187+
self.after_run()
188+
189+
# write the final result to the database and console
190+
if self._got_root:
191+
self.log.run_was_success()
192+
else:
193+
self.log.run_was_failure("maximum turn number reached")
194+
195+
return self._got_root
196+
except Exception:
197+
import traceback
198+
self.log.run_was_failure("exception occurred", details=f":\n\n{traceback.format_exc()}")
199+
raise

src/hackingBuddyGPT/usecases/agents.py

Lines changed: 0 additions & 124 deletions
This file was deleted.

src/hackingBuddyGPT/usecases/base.py

Lines changed: 2 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import abc
22
import json
3-
import argparse
43
from dataclasses import dataclass
54

65
from hackingBuddyGPT.utils.logging import Logger, log_param
7-
from typing import Dict, Type, TypeVar, Generic
6+
from typing import Dict, Type
87

9-
from hackingBuddyGPT.utils.configurable import Transparent, configurable
8+
from hackingBuddyGPT.utils.configurable import configurable
109

1110
@dataclass
1211
class UseCase(abc.ABC):
@@ -49,98 +48,8 @@ def get_name(self) -> str:
4948
"""
5049
pass
5150

52-
53-
# this runs the main loop for a bounded amount of turns or until root was achieved
54-
@dataclass
55-
class AutonomousUseCase(UseCase, abc.ABC):
56-
max_turns: int = 10
57-
58-
_got_root: bool = False
59-
60-
@abc.abstractmethod
61-
def perform_round(self, turn: int):
62-
pass
63-
64-
def before_run(self):
65-
pass
66-
67-
def after_run(self):
68-
pass
69-
70-
def run(self, configuration):
71-
self.configuration = configuration
72-
self.log.start_run(self.get_name(), self.serialize_configuration(configuration))
73-
74-
self.before_run()
75-
76-
turn = 1
77-
try:
78-
while turn <= self.max_turns and not self._got_root:
79-
with self.log.section(f"round {turn}"):
80-
self.log.console.log(f"[yellow]Starting turn {turn} of {self.max_turns}")
81-
82-
self._got_root = self.perform_round(turn)
83-
84-
turn += 1
85-
86-
self.after_run()
87-
88-
# write the final result to the database and console
89-
if self._got_root:
90-
self.log.run_was_success()
91-
else:
92-
self.log.run_was_failure("maximum turn number reached")
93-
94-
return self._got_root
95-
except Exception:
96-
import traceback
97-
self.log.run_was_failure("exception occurred", details=f":\n\n{traceback.format_exc()}")
98-
raise
99-
100-
10151
use_cases: Dict[str, configurable] = dict()
10252

103-
104-
T = TypeVar("T", bound=type)
105-
106-
107-
class AutonomousAgentUseCase(AutonomousUseCase, Generic[T]):
108-
agent: T = None
109-
110-
def perform_round(self, turn: int):
111-
raise ValueError("Do not use AutonomousAgentUseCase without supplying an agent type as generic")
112-
113-
def get_name(self) -> str:
114-
raise ValueError("Do not use AutonomousAgentUseCase without supplying an agent type as generic")
115-
116-
@classmethod
117-
def __class_getitem__(cls, item):
118-
item = dataclass(item)
119-
120-
class AutonomousAgentUseCase(AutonomousUseCase):
121-
agent: Transparent(item) = None
122-
123-
def init(self):
124-
super().init()
125-
self.agent.init()
126-
127-
def get_name(self) -> str:
128-
return self.__class__.__name__
129-
130-
def before_run(self):
131-
return self.agent.before_run()
132-
133-
def after_run(self):
134-
return self.agent.after_run()
135-
136-
def perform_round(self, turn: int):
137-
return self.agent.perform_round(turn)
138-
139-
constructed_class = dataclass(AutonomousAgentUseCase)
140-
141-
return constructed_class
142-
143-
14453
def use_case(description):
14554
def inner(cls):
14655
cls = dataclass(cls)

src/hackingBuddyGPT/usecases/web/with_explanation.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,20 @@
11
from dataclasses import field
2-
from typing import List, Any, Union, Dict, Iterable, Optional
3-
2+
from typing import List, Any, Union, Iterable, Optional
43
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage
54
from openai.types.chat.chat_completion_chunk import ChoiceDelta
65

7-
from hackingBuddyGPT.capabilities import Capability
86
from hackingBuddyGPT.capabilities.http_request import HTTPRequest
97
from hackingBuddyGPT.capabilities.submit_flag import SubmitFlag
10-
from hackingBuddyGPT.usecases.agents import Agent
11-
from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case
8+
from hackingBuddyGPT.strategies import SimpleStrategy
9+
from hackingBuddyGPT.usecases.base import use_case
1210
from hackingBuddyGPT.utils import LLMResult, tool_message
1311
from hackingBuddyGPT.utils.configurable import parameter
14-
from hackingBuddyGPT.utils.openai.openai_lib import OpenAILib
1512

1613
Prompt = List[Union[ChatCompletionMessage, ChatCompletionMessageParam]]
1714
Context = Any
1815

19-
20-
class WebTestingWithExplanation(Agent):
21-
llm: OpenAILib
16+
@use_case("Minimal implementation of a web testing use case while allowing the llm to 'talk'")
17+
class WebTestingWithExplanation(SimpleStrategy):
2218
host: str = parameter(desc="The host to test", default="http://localhost")
2319
flag_format_description: str = parameter(
2420
desc="Description of the flag provided to the LLM",
@@ -35,14 +31,16 @@ class WebTestingWithExplanation(Agent):
3531

3632
_prompt_history: Prompt = field(default_factory=list)
3733
_context: Context = field(default_factory=lambda: {"notes": list()})
38-
_capabilities: Dict[str, Capability] = field(default_factory=dict)
3934
_all_flags_found: bool = False
4035

4136
def init(self):
4237
super().init()
4338
self._context["host"] = self.host
44-
self.add_capability(SubmitFlag(self.flag_format_description, set(self.flag_template.format(flag=flag) for flag in self.flags.split(",")), success_function=self.all_flags_found))
45-
self.add_capability(HTTPRequest(self.host))
39+
self._capabilities.add_capability(SubmitFlag(self.flag_format_description, set(self.flag_template.format(flag=flag) for flag in self.flags.split(",")), success_function=self.all_flags_found))
40+
self._capabilities.add_capability(HTTPRequest(self.host))
41+
42+
def get_name(self) -> str:
43+
return self.__class__.__name__
4644

4745
def before_run(self):
4846
system_message = (
@@ -64,7 +62,7 @@ def all_flags_found(self):
6462
def perform_round(self, turn: int):
6563
prompt = self._prompt_history # TODO: in the future, this should do some context truncation
6664

67-
result_stream: Iterable[Union[ChoiceDelta, LLMResult]] = self.llm.stream_response(prompt, self.log.console, capabilities=self._capabilities, get_individual_updates=True)
65+
result_stream: Iterable[Union[ChoiceDelta, LLMResult]] = self.llm.stream_response(prompt, self.log.console, capabilities=self._capabilities._capabilities, get_individual_updates=True)
6866
result: Optional[LLMResult] = None
6967
stream_output = self.log.stream_message("assistant") # TODO: do not hardcode the role
7068
for delta in result_stream:
@@ -83,12 +81,7 @@ def perform_round(self, turn: int):
8381

8482
if message.tool_calls is not None:
8583
for tool_call in message.tool_calls:
86-
tool_result = self.run_capability_json(message_id, tool_call.id, tool_call.function.name, tool_call.function.arguments)
84+
tool_result = self._capabilities.run_capability_json(message_id, tool_call.id, tool_call.function.name, tool_call.function.arguments)
8785
self._prompt_history.append(tool_message(tool_result, tool_call.id))
8886

8987
return self._all_flags_found
90-
91-
92-
@use_case("Minimal implementation of a web testing use case while allowing the llm to 'talk'")
93-
class WebTestingWithExplanationUseCase(AutonomousAgentUseCase[WebTestingWithExplanation]):
94-
pass

0 commit comments

Comments
 (0)