diff --git a/.gitignore b/.gitignore index 5b8b06cb..04fa677a 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,10 @@ scripts/mac_ansible_hosts.ini scripts/mac_ansible_id_rsa scripts/mac_ansible_id_rsa.pub .aider* + +src/hackingBuddyGPT/usecases/web_api_testing/documentation/openapi_spec/ +src/hackingBuddyGPT/usecases/web_api_testing/documentation/reports/ +src/hackingBuddyGPT/usecases/web_api_testing/retrieve_spotify_token.py +config/my_configs/* +config/configs/* +config/configs/ \ No newline at end of file diff --git a/README.md b/README.md index b3f828a0..5cfed5e5 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,11 @@ HackingBuddyGPT helps security researchers use LLMs to discover new attack vectors and save the world (or earn bug bounties) in 50 lines of code or less. In the long run, we hope to make the world a safer place by empowering security professionals to get more hacking done by using AI. The more testing they can do, the safer all of us will get. +**🆕 New Feature**: hackingBuddyGPT now supports both SSH connections to remote targets and local shell execution for easier testing and development! + +**⚠️ WARNING**: This software will execute commands on live environments. When using local shell mode, commands will be executed on your local system, which could potentially lead to data loss, system modification, or security vulnerabilities. Always use appropriate precautions and consider using isolated environments or virtual machines for testing. + + We aim to become **THE go-to framework for security researchers** and pen-testers interested in using LLMs or LLM-based autonomous agents for security testing. To aid their experiments, we also offer re-usable [linux priv-esc benchmarks](https://github.com/ipa-lab/benchmark-privesc-linux) and publish all our findings as open-access reports. If you want to use hackingBuddyGPT and need help selecting the best LLM for your tasks, [we have a paper comparing multiple LLMs](https://arxiv.org/abs/2310.11409). @@ -68,10 +73,11 @@ the use of LLMs for web penetration-testing and web api testing. | Name | Description | Screenshot | |------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | [minimal](https://docs.hackingbuddy.ai/docs/dev-guide/dev-quickstart) | A minimal 50 LoC Linux Priv-Esc example. This is the usecase from [Build your own Agent/Usecase](#build-your-own-agentusecase) | ![A very minimal run](https://docs.hackingbuddy.ai/run_archive/2024-04-29_minimal.png) | -| [linux-privesc](https://docs.hackingbuddy.ai/docs/usecases/linux-priv-esc) | Given an SSH-connection for a low-privilege user, task the LLM to become the root user. This would be a typical Linux privilege escalation attack. We published two academic papers about this: [paper #1](https://arxiv.org/abs/2308.00121) and [paper #2](https://arxiv.org/abs/2310.11409) | ![Example wintermute run](https://docs.hackingbuddy.ai/run_archive/2024-04-06_linux.png) | +| [linux-privesc](https://docs.hackingbuddy.ai/docs/usecases/linux-priv-esc) | Given a connection (SSH or local shell) for a low-privilege user, task the LLM to become the root user. This would be a typical Linux privilege escalation attack. We published two academic papers about this: [paper #1](https://arxiv.org/abs/2308.00121) and [paper #2](https://arxiv.org/abs/2310.11409) | ![Example wintermute run](https://docs.hackingbuddy.ai/run_archive/2024-04-06_linux.png) | | [web-pentest (WIP)](https://docs.hackingbuddy.ai/docs/usecases/web) | Directly hack a webpage. Currently in heavy development and pre-alpha stage. | ![Test Run for a simple Blog Page](https://docs.hackingbuddy.ai/run_archive/2024-05-03_web.png) | | [web-api-pentest (WIP)](https://docs.hackingbuddy.ai/docs/usecases/web-api) | Directly test a REST API. Currently in heavy development and pre-alpha stage. (Documentation and testing of REST API.) | Documentation:![web_api_documentation.png](https://docs.hackingbuddy.ai/run_archive/2024-05-15_web-api_documentation.png) Testing:![web_api_testing.png](https://docs.hackingbuddy.ai/run_archive/2024-05-15_web-api.png) | -| [extended linux-privesc](https://docs.hackingbuddy.ai/docs/usecases/extended-linux-privesc) | This usecases extends linux-privesc with additional features such as retrieval augmented generation (RAG) or chain-of-thought (CoT) | ![Extended Linux Privilege Escalation Run](https://docs.hackingbuddy.ai/run_archive/2025-4-14_extended_privesc_usecase_1.png) ![Extended Linux Privilege Escalation Run](https://docs.hackingbuddy.ai/run_archive/2025-4-14_extended_privesc_usecase_1.png) | +| [extended linux-privesc](https://docs.hackingbuddy.ai/docs/usecases/extended-linux-privesc) | This usecases extends linux-privesc with additional features such as retrieval augmented generation (RAG) or chain-of-thought (CoT) | ![Extended Linux Privilege Escalation Run](https://docs.hackingbuddy.ai/run_archive/2025-4-14_extended_privesc_usecase_1.png) ![Extended Linux Privilege Escalation Run](https://docs.hackingbuddy.ai/run_archive/2025-4-14_extended_privesc_usecase_2.png) | + ## Build your own Agent/Usecase So you want to create your own LLM hacking agent? We've got you covered and taken care of the tedious groundwork. @@ -79,7 +85,7 @@ So you want to create your own LLM hacking agent? We've got you covered and take Create a new usecase and implement `perform_round` containing all system/LLM interactions. We provide multiple helper and base classes so that a new experiment can be implemented in a few dozen lines of code. Tedious tasks, such as connecting to the LLM, logging, etc. are taken care of by our framework. Check our [developer quickstart quide](https://docs.hackingbuddy.ai/docs/dev-guide/dev-quickstart) for more information. -The following would create a new (minimal) linux privilege-escalation agent. Through using our infrastructure, this already uses configurable LLM-connections (e.g., for testing OpenAI or locally run LLMs), logs trace data to a local sqlite database for each run, implements a round limit (after which the agent will stop if root has not been achieved until then) and can connect to a linux target over SSH for fully-autonomous command execution (as well as password guessing). +The following would create a new (minimal) linux privilege-escalation agent. Through using our infrastructure, this already uses configurable LLM-connections (e.g., for testing OpenAI or locally run LLMs), logs trace data to a local sqlite database for each run, implements a round limit (after which the agent will stop if root has not been achieved until then) and can connect to a target system either locally or over SSH for fully-autonomous command execution (as well as password guessing). ~~~ python template_dir = pathlib.Path(__file__).parent @@ -155,7 +161,9 @@ We try to keep our python dependencies as light as possible. This should allow f 1. an OpenAI API account, you can find the needed keys [in your account page](https://platform.openai.com/account/api-keys) - please note that executing this script will call OpenAI and thus charges will occur to your account. Please keep track of those. -2. a potential target that is accessible over SSH. You can either use a deliberately vulnerable machine such as [Lin.Security.1](https://www.vulnhub.com/entry/) or a security benchmark such as our [linux priv-esc benchmark](https://github.com/ipa-lab/benchmark-privesc-linux). +2. a target environment to test against. You have two options: + - **Local Shell**: Use your local system (useful for testing and development) + - **SSH Target**: A remote machine accessible over SSH. You can use a deliberately vulnerable machine such as [Lin.Security.1](https://www.vulnhub.com/entry/) or a security benchmark such as our [linux priv-esc benchmark](https://github.com/ipa-lab/benchmark-privesc-linux). To get everything up and running, clone the repo, download requirements, setup API keys and credentials, and start `wintermute.py`: @@ -229,11 +237,45 @@ usage: src/hackingBuddyGPT/cli/wintermute.py LinuxPrivesc [--help] [--config con --conn.port='2222' (default from .env file, alternatives: 22 from builtin) ``` -### Provide a Target Machine over SSH +### Connection Options: Local Shell vs SSH + +hackingBuddyGPT now supports two connection modes: + +#### Local Shell Mode +Use your local system for testing and development. This is useful for quick experimentation without needing a separate target machine. + +**Setup Steps:** +1. First, create a new tmux session with a specific name: + ```bash + $ tmux new-session -s + ``` + +2. Once you have the tmux shell running, use hackingBuddyGPT to interact with it: + ```bash + # Local shell with tmux session + $ python src/hackingBuddyGPT/cli/wintermute.py LinuxPrivesc --conn=local_shell --conn.tmux_session= + ``` + +**Example:** +```bash +# Step 1: Create tmux session named "hacking_session" +$ tmux new-session -s hacking_session + +# Step 2: In another terminal, run hackingBuddyGPT +$ python src/hackingBuddyGPT/cli/wintermute.py LinuxPrivesc --conn=local_shell --conn.tmux_session=hacking_session +``` + +#### SSH Mode +Connect to a remote target machine over SSH. This is the traditional mode for testing against vulnerable VMs. + +```bash +# SSH connection (note the updated format with --conn=ssh) +$ python src/hackingBuddyGPT/cli/wintermute.py LinuxPrivesc --conn=ssh --conn.host=192.168.122.151 --conn.username=lowpriv --conn.password=trustno1 +``` -The next important part is having a machine that we can run our agent against. In our case, the target machine will be situated at `192.168.122.151`. +When using SSH mode, the target machine should be situated at your specified IP address (e.g., `192.168.122.151` in the example above). -We are using vulnerable Linux systems running in Virtual Machines for this. Never run this against real systems. +We are using vulnerable Linux systems running in Virtual Machines for SSH testing. Never run this against real production systems. > 💡 **We also provide vulnerable machines!** > @@ -277,9 +319,13 @@ Finally we can run hackingBuddyGPT against our provided test VM. Enjoy! With that out of the way, let's look at an example hackingBuddyGPT run. Each run is structured in rounds. At the start of each round, hackingBuddyGPT asks a LLM for the next command to execute (e.g., `whoami`) for the first round. It then executes that command on the virtual machine, prints its output and starts a new round (in which it also includes the output of prior rounds) until it reaches step number 10 or becomes root: ```bash -# start wintermute, i.e., attack the configured virtual machine -$ python src/hackingBuddyGPT/cli/wintermute.py LinuxPrivesc --llm.api_key=sk...ChangeMeToYourOpenAiApiKey --llm.model=gpt-4-turbo --llm.context_size=8192 --conn.host=192.168.122.151 --conn.username=lowpriv --conn.password=trustno1 --conn.hostname=test1 +# Example 1: Using local shell with tmux session +# First create the tmux session: tmux new-session -s hacking_session +# Then run hackingBuddyGPT: +$ python src/hackingBuddyGPT/cli/wintermute.py LinuxPrivesc --llm.api_key=sk...ChangeMeToYourOpenAiApiKey --llm.model=gpt-4-turbo --llm.context_size=8192 --conn=local_shell --conn.tmux_session=hacking_session +# Example 2: Using SSH connection (updated format) +$ python src/hackingBuddyGPT/cli/wintermute.py LinuxPrivesc --llm.api_key=sk...ChangeMeToYourOpenAiApiKey --llm.model=gpt-4-turbo --llm.context_size=8192 --conn=ssh --conn.host=192.168.122.151 --conn.username=lowpriv --conn.password=trustno1 --conn.hostname=test1 # install dependencies for testing if you want to run the tests $ pip install '.[testing]' diff --git a/pyproject.toml b/pyproject.toml index 516876ee..d0961ad1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ description = "Helping Ethical Hackers use LLMs in 50 lines of code" readme = "README.md" keywords = ["hacking", "pen-testing", "LLM", "AI", "agent"] requires-python = ">=3.10" -version = "0.4.0" +version = "0.5.0" license = { file = "LICENSE" } classifiers = [ "Programming Language :: Python :: 3", @@ -45,11 +45,15 @@ dependencies = [ 'uvicorn[standard] == 0.30.6', 'dataclasses_json == 0.6.7', 'websockets == 13.1', - 'langchain-community', - 'langchain-openai', + 'pandas', + 'faker', + 'fpdf', + 'langchain_core', + 'langchain_community', + 'langchain_chroma', + 'langchain_openai', 'markdown', 'chromadb', - 'langchain-chroma', ] [project.urls] @@ -69,7 +73,7 @@ where = ["src"] pythonpath = "src" addopts = ["--import-mode=importlib"] [project.optional-dependencies] -testing = ['pytest', 'pytest-mock'] +testing = ['pytest', 'pytest-mock', 'pandas', 'faker', 'langchain_core'] dev = [ 'ruff', ] diff --git a/src/hackingBuddyGPT/capabilities/http_request.py b/src/hackingBuddyGPT/capabilities/http_request.py index e1eb7622..874cf253 100644 --- a/src/hackingBuddyGPT/capabilities/http_request.py +++ b/src/hackingBuddyGPT/capabilities/http_request.py @@ -45,18 +45,11 @@ def __call__( body_is_base64: Optional[bool] = False, headers: Optional[Dict[str, str]] = None, ) -> str: + if body is not None and body_is_base64: body = base64.b64decode(body).decode() - if self.host[-1] != "/": + if self.host[-1] != "/" and not path.startswith("/"): path = "/" + path - resp = self._client.request( - method, - self.host + path, - params=query, - data=body, - headers=headers, - allow_redirects=self.follow_redirects, - ) try: resp = self._client.request( method, diff --git a/src/hackingBuddyGPT/capabilities/local_shell.py b/src/hackingBuddyGPT/capabilities/local_shell.py new file mode 100644 index 00000000..4e90754a --- /dev/null +++ b/src/hackingBuddyGPT/capabilities/local_shell.py @@ -0,0 +1,37 @@ +import re +from dataclasses import dataclass +from typing import Tuple + +from hackingBuddyGPT.capabilities import Capability +from hackingBuddyGPT.utils.local_shell import LocalShellConnection + + +@dataclass +class LocalShellCapability(Capability): + conn: LocalShellConnection + + def describe(self) -> str: + return "give a command to be executed and I will respond with the terminal output when running this command on the shell via tmux. The given command must not require user interaction. Do not use quotation marks in front and after your command." + + def get_name(self): + return "local_exec" + + def _got_root(self, output: str) -> bool: + """Check if we got root access based on the command output.""" + if not output.strip(): + return False + + lines = output.strip().split('\n') + last_line = lines[-1] if lines else '' + + # Check for common root indicators + return ( + "root" in output.lower() or + last_line.strip().endswith("#") or + "root@" in last_line or + last_line.strip() == "#" + ) + + def __call__(self, cmd: str) -> Tuple[str, bool]: + out, _, _ = self.conn.run(cmd) # This is CORRECT - use the commented version + return out, self._got_root(out) \ No newline at end of file diff --git a/src/hackingBuddyGPT/capabilities/parsed_information.py b/src/hackingBuddyGPT/capabilities/parsed_information.py new file mode 100644 index 00000000..ece638e2 --- /dev/null +++ b/src/hackingBuddyGPT/capabilities/parsed_information.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass, field +from typing import Dict, Any, List, Tuple +from hackingBuddyGPT.capabilities import Capability + + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple + +@dataclass +class ParsedInformation(Capability): + status_code: str + reason_phrase: Dict[str, Any] = field(default_factory=dict) + headers: Dict[str, Any] = field(default_factory=dict) + response_body: Dict[str, Any] = field(default_factory=dict) + registry: List[Tuple[str, str, str, str]] = field(default_factory=list) + + def describe(self) -> str: + """ + Returns a description of the test case. + """ + return f"Parsed information for {self.status_code}, reason_phrase: {self.reason_phrase}, headers: {self.headers}, response_body: {self.response_body} " + def __call__(self, status_code: str, reason_phrase: str, headers: str, response_body:str) -> dict: + self.registry.append((status_code, response_body, headers,response_body)) + + return {"status_code": status_code, "reason_phrase": reason_phrase, "headers": headers, "response_body": response_body} diff --git a/src/hackingBuddyGPT/capabilities/python_test_case.py b/src/hackingBuddyGPT/capabilities/python_test_case.py new file mode 100644 index 00000000..f6b2dc8e --- /dev/null +++ b/src/hackingBuddyGPT/capabilities/python_test_case.py @@ -0,0 +1,22 @@ + +from hackingBuddyGPT.capabilities import Capability + + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple + +@dataclass +class PythonTestCase(Capability): + description: str + input: Dict[str, Any] = field(default_factory=dict) + expected_output: Dict[str, Any] = field(default_factory=dict) + registry: List[Tuple[str, dict, dict]] = field(default_factory=list) + + def describe(self) -> str: + """ + Returns a description of the test case. + """ + return f"Test Case: {self.description}\nInput: {self.input}\nExpected Output: {self.expected_output}" + def __call__(self, description: str, input: dict, expected_output: dict) -> dict: + self.registry.append((description, input, expected_output)) + return {"description": description, "input": input, "expected_output": expected_output} diff --git a/src/hackingBuddyGPT/usecases/privesc/linux.py b/src/hackingBuddyGPT/usecases/privesc/linux.py old mode 100644 new mode 100755 index 7b9228e6..38a2d755 --- a/src/hackingBuddyGPT/usecases/privesc/linux.py +++ b/src/hackingBuddyGPT/usecases/privesc/linux.py @@ -1,18 +1,24 @@ from hackingBuddyGPT.capabilities import SSHRunCommand, SSHTestCredential +from hackingBuddyGPT.capabilities.local_shell import LocalShellCapability from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case from hackingBuddyGPT.utils import SSHConnection - +from hackingBuddyGPT.utils.local_shell import LocalShellConnection +from typing import Union from .common import Privesc class LinuxPrivesc(Privesc): - conn: SSHConnection = None + conn: Union[SSHConnection, LocalShellConnection] = None system: str = "linux" def init(self): super().init() - self.add_capability(SSHRunCommand(conn=self.conn), default=True) - self.add_capability(SSHTestCredential(conn=self.conn)) + if isinstance(self.conn, LocalShellConnection): + self.add_capability(LocalShellCapability(conn=self.conn), default=True) + self.add_capability(SSHTestCredential(conn=self.conn)) + else: + self.add_capability(SSHRunCommand(conn=self.conn), default=True) + self.add_capability(SSHTestCredential(conn=self.conn)) @use_case("Linux Privilege Escalation") diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/__init__.py b/src/hackingBuddyGPT/usecases/web_api_testing/__init__.py index bae1cbfc..42edb2bd 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/__init__.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/__init__.py @@ -1,2 +1,5 @@ from .simple_openapi_documentation import SimpleWebAPIDocumentation from .simple_web_api_testing import SimpleWebAPITesting +from . import response_processing +from . import documentation +from . import testing diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/openapi_specification_handler.py b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/openapi_specification_handler.py index 3e9d7059..25482ad0 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/openapi_specification_handler.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/openapi_specification_handler.py @@ -1,15 +1,13 @@ import os +import re from collections import defaultdict from datetime import datetime - -import pydantic_core import yaml -from rich.panel import Panel - from hackingBuddyGPT.capabilities.yamlFile import YAMLFile +from hackingBuddyGPT.usecases.web_api_testing.documentation.pattern_matcher import PatternMatcher +from hackingBuddyGPT.utils.prompt_generation.information import PromptStrategy from hackingBuddyGPT.usecases.web_api_testing.response_processing import ResponseHandler from hackingBuddyGPT.usecases.web_api_testing.utils import LLMHandler -from hackingBuddyGPT.utils import tool_message class OpenAPISpecificationHandler(object): @@ -28,39 +26,83 @@ class OpenAPISpecificationHandler(object): _capabilities (dict): A dictionary to store capabilities related to YAML file handling. """ - def __init__(self, llm_handler: LLMHandler, response_handler: ResponseHandler): + def __init__(self, llm_handler: LLMHandler, response_handler: ResponseHandler, strategy: PromptStrategy, url: str, + description: str, name: str) -> None: """ Initializes the handler with a template OpenAPI specification. Args: llm_handler (object): An instance of the LLM handler for interacting with the LLM. response_handler (object): An instance of the response handler for processing API responses. + strategy (PromptStrategy): An instance of the PromptStrategy class. """ + self.unsuccessful_methods = {} self.response_handler = response_handler self.schemas = {} + self.query_params = {} self.endpoint_methods = {} - self.filename = f"openapi_spec_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.yaml" + self.endpoint_examples = {} + date = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + self.filename = f"{name}_spec.yaml" self.openapi_spec = { "openapi": "3.0.0", "info": { - "title": "Generated API Documentation", + "title": f"Generated API Documentation {name}", "version": "1.0", - "description": "Automatically generated description of the API.", + "description": f"{description} + \nUrl:{url}", }, - "servers": [{"url": "https://jsonplaceholder.typicode.com"}], + "servers": [{"url": f"{url}"}], # https://jsonplaceholder.typicode.com "endpoints": {}, "components": {"schemas": {}}, } self.llm_handler = llm_handler current_path = os.path.dirname(os.path.abspath(__file__)) - self.file_path = os.path.join(current_path, "openapi_spec") + + self.file_path = os.path.join(current_path, "openapi_spec", str(strategy).split(".")[1].lower(), name.lower(), date) + os.makedirs(self.file_path, exist_ok=True) self.file = os.path.join(self.file_path, self.filename) + self._capabilities = {"yaml": YAMLFile()} + self.unsuccessful_paths = [] + + self.pattern_matcher = PatternMatcher() def is_partial_match(self, element, string_list): - return any(element in string or string in element for string in string_list) + """ + Checks if the given path `element` partially matches any path in `string_list`, + treating path parameters (e.g., `{id}`) as wildcards. + + A partial match is defined as: + - Having the same number of path segments. + - Matching all static segments (segments not wrapped in `{}`). + + This is useful when comparing generalized OpenAPI paths with actual request paths. - def update_openapi_spec(self, resp, result): + Args: + element (str): The path to check for partial matches (e.g., "/users/123"). + string_list (List[str]): A list of known paths (e.g., ["/users/{id}", "/posts/{postId}"]). + + Returns: + bool: True if a partial match is found, False otherwise. + """ + element_parts = element.split("/") + + for string in string_list: + string_parts = string.split("/") + if len(element_parts) != len(string_parts): + continue # Skip if structure differs + + for e_part, s_part in zip(element_parts, string_parts): + if s_part.startswith("{") and s_part.endswith("}"): + continue # Skip placeholders + if e_part != s_part: + break # No match + else: + return True # All static parts matched + + return False + + def update_openapi_spec(self, resp, result, prompt_engineer): """ Updates the OpenAPI specification based on the API response provided. @@ -69,54 +111,124 @@ def update_openapi_spec(self, resp, result): result (str): The result of the API call. """ request = resp.action + status_code, status_message = self.extract_status_code_and_message(result) if request.__class__.__name__ == "RecordNote": # TODO: check why isinstance does not work - self.check_openapi_spec(resp) - elif request.__class__.__name__ == "HTTPRequest": + # self.check_openapi_spec(resp) + return list(self.openapi_spec["endpoints"].keys()) + + if request.__class__.__name__ == "HTTPRequest": path = request.path method = request.method - print(f"method: {method}") - # Ensure that path and method are not None and method has no numeric characters - # Ensure path and method are valid and method has no numeric characters - if path and method: - endpoint_methods = self.endpoint_methods - endpoints = self.openapi_spec["endpoints"] - x = path.split("/")[1] - - # Initialize the path if not already present - if path not in endpoints and x != "": - endpoints[path] = {} - if "1" not in path: - endpoint_methods[path] = [] - - # Update the method description within the path - example, reference, self.openapi_spec = self.response_handler.parse_http_response_to_openapi_example( - self.openapi_spec, result, path, method - ) - self.schemas = self.openapi_spec["components"]["schemas"] - - if example or reference: - endpoints[path][method.lower()] = { - "summary": f"{method} operation on {path}", - "responses": { - "200": { - "description": "Successful response", - "content": {"application/json": {"schema": {"$ref": reference}, "examples": example}}, - } - }, - } + path = self.replace_id_with_placeholder(path, prompt_engineer) + if not path or not method or path == "/" or not path.startswith("/"): + return list(self.openapi_spec["endpoints"].keys()) + + # replace specific values with generic values for doc + path = self.pattern_matcher.replace_according_to_pattern(path) + + if path in self.unsuccessful_paths: + return list(self.openapi_spec["endpoints"].keys()) - if "1" not in path and x != "": - endpoint_methods[path].append(method) - elif self.is_partial_match(x, endpoints.keys()): - path = f"/{x}" - print(f"endpoint methods = {endpoint_methods}") - print(f"new path:{path}") - endpoint_methods[path].append(method) + endpoint_methods = self.endpoint_methods + endpoints = self.openapi_spec["endpoints"] - endpoint_methods[path] = list(set(endpoint_methods[path])) + # Extract the main part of the path for checking partial matches + path_parts = path.split("/") + main_path = path if len(path_parts) > 1 else "" + + # Initialize the path if it's not present and is valid + if status_code.startswith("20"): + if path not in endpoints and "?" not in path: + endpoints[path] = {} + endpoint_methods[path] = [] + self.endpoint_examples[path] = {} + + unsuccessful_status_codes = ["400", "404", "500"] + + if path in endpoints and (status_code in unsuccessful_status_codes): + + self.unsuccessful_paths.append(path) + if path not in self.unsuccessful_methods: + self.unsuccessful_methods[path] = [] + self.unsuccessful_methods[path].append(method) + return list(self.openapi_spec["endpoints"].keys()) + + # Parse the response into OpenAPI example and reference + example, reference, self.openapi_spec = self.response_handler.parse_http_response_to_openapi_example( + self.openapi_spec, result, path, method + ) + + self.schemas = self.openapi_spec["components"]["schemas"] + + # Check if the path exists in the dictionary and the method is not already defined for this path + if path in endpoints and method.lower() not in endpoints[path]: + # Create a new dictionary for this method if it doesn't exist + endpoints[path][method.lower()] = { + "summary": f"{method} operation on {path}", + "responses": { + f"{status_code}": { + "description": status_message, + "content": {} + } + } + } + + if path in endpoint_methods: + endpoint_methods[path] = [] + + # Update endpoint methods for the path + if path not in endpoint_methods: + endpoint_methods[path] = [] + endpoint_methods[path].append(method) + + # Ensure uniqueness of methods for each path + endpoint_methods[path] = list(set(endpoint_methods[path])) + + # Check if there's a need to add or update the 'content' based on the conditions provided + if example or reference or status_message == "No Content" and not path.__contains__("?"): + if isinstance(example, list): + example = example[0] + # Ensure the path and method exists and has the 'responses' structure + if (path in endpoints and method.lower() in endpoints[path]): + if "responses" in endpoints[path][method.lower()].keys() and f"{status_code}" in \ + endpoints[path][method.lower()]["responses"]: + # Get the response content dictionary + response_content = endpoints[path][method.lower()]["responses"][f"{status_code}"]["content"] + + # Assign a new structure to 'content' under the specific status code + response_content["application/json"] = { + "schema": {"$ref": reference}, + "examples": example + } + + self.endpoint_examples[path] = example + + # Add query parameters to the OpenAPI path item object + if path.__contains__('?'): + query_params_dict = self.pattern_matcher.extract_query_params(path) + new_path = path.split("?")[0] + if query_params_dict != {}: + if path not in endpoints.keys(): + endpoints[new_path] = {} + if method.lower() not in endpoints[new_path]: + endpoints[new_path][method.lower()] = {} + endpoints[new_path][method.lower()].setdefault('parameters', []) + for param, value in query_params_dict.items(): + param_entry = { + "name": param, + "in": "query", + "required": True, # Change this as needed + "schema": { + "type": self.get_type(value) # Adjust the type based on actual data type + } + } + endpoints[new_path][method.lower()]['parameters'].append(param_entry) + if path not in self.query_params.keys(): + self.query_params[new_path] = [] + self.query_params[new_path].append(param) - return list(endpoints.keys()) + return list(self.openapi_spec["endpoints"].keys()) def write_openapi_to_yaml(self): """ @@ -142,25 +254,32 @@ def write_openapi_to_yaml(self): except Exception as e: raise Exception(f"Error writing YAML file: {e}") from e - def check_openapi_spec(self, note): + def _update_documentation(self, response, result, result_str, prompt_engineer): """ - Uses OpenAI's GPT model to generate a complete OpenAPI specification based on a natural language description. + Updates the OpenAPI documentation based on a new API response and result string. - Args: - note (object): The note object containing the description of the API. - """ - description = self.response_handler.extract_description(note) - from hackingBuddyGPT.usecases.web_api_testing.utils.documentation.parsing.yaml_assistant import ( - YamlFileAssistant, - ) + This method performs the following: + - Updates the OpenAPI specification using the latest API response. + - Writes the updated OpenAPI spec to a YAML file if new endpoints are discovered. + - Updates the schemas used by the `prompt_engineer`. + - Constructs a mapping of HTTP methods to endpoints and stores it in the prompt helper. - yaml_file_assistant = YamlFileAssistant(self.file_path, self.llm_handler) - yaml_file_assistant.run(description) + Args: + response (Any): The original API response object, possibly including metadata or status. + result (Any): The raw result of executing the API call, potentially including headers and body. + result_str (str): A string representation of the HTTP response, including status line and body. + prompt_engineer (PromptEngineer): An instance of the prompt engineer responsible for generating prompts and managing discovered schema information. - def _update_documentation(self, response, result, prompt_engineer): - prompt_engineer.prompt_helper.found_endpoints = self.update_openapi_spec(response, result) - self.write_openapi_to_yaml() - prompt_engineer.prompt_helper.schemas = self.schemas + Returns: + PromptEngineer: The updated prompt engineer with any new endpoint or schema information applied. + + """ + if result_str is None: + return prompt_engineer + endpoints = self.update_openapi_spec(response, result, prompt_engineer) + if prompt_engineer.prompt_helper.found_endpoints != endpoints and endpoints != [] and len(endpoints) != 1: + self.write_openapi_to_yaml() + prompt_engineer.prompt_helper.schemas = self.schemas http_methods_dict = defaultdict(list) for endpoint, methods in self.endpoint_methods.items(): @@ -171,28 +290,151 @@ def _update_documentation(self, response, result, prompt_engineer): prompt_engineer.prompt_helper.endpoint_methods = self.endpoint_methods return prompt_engineer - def document_response(self, completion, response, log, prompt_history, prompt_engineer): - message = completion.choices[0].message - tool_call_id = message.tool_calls[0].id - command = pydantic_core.to_json(response).decode() + def document_response(self, result, response, result_str, prompt_history, prompt_engineer): + """ + Processes an API response and updates the OpenAPI documentation if the response is valid. + + This method filters out invalid or placeholder responses using a set of known flags. + If the response appears valid, it triggers the `_update_documentation()` method + to update the OpenAPI spec and associated prompt engineering logic. - log.console.print(Panel(command, title="assistant")) - prompt_history.append(message) + Args: + result (Any): The raw execution result, typically the HTTP response body or object. + response (Any): The full API response object, potentially containing metadata or headers. + result_str (str): A string representation of the HTTP response for validation and parsing. + prompt_history (Any): The accumulated history of prompt interactions. + prompt_engineer (PromptEngineer): Instance responsible for generating and managing prompts. - with log.console.status("[bold green]Executing that command..."): - result = response.execute() - log.console.print(Panel(result[:30], title="tool")) - result_str = self.response_handler.parse_http_status_line(result) - prompt_history.append(tool_message(result_str, tool_call_id)) + Returns: + Tuple[Any, PromptEngineer]: A tuple containing the unchanged `prompt_history` and + the (potentially updated) `prompt_engineer`. + """ - invalid_flags = {"recorded", "Not a valid HTTP method", "404", "Client Error: Not Found"} - if result_str not in invalid_flags or any(flag in result_str for flag in invalid_flags): - prompt_engineer = self._update_documentation(response, result, prompt_engineer) + invalid_flags = {"recorded"} + if result_str not in invalid_flags or any(flag in result_str for flag in invalid_flags): + prompt_engineer = self._update_documentation(response, result, result_str, prompt_engineer) - return log, prompt_history, prompt_engineer + return prompt_history, prompt_engineer def found_all_endpoints(self): + """ + Determines whether a sufficient number of API endpoints have been discovered. + + Currently, this uses a simple heuristic: if the number of endpoint-method pairs + is at least 10, it is assumed that all relevant endpoints have been found. + + Returns: + bool: True if at least 10 endpoint-method entries exist, False otherwise. + """ if len(self.endpoint_methods.items()) < 10: return False else: return True + + def get_type(self, value): + """ + Determines the data type of a given string value. + + Checks whether the input string represents an integer, a floating-point number (double), + or should be treated as a generic string. + + Args: + value (str): The value to inspect. + + Returns: + str: One of "integer", "double", or "string" depending on the detected type. + """ + + def is_double(s): + # Matches numbers like -123.456, +7.890, and excludes integers + return re.fullmatch(r"[+-]?(\d+\.\d*|\.\d+)([eE][+-]?\d+)?", s) is not None + + if value.isdigit(): + return "integer" + elif is_double(value): + return "double" + else: + return "string" + + def extract_status_code_and_message(self, result): + """ + Extracts the HTTP status code and status message from a response string. + + Args: + result (str): A string containing the full HTTP response or just the status line. + + Returns: + Tuple[Optional[str], Optional[str]]: A tuple containing the HTTP status code and message. + Returns (None, None) if the pattern is not matched. + """ + if not isinstance(result, str): + result = str(result) + match = re.search(r"^HTTP/\d\.\d\s+(\d+)\s+(.*)", result, re.MULTILINE) + if match: + status_code = match.group(1) + status_message = match.group(2).strip() + return status_code, status_message + else: + return None, None + + def replace_crypto_with_id(self, path): + """ + Replaces any known cryptocurrency name in a URL path with a placeholder `{id}`. + + Useful for generalizing dynamic paths when generating or matching OpenAPI specs. + + Args: + path (str): The URL path to process. + + Returns: + str: The path with any matching cryptocurrency name replaced by `{id}`. + """ + + # Default list of cryptos to detect + cryptos = ["bitcoin", "ethereum", "litecoin", "dogecoin", + "cardano", "solana"] + + # Convert to lowercase for the match, but preserve the original path for reconstruction if you prefer + lower_path = path.lower() + + for crypto in cryptos: + if crypto in lower_path: + # Example approach: split by '/' and replace the segment that matches crypto + parts = path.split('/') + replaced_any = False + for i, segment in enumerate(parts): + if segment.lower() == crypto: + parts[i] = "{id}" + if segment.lower() == crypto: + parts[i] = "{id}" + replaced_any = True + if replaced_any: + return "/".join(parts) + + return path + + def replace_id_with_placeholder(self, path, prompt_engineer): + """ + Replaces numeric IDs in the URL path with a placeholder `{id}` for generalization. + + This function is used to abstract hardcoded numeric values (e.g., `/users/1`) in the path + to a parameterized form (e.g., `/users/{id}`), which is helpful when building or inferring + OpenAPI specifications. + + Behavior varies slightly depending on the current step tracked by the `prompt_engineer`. + + Args: + path (str): The URL path to process. + prompt_engineer (PromptEngineer): An object containing context about the current prompt + and parsing state, specifically the current step of API exploration. + + Returns: + str: The updated path with numeric IDs replaced by `{id}`. + """ + if "1" in path: + path = path.replace("1", "{id}") + if prompt_engineer.prompt_helper.current_step == 2: + parts = [part.strip() for part in path.split("/") if part.strip()] + if len(parts) > 1: + path = parts[0] + "/{id}" + return path diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/openapi_converter.py b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/openapi_converter.py index 3f1156f5..0f23465d 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/openapi_converter.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/openapi_converter.py @@ -84,15 +84,78 @@ def json_to_yaml(self, json_filepath): """ return self.convert_file(json_filepath, "yaml", "json", "yaml") + def extract_openapi_info(self, openapi_spec_file, output_path=""): + """ + Extracts relevant information from an OpenAPI specification and writes it to a JSON file. + + Args: + openapi_spec (dict): The OpenAPI specification loaded as a dictionary. + output_file_path (str): Path to save the extracted information in JSON format. + + Returns: + dict: The extracted information saved in JSON format. + """ + openapi_spec = json.load(open(openapi_spec_file)) + + # Extract the API description and host URL + description = openapi_spec.get("info", {}).get("description", "No description provided.") + host = openapi_spec.get("servers", [{}])[0].get("url", "No host URL provided.") + + # Extract correct endpoints and query parameters + correct_endpoints = [] + query_params = {} + + for path, path_item in openapi_spec.get("paths", {}).items(): + correct_endpoints.append(path) + # Collect query parameters for each endpoint + endpoint_query_params = [] + for method, operation in path_item.items(): + if isinstance(operation, dict): + if "parameters" in operation.keys(): + parameters = operation.get("parameters", []) + for param in parameters: + if param.get("in") == "query": + endpoint_query_params.append(param.get("name")) + + if endpoint_query_params: + query_params[path] = endpoint_query_params + + # Create the final output structure + extracted_info = { + "token": "your_api_token_here", + "host": host, + "description": description, + "correct_endpoints": correct_endpoints, + "query_params": query_params + } + filename = os.path.basename(openapi_spec_file) + filename = filename.replace("_oas", "_config") + base_name, _ = os.path.splitext(filename) + output_filename = f"{base_name}.json" + output_path = os.path.join(output_path, output_filename) + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + # Write to JSON file + with open(output_path, 'w') as json_file: + json.dump(extracted_info, json_file, indent=2) + print(f'output path:{output_path}') + + return extracted_info + # Usage example if __name__ == "__main__": - yaml_input = "/home/diana/Desktop/masterthesis/hackingBuddyGPT/src/hackingBuddyGPT/usecases/web_api_testing/openapi_spec/openapi_spec_2024-06-13_17-16-25.yaml" + # yaml_input = "src/hackingBuddyGPT/usecases/web_api_testing/configs/test_config.json/hard/coincap_oas.json" converter = OpenAPISpecificationConverter("converted_files") - # Convert YAML to JSON - json_file = converter.yaml_to_json(yaml_input) - - # Convert JSON to YAML - if json_file: - converter.json_to_yaml(json_file) + ## Convert YAML to JSON + # json_file = converter.yaml_to_json(yaml_input) + # + ## Convert JSON to YAML + # if json_file: + # converter.json_to_yaml(json_file) + + openapi_path = "/home/diana/Desktop/masterthesis/00/hackingBuddyGPT/tests/test_files/oas/fakeapi_oas.json" + converter.extract_openapi_info(openapi_path, + output_path="/home/diana/Desktop/masterthesis/00/hackingBuddyGPT/tests/test_files") diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/openapi_parser.py b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/openapi_parser.py index 815cb0c5..bd36718d 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/openapi_parser.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/parsing/openapi_parser.py @@ -1,3 +1,6 @@ +import json +import os +from pathlib import Path from typing import Dict, List, Union import yaml @@ -20,17 +23,22 @@ def __init__(self, filepath: str): filepath (str): The path to the OpenAPI specification YAML file. """ self.filepath: str = filepath - self.api_data: Dict[str, Union[Dict, List]] = self.load_yaml() + self.api_data: Dict[str, Union[Dict, List]] = self.load_file(filepath=self.find_oas(filepath=filepath)) + self.oas_path = self.find_oas(filepath) - def load_yaml(self) -> Dict[str, Union[Dict, List]]: + def load_file(self, filepath="") -> Dict[str, Union[Dict, List]]: """ Loads YAML data from the specified file. Returns: Dict[str, Union[Dict, List]]: The parsed data from the YAML file. """ - with open(self.filepath, "r") as file: - return yaml.safe_load(file) + if filepath.endswith(".yaml"): + with open(self.filepath, "r") as file: + return yaml.safe_load(file) + else: + with open(filepath, 'r', encoding='utf-8') as file: + return json.load(file) def _get_servers(self) -> List[str]: """ @@ -41,7 +49,7 @@ def _get_servers(self) -> List[str]: """ return [server["url"] for server in self.api_data.get("servers", [])] - def get_paths(self) -> Dict[str, Dict[str, Dict]]: + def get_endpoints(self) -> Dict[str, Dict[str, Dict]]: """ Retrieves all API paths and their methods from the OpenAPI specification. @@ -64,7 +72,9 @@ def _get_operations(self, path: str) -> Dict[str, Dict]: Returns: Dict[str, Dict]: A dictionary with methods as keys and operation details as values. """ - return self.api_data["paths"].get(path, {}) + valid_methods = {"get", "post", "put", "delete", "patch", "head", "options", "trace"} + path_item = self.api_data.get("paths", {}).get(path, {}) + return {method: details for method, details in path_item.items() if method.lower() in valid_methods} def _print_api_details(self) -> None: """ @@ -81,3 +91,366 @@ def _print_api_details(self) -> None: print(f" Operation: {operation.upper()}") print(f" Summary: {details.get('summary')}") print(f" Description: {details['responses']['200']['description']}") + + def find_oas(self, filepath) -> str: + """ + + Gets the OpenAPI specification for the config + Args: + filepath (str): The config path + + Returns: + str: The OAS file path + """ + current_file_path = os.path.dirname(filepath) + + file_name = Path(filepath).name.split("_config")[0] + oas_file_path = os.path.join(current_file_path, "oas", file_name + "_oas.json") + return oas_file_path + + def get_schemas(self) -> Dict[str, Dict]: + """ + Retrieve schemas from OpenAPI JSON data. + + + Returns: + Dict[str, Dict]: A dictionary with schemas + """ + + components = self.api_data.get('components', {}) + schemas = components.get('schemas', {}) + return schemas + + def get_protected_endpoints(self) -> List: + """ + Retrieves protected endpoints from api data. + + + Returns: + List: A list of protected endpoints + """ + protected = [] + for path, operations in self.api_data['paths'].items(): + for operation, details in operations.items(): + if 'security' in details: + protected.append(f"{operation.upper()} {path}") + return protected + + def get_refresh_endpoints(self): + """ + Retrieves refresh endpoints from api data. + + + Returns: + List: A list of refresh endpoints + """ + refresh_endpoints = [] + for path, operations in self.api_data['paths'].items(): + if 'refresh' in path.lower(): + refresh_endpoints.extend([f"{op.upper()} {path}" for op in operations]) + return refresh_endpoints + + def get_schema_for_endpoint(self, path, method): + """ + Retrieve the schema for a specific endpoint method. + + Args: + path (str): The endpoint path. + method (str): The HTTP method (e.g., 'get', 'post'). + + Returns: + dict: The schema for the requestBody, or None if not available. + """ + method_details = self.api_data.get("paths", {}).get(path, {}).get(method.lower(), {}) + request_body = method_details.get("requestBody", {}) + + # Safely get the schema + content = request_body.get("content", {}) + application_json = content.get("application/json", {}) + schema = application_json.get("schema", None) + schema_ref = None + + if schema and isinstance(schema, dict): + schema_ref = schema.get("$ref", None) + + schemas = self.get_schemas() + correct_schema = None + if schema_ref is not None: + ref_list = schema_ref.split("/") + for schema in schemas: + if schema in ref_list: + correct_schema = schemas.get(schema) + return correct_schema + + return None + + def classify_endpoints(self, name=""): + """ + Classifies API endpoints into various security and functionality categories based on heuristics + such as URL patterns, HTTP methods, response codes, descriptions, and security settings. + + This method processes all endpoints defined in `self.api_data['paths']` and assigns them + into predefined classes including public, protected, resource-intensive, login, and others. + Classifications are based on path structure, method, presence of authentication/authorization, + keywords in the path or description, and response status codes. + + Args: + name (str, optional): An optional string (e.g., test name or profile name) that can + influence the classification of certain endpoints (e.g., skipping account creation + classification for specific OWASP test cases). Defaults to "". + + Returns: + dict: A dictionary containing classified endpoints under the following keys: + - 'resource_intensive_endpoint': Endpoints involving batch uploads or processing. + - 'public_endpoint': Endpoints accessible without authentication. + - 'secure_action_endpoint': Endpoints performing sensitive operations. + - 'role_access_endpoint': Endpoints involving user/admin roles. + - 'sensitive_data_endpoint': Endpoints returning sensitive/confidential data. + - 'sensitive_action_endpoint': Endpoints performing critical modifications. + - 'protected_endpoint': Endpoints requiring authentication. + - 'refresh_endpoint': Endpoints related to token/session refreshing. + - 'login_endpoint': Endpoints used for user login or sign-in. + - 'authentication_endpoint': Endpoints dealing with authentication or token handling. + - 'account_creation': Endpoints related to user account creation. + - 'unclassified_endpoint': Endpoints that do not match any specific classification. + """ + classifications = { + 'resource_intensive_endpoint': [], + 'public_endpoint': [], + 'secure_action_endpoint': [], + 'role_access_endpoint': [], + 'sensitive_data_endpoint': [], + 'sensitive_action_endpoint': [], + 'protected_endpoint': [], + 'refresh_endpoint': [], + 'login_endpoint': [], + 'authentication_endpoint': [], + 'unclassified_endpoint': [], + 'account_creation': [] + } + + for path, path_item in self.api_data['paths'].items(): + for method, operation in path_item.items(): + schema = self.get_schema_for_endpoint(path, method) + if method == 'get' and schema == None and "parameters" in operation.keys() and len( + operation.get("parameters", [])) > 0: + schema = operation.get("parameters")[0] + classified = False + parameters = operation.get("parameters", []) + description = operation.get('description', '').lower() + security = operation.get('security', {}) + responses = operation.get("responses", {}) + unauthorized_description = responses.get("401", {}).get("description", "").lower() + forbidden_description = responses.get("403", {}).get("description", "").lower() + too_many_requests_description = responses.get("429", {}).get("description", "").lower() + + if "dashboard" in path: + classifications['unclassified_endpoint'].append({ + "method": method.upper(), + "path": path, + "schema": schema}) + classified = True + continue + + # Protected endpoints: Paths mentioning "user" or "admin" explicitly + # Check if the path mentions "user" or "admin" and doesn't include "api" + path_condition = ( + any(keyword in path for keyword in ["user", "admin"]) + and not any(keyword in path for keyword in ["api"]) + ) + + # Check if any parameter's value equals "Authorization-Token" + parameter_condition = any( + param.get("name") == "Authorization-Token" for param in parameters + ) + + auth_condition = 'Unauthorized' in unauthorized_description or "forbidden" in forbidden_description + + # Combined condition with `security` (adjust based on actual schema requirements) + if (path_condition or parameter_condition or auth_condition) or security: + classifications['protected_endpoint'].append({ + "method": method.upper(), + "path": path, + "schema": schema}) + classified = True + + # Public endpoint: No '401 Unauthorized' response or description doesn't mention 'unauthorized' + if ('Unauthorized' not in unauthorized_description + or "forbidden" not in forbidden_description + or "too many requests" not in too_many_requests_description + and not security): + classifications['public_endpoint'].append( + { + "method": method.upper(), + "path": path, + "schema": schema} + ) + classified = True + + # User creation endpoint + if any(keyword in path.lower() for keyword in + ['user', 'users', 'signup']) and not "login" in path or any( + word in description for word in ['create a user']): + + if path.lower().endswith("user") and name.startswith("OWASP"): + continue + if not any(keyword in path.lower() for keyword in + ['pictures', 'verify-email-token', 'change-email', "reset", "verify", "videos", + "mechanic"]): + if method.upper() == "POST" and not "data-export" in path: + classifications["account_creation"].append({ + "method": method.upper(), + "path": path, + "schema": schema}) + classified = True + + # Secure action endpoints: Identified by roles or protected access + if any(keyword in path.lower() for keyword in ["user", "admin"]): + classifications['role_access_endpoint'].append({ + "method": method.upper(), + "path": path, + "schema": schema}) + classified = True + + # Sensitive data or action endpoints: Based on description + if any(word in description for word in ['sensitive', 'confidential']): + classifications['sensitive_data_endpoint'].append({ + "method": method.upper(), + "path": path, + "schema": schema}) + classified = True + + if any(word in description for word in ['delete', 'modify', 'change']): + classifications['sensitive_action_endpoint'].append({ + "method": method.upper(), + "path": path, + "schema": schema}) + classified = True + + # Resource-intensive endpoints + if any(word in description for word in ['upload', 'batch', 'heavy', 'intensive']): + classifications['resource_intensive_endpoint'].append({ + "method": method.upper(), + "path": path, + "schema": schema}) + classified = True + + # Rate-limited endpoints + if '429' in responses and 'too many requests' in responses['429'].get('description', '').lower(): + classifications['resource_intensive_endpoint'].append({ + "method": method.upper(), + "path": path, + "schema": schema}) + classified = True + + # Refresh endpoints + if 'refresh' in path.lower() or 'refresh' in description: + classifications['refresh_endpoint'].append({ + "method": method.upper(), + "path": path, + "schema": schema}) + classified = True + + # Login endpoints + if any(keyword in path.lower() for keyword in ['login', 'signin', 'sign-in']): + if method.upper() == "POST": + classifications['login_endpoint'].append({ + "method": method.upper(), + "path": path, + "schema": schema}) + classified = True + + # Authentication-related endpoints + if any(keyword in path.lower() or keyword in description for keyword in + ['auth', 'authenticate', 'token', 'register']): + classifications['authentication_endpoint'].append( + { + "method": method.upper(), + "path": path, + "schema": schema} + ) + classified = True + + # Unclassified endpoints + if not classified: + if isinstance(method, dict): + for method, path in classifications.items(): # Iterate over dictionary items + # Now we can use .upper() on the 'method' string + classifications['unclassified_endpoint'].append({ + "method": method.upper(), + "path": path, + "schema": schema}) + else: + classifications['unclassified_endpoint'].append( + { + "method": method.upper(), + "path": path, + "schema": schema}) + + return classifications + + def categorize_endpoints(self, endpoints, query: dict): + """ + Categorizes a list of API endpoints based on their path structure. + + This method inspects the number of path segments in each endpoint to determine + its type (e.g., root-level, instance-level, subresource, etc.). It uses basic + heuristics, such as the presence of the keyword "id" and the number of segments + after splitting the path by slashes. + + Args: + endpoints (list): A list of API endpoint strings (e.g., ['/users', '/users/{id}']). + query (dict): A dictionary of query parameters (typically used in GET requests). + The values of this dictionary are included in the result under the 'query' key. + + Returns: + dict: A dictionary categorizing the endpoints into the following types: + - 'root_level': Endpoints with a single path segment (e.g., '/users'). + - 'instance_level': Endpoints that include one path parameter like 'id' (e.g., '/users/{id}'). + - 'subresource': Endpoints with two segments that don't include 'id' (e.g., '/users/profile'). + - 'related_resource': Endpoints with three segments including 'id' (e.g., '/users/{id}/orders'). + - 'multi-level_resource': Endpoints with more than two segments not matched by the above. + - 'query': The values from the input query dictionary.""" + root_level = [] + single_parameter = [] + subresource = [] + related_resource = [] + multi_level_resource = [] + + for endpoint in endpoints: + # Split the endpoint by '/' and filter out empty strings + parts = [part for part in endpoint.split('/') if part] + + # Determine the category based on the structure + if len(parts) == 1: + root_level.append(endpoint) + elif len(parts) == 2: + if "id" in endpoint: + single_parameter.append(endpoint) + else: + subresource.append(endpoint) + elif len(parts) == 3: + if "id" in endpoint: + related_resource.append(endpoint) + else: + multi_level_resource.append(endpoint) + else: + multi_level_resource.append(endpoint) + + return { + "root_level": root_level, + "instance_level": single_parameter, + "subresource": subresource, + "query": query.values(), + "related_resource": related_resource, + "multi-level_resource": multi_level_resource, + } + + +if __name__ == "__main__": # Usage + parser = OpenAPISpecificationParser( + "/config/hard/reqres_config.json") + + endpoint_classes = parser.classify_endpoints() + for category, endpoints in endpoint_classes.items(): + print(f"{category}: {endpoints}") diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/pattern_matcher.py b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/pattern_matcher.py new file mode 100644 index 00000000..b9c33cbd --- /dev/null +++ b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/pattern_matcher.py @@ -0,0 +1,122 @@ +import re + +class PatternMatcher: + """ + A utility class for matching and manipulating URL paths using regular expressions. + + This class supports: + - Detecting specific patterns in URL paths (e.g., numeric IDs, nested resources). + - Replacing numeric IDs and query parameters with placeholders. + - Extracting query parameters into a dictionary. + """ + + def __init__(self): + """ + Initialize the PatternMatcher with predefined regex patterns. + """ + self.patterns = { + 'id': re.compile(r"/\d+"), # Matches numeric segments in paths like "/123" + 'query_params': re.compile(r"(\?|\&)([^=]+)=([^&]+)"), # Matches key=value pairs in query strings + 'numeric_resource': re.compile(r"/\w+/\d+$"), # Matches paths like "/resource/123" + 'nested_resource': re.compile(r"/\w+/\w+/\d+$") # Matches paths like "/category/resource/123" + } + + def matches_any_pattern(self, path): + """ + Check if the input path matches any of the defined regex patterns. + + Args: + path (str): The URL path to evaluate. + + Returns: + bool: True if any pattern matches; False otherwise. + """ + for name, pattern in self.patterns.items(): + if pattern.search(path): + return True + return False + + def replace_parameters(self, path, param_placeholder="{{{param}}}"): + """ + Replace numeric path segments and query parameter values with placeholders. + + Args: + path (str): The URL path to process. + param_placeholder (str): A template string for parameter placeholders (not currently used). + + Returns: + str: The transformed path with placeholders. + """ + for pattern_name, pattern in self.patterns.items(): + if 'id' in pattern_name: + # Replace numeric path segments with "/{id}" + return pattern.sub(r"/{id}", path) + + if 'query_params' in pattern_name: + # Replace query parameter values with placeholders + def replacement_logic(match): + delimiter = match.group(1) # ? or & + param_name = match.group(2) + param_value = match.group(3) + + # Replace value with a lowercase placeholder + new_value = f"{{{param_name.lower()}}}" + return f"{delimiter}{param_name}={new_value}" + + return pattern.sub(replacement_logic, path) + + return path + + def replace_according_to_pattern(self, path): + """ + Apply replacement logic if the path matches known patterns. + Also replaces hardcoded "/1" with "/{id}" as a fallback. + + Args: + path (str): The URL path to transform. + + Returns: + str: The transformed path. + """ + if self.matches_any_pattern(path): + return self.replace_parameters(path) + + # Fallback transformation + if "/1" in path: + path = path.replace("/1", "/{id}") + return path + + def extract_query_params(self, path): + """ + Extract query parameters from a URL into a dictionary. + + Args: + path (str): The URL containing query parameters. + + Returns: + dict: A dictionary of parameter names and values. + """ + params = {} + matches = self.patterns['query_params'].findall(path) + for _, param, value in matches: + params[param] = value + return params + + +if __name__ == "__main__": + # Example usage + matcher = PatternMatcher() + example_path = "/resource/456?param1=10&Param2=text&NumValue=123456" + example_nested_path = "/category/resource/789?detail=42&Info2=moreText" + + # Replace parameters in paths + modified_path = matcher.replace_parameters(example_path) + modified_nested_path = matcher.replace_parameters(example_nested_path) + + print(modified_path) + print(modified_nested_path) + print(f'Original path: {example_path}') + print(f'Extracted parameters: {matcher.extract_query_params(example_path)}') + + print(f'Original nested path: {example_nested_path}') + print(f'Extracted parameters: {matcher.extract_query_params(example_nested_path)}') diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/report_handler.py b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/report_handler.py index 6c10f88d..e747ac09 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/documentation/report_handler.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/documentation/report_handler.py @@ -1,67 +1,262 @@ +import json import os +import re +import textwrap import uuid from datetime import datetime from enum import Enum from typing import List - +from fpdf import FPDF class ReportHandler: """ - A handler for creating and managing report files that document operations and data. + A handler for creating and managing reports during automated web API testing. + + This class creates both text and PDF reports documenting tested endpoints, analysis results, + and any vulnerabilities discovered based on HTTP responses. Attributes: - file_path (str): The path to the directory where report files are stored. - report_name (str): The full path to the current report file being written to. - report (file): The file object for the report, opened for writing data. + file_path (str): Path to the directory where general reports are stored. + vul_file_path (str): Path to the directory for vulnerability-specific reports. + report_name (str): Full path to the current report text file. + vul_report_name (str): Full path to the vulnerability report text file. + pdf (FPDF): An FPDF object used to generate a PDF version of the report. + vulnerabilities_counter (int): Counter tracking the number of vulnerabilities found. """ - def __init__(self): + def __init__(self, config): """ - Initializes the ReportHandler by setting up the file path for reports, - creating the directory if it does not exist, and preparing a new report file. + Initializes the ReportHandler, prepares report and vulnerability file paths, and creates + necessary directories and files. + + Args: + config (dict): Configuration dictionary containing metadata like the test name. """ - current_path: str = os.path.dirname(os.path.abspath(__file__)) - self.file_path: str = os.path.join(current_path, "reports") + current_path = os.path.dirname(os.path.abspath(__file__)) + self.file_path = os.path.join(current_path, "reports", config.get("name")) + self.vul_file_path = os.path.join(current_path, "vulnerabilities", config.get("name")) - if not os.path.exists(self.file_path): - os.mkdir(self.file_path) + os.makedirs(self.file_path, exist_ok=True) + os.makedirs(self.vul_file_path, exist_ok=True) - self.report_name: str = os.path.join( + self.report_name = os.path.join( self.file_path, f"report_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt" ) + self.vul_report_name = os.path.join( + self.vul_file_path, f"vul_report_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt" + ) + + self.vulnerabilities_counter = 0 + + # Initialize PDF + self.pdf = FPDF() + self.pdf.set_auto_page_break(auto=True, margin=15) + self.pdf.add_page() + self.pdf.set_font("Arial", size=12) + self.pdf.set_font("Arial", 'B', 16) + self.pdf.cell(200, 10, "Vulnerability Report", ln=True, align='C') + self.pdf_path = os.path.join( + self.vul_file_path, f"vul_report_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pdf" + ) + self.y = 10 # start position + + try: self.report = open(self.report_name, "x") + self.vul_report = open(self.vul_report_name, "x") except FileExistsError: - # Retry with a different name using a UUID to ensure uniqueness self.report_name = os.path.join( self.file_path, f"report_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}_{uuid.uuid4().hex}.txt", ) self.report = open(self.report_name, "x") + def write_line(self, text): + self.pdf.text(10, self.y, text) + self.y += 10 # move down for next line + self.pdf.output(self.pdf_path) + def write_endpoint_to_report(self, endpoint: str) -> None: """ - Writes an endpoint string to the report file. + Writes a single endpoint string to both the text and PDF reports. Args: - endpoint (str): The endpoint information to be recorded in the report. + endpoint (str): The tested endpoint. """ with open(self.report_name, "a") as report: report.write(f"{endpoint}\n") + self.pdf.set_font("Arial", size=12) + self.pdf.multi_cell(0, 10, f"Endpoint: {endpoint}") + def write_analysis_to_report(self, analysis: List[str], purpose: Enum) -> None: """ - Writes an analysis result and its purpose to the report file. + Writes analysis data to the text and PDF reports, grouped by purpose. Args: - analysis (List[str]): The analysis data to be recorded. - purpose (Enum): An enumeration that describes the purpose of the analysis. + analysis (List[str]): List of strings with analysis output. + purpose (Enum): Enum representing the analysis type or purpose. """ - with open(self.report_name, "a") as report: - report.write(f"{purpose.name}:\n") + try: + with open(self.report_name, 'r') as report: + content = report.read() + except FileNotFoundError: + content = "" + + if purpose.name not in content: + with open(self.report_name, 'a') as report: + report.write('-' * 90 + '\n') + report.write(f'{purpose.name}:\n') + + with open(self.report_name, 'a') as report: for item in analysis: - for line in item.split("\n"): - if "note recorded" in line: - continue - else: - report.write(line + "\n") + filtered_lines = [line for line in item.split("\n") if "note recorded" not in line] + report.write("\n".join(filtered_lines) + "\n") + + self.pdf.set_font("Arial", 'B', 12) + self.pdf.text(10, self.pdf.get_y() + 10, f"Purpose: {purpose.name}") + self.pdf.set_font("Arial", size=10) + + for item in analysis: + filtered_lines = [line for line in item.split("\n") if "note recorded" not in line] + wrapped_text = [textwrap.fill(line, width=80) for line in filtered_lines if line.strip()] + y_position = self.pdf.get_y() + 5 + for line in wrapped_text: + self.pdf.text(10, y_position, line) + y_position += 5 + self.pdf.set_y(y_position + 5) + + def save_report(self) -> None: + """ + Saves the PDF version of the report to the file system. + """ + report_name = os.path.join( + self.file_path, f"report_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pdf" + ) + self.pdf.output(report_name) + + def write_vulnerability_to_report(self, test_step, test_over_step, raw_response, current_substep): + """ + Analyzes an HTTP response and logs whether a vulnerability was detected. + + Args: + test_step (dict): Metadata about the current test step, including expected codes and messages. + raw_response (str): Full raw HTTP response string. + current_substep (str): Label or identifier for the current test substep. + """ + match = re.search(r"^HTTP/\d\.\d\s+(\d+)(?:\s+(.*))?", raw_response, re.MULTILINE) + if match: + status_code = match.group(1).strip() + status_message = (match.group(2) or "").strip() + full_status_line = f"{status_code} {status_message}".strip() + else: + status_code = None + full_status_line = "" + + test_case_purpose = test_step.get('purpose', "Unnamed Test Case") + test_case_name = test_over_step.get("phase_title").split("Phase: ")[1] + step = test_step.get('step', "No step") + expected = test_step.get('expected_response_code', "No expected result") + # Example response headers from a web server + response_headers = { + 'Server': 'Apache/2.4.1', + 'Strict-Transport-Security': 'max-age=31536000; includeSubDomains', + 'X-Content-Type-Options': 'nosniff', + 'Content-Security-Policy': "default-src 'self'", + 'X-Frame-Options': 'DENY', + 'Set-Cookie': 'sessionid=123456; HttpOnly; Secure' + } + + # Define the security configurations we expect + expected_configurations = { + 'Strict-Transport-Security': lambda value: "max-age" in value, + 'X-Content-Type-Options': lambda value: value.lower() == 'nosniff', + 'Content-Security-Policy': lambda value: "default-src 'self'" in value, + 'X-Frame-Options': lambda value: value.lower() == 'deny', + 'Set-Cookie': lambda value: 'httponly' in value.lower() and 'secure' in value.lower() + } + + + + if "only one id" in test_step.get("security"): + headers, body = raw_response.split('\r\n\r\n', 1) + body = json.loads(body) + if len(body)> 1: + self.vulnerabilities_counter += 1 + report_line = ( + f"Test Purpose: {test_case_purpose}\n" + f"Test Name:{test_case_name}\n" + f"Step: {step}\n" + f"Expected Result: Only one \n" + f"Actual Result: More than one id returned\n" + f"Number of found vulnerabilities: {self.vulnerabilities_counter}\n\n" + ) + with open(self.vul_report_name, "a", encoding="utf-8") as f: + f.write(report_line) + + elif "Access-Control Allow-Origin *"or "Access-Control Allow-Credentials: true" in headers: + report_line = ( + f"Test Purpose: {test_case_purpose}\n" + f"Test Name: {test_case_name}\n" + f"Step: {step}\n" + f"Expected Result: All debug options disabled, no default credentials, correct permission settings applied\n" + f"Actual Result: Debug mode enabled, default admin account active, incorrect file permissions\n" + f"Number of found vulnerabilities: {self.vulnerabilities_counter}\n\n" + ) + + with open(self.vul_report_name, "a", encoding="utf-8") as f: + f.write(report_line) + + # Check the response headers for security misconfigurations + for header, is_config_correct in expected_configurations.items(): + actual_value = response_headers.get(header, '') + if not actual_value or not is_config_correct(actual_value): + report_line = ( + f"Test Purpose: {test_case_purpose}\n" + f"Test Name: {test_case_name}\n" + f"Step: {step}\n" + f"Expected Result: All debug options disabled, no default credentials, correct permission settings applied\n" + f"Actual Result: Debug mode enabled, default admin account active, incorrect file permissions\n" + f"Number of found vulnerabilities: {self.vulnerabilities_counter}\n\n" + ) + + with open(self.vul_report_name, "a", encoding="utf-8") as f: + f.write(report_line) + elif "message" in body or "conversion_params" in body: + report_line = ( + f"Test Purpose: {test_case_purpose}\n" + f"Test Name: {test_case_name}\n" + f"Step: {step}\n" + f"Expected Result: Only necesary information should be returned.\n" + f"Actual Result: Too much information was logged.\n" + f"Number of found vulnerabilities: {self.vulnerabilities_counter}\n\n" + ) + + with open(self.vul_report_name, "a", encoding="utf-8") as f: + f.write(report_line) + + expected_codes = test_step.get('expected_response_code', []) + conditions = test_step.get('conditions', {}) + successful_msg = conditions.get('if_successful', "No Vulnerability found.") + unsuccessful_msg = conditions.get('if_unsuccessful', "Vulnerability found.") + + success = any( + str(status_code).strip() == str(expected.split()[0]).strip() + and expected.split()[0].strip().isdigit() + for expected in expected_codes if expected.strip() + ) + + + if not success: + self.vulnerabilities_counter += 1 + report_line = ( + f"Test Purpose: {test_case_purpose}\n" + f"Test Name:{test_case_name}\n" + f"Step: {step}\n" + f"Expected Result: {expected}\n" + f"Actual Result: {status_code}\n" + f"{unsuccessful_msg}\n" + f"Number of found vulnerabilities: {self.vulnerabilities_counter}\n\n" + ) + with open(self.vul_report_name, "a", encoding="utf-8") as f: + f.write(report_line) diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/pentesting_information.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/pentesting_information.py deleted file mode 100644 index ce5874f9..00000000 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/pentesting_information.py +++ /dev/null @@ -1,122 +0,0 @@ -from typing import Dict, List - -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( - PromptPurpose, -) - - -class PenTestingInformation: - """ - A class to provide penetration testing steps and analysis based on different purposes - such as authentication, input validation, error handling, etc. - - Attributes: - schemas (dict): A dictionary containing valid schemas for input validation tests. - explore_steps (dict): A dictionary mapping each PromptPurpose to a list of exploration steps. - exploit_steps (dict): A dictionary mapping each PromptPurpose to a list of exploitation steps. - """ - - def __init__(self, schemas: Dict[str, List[str]] = None) -> object: - """ - Initializes the PenTestingInformation with optional schemas. - - Args: - schemas (dict, optional): A dictionary containing valid schemas for input validation tests. Default is None. - """ - self.schemas = schemas if schemas is not None else {} - self.explore_steps = self.init_steps - - @property - def init_steps(self) -> Dict[PromptPurpose, List[str]]: - """ - Provides initial penetration testing steps for various purposes. - - Returns: - dict: A dictionary where each key is a PromptPurpose and each value is a list of steps. - """ - return { - PromptPurpose.AUTHENTICATION_AUTHORIZATION: [ - "Verify if there exists an authentication mechanism, the effectiveness of the authentication mechanism, " - "that only authorized users or applications can access the APIs, and if the API enforces proper authorization " - "checks to prevent unauthorized access to resources." - ], - PromptPurpose.INPUT_VALIDATION: [ - f"Examine how the API handles various types of user input by crafting and sending a POST request with valid data. " - f"Valid schemas are {self.schemas}.", - f"Examine how the API handles various types of user input by crafting and sending a POST request with invalid data. " - f"Valid schemas are {self.schemas}.", - f"Examine how the API handles various types of user input by crafting and sending a POST request with edge data. " - f"Valid schemas are {self.schemas}.", - f"Omit required fields to see if the API correctly identifies them as missing by sending a POST request. " - f"Valid schemas are {self.schemas}.", - f"Include special characters or injection attempts (e.g., SQL Injection, XSS) to see if the API sanitizes inputs properly, " - f"by sending a POST request. Valid schemas are {self.schemas}.", - f"Send data that doesn’t match the expected format by a POST request. Valid data formats are {self.schemas}.", - "Check for proper error handling, response codes, and sanitization.", - "Attempt to exploit common vulnerabilities by injecting malicious inputs, such as SQL injection, NoSQL injection, " - "cross-site scripting, and other injection attacks. Evaluate whether the API properly validates, escapes, and sanitizes " - "all user-supplied data, ensuring no unexpected behavior or security vulnerabilities are exposed.", - ], - PromptPurpose.ERROR_HANDLING_INFORMATION_LEAKAGE: [ - "Check how the API handles errors and if there are detailed error messages.", - "Look for vulnerabilities and information leakage.", - ], - PromptPurpose.SESSION_MANAGEMENT: [ - "Check if the API uses session management.", - "Look at the session handling mechanism for vulnerabilities such as session fixation, session hijacking, or session timeout settings.", - ], - PromptPurpose.CROSS_SITE_SCRIPTING: [ - "Look for vulnerabilities that could enable malicious scripts to be injected into API responses." - ], - PromptPurpose.CROSS_SITE_FORGERY: [ - "Look for vulnerabilities that could enable unauthorized actions to be performed on behalf of authenticated users." - ], - PromptPurpose.BUSINESS_LOGIC_VULNERABILITIES: [ - "Examine the API's business logic and identify flaws that can be exploited for unauthorized access, manipulation, or data exposure." - ], - PromptPurpose.RATE_LIMITING_THROTTLING: [ - "Check if the API has adequate rate-limiting and throttling controls to prevent abuse and denial-of-service attacks." - ], - PromptPurpose.SECURITY_MISCONFIGURATIONS: [ - "Check the API's configuration settings and determine if they expose sensitive information or create security weaknesses." - ], - PromptPurpose.LOGGING_MONITORING: [ - "Examine the logging and monitoring capabilities of the API and check if security incidents are detected and responded to promptly." - ], - } - - def analyse_steps(self, response: str = "") -> Dict[PromptPurpose, List[str]]: - """ - Provides prompts for analysis based on the provided response for various purposes using an LLM. - - Args: - response (str, optional): The HTTP response to analyze. Default is an empty string. - - Returns: - dict: A dictionary where each key is a PromptPurpose and each value is a list of prompts. - """ - return { - PromptPurpose.PARSING: [ - f""" Please parse this response and extract the following details in JSON format: {{ - "Status Code": "", - "Reason Phrase": "", - "Headers": , - "Response Body": - from this response: {response} - - }}""" - ], - PromptPurpose.ANALYSIS: [ - f"Given the following parsed HTTP response:\n{response}\n" - "Please analyze this response to determine:\n" - "1. Whether the status code is appropriate for this type of request.\n" - "2. If the headers indicate proper security and rate-limiting practices.\n" - "3. Whether the response body is correctly handled." - ], - PromptPurpose.DOCUMENTATION: [ - f"Based on the analysis provided, document the findings of this API response validation:\n{response}" - ], - PromptPurpose.REPORTING: [ - f"Based on the documented findings : {response}. Suggest any improvements or issues that should be reported to the API developers." - ], - } diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompt_engineer.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompt_engineer.py deleted file mode 100644 index 54e3aea7..00000000 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompt_engineer.py +++ /dev/null @@ -1,149 +0,0 @@ -from instructor.retry import InstructorRetryException - -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( - PromptContext, - PromptStrategy, -) -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompt_generation_helper import ( - PromptGenerationHelper, -) -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts.state_learning import ( - InContextLearningPrompt, -) -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts.task_planning import ( - ChainOfThoughtPrompt, - TreeOfThoughtPrompt, -) -from hackingBuddyGPT.usecases.web_api_testing.utils.custom_datatypes import Prompt -from hackingBuddyGPT.utils import tool_message - - -class PromptEngineer: - """Prompt engineer that creates prompts of different types.""" - - def __init__( - self, - strategy: PromptStrategy = None, - history: Prompt = None, - handlers=(), - context: PromptContext = None, - rest_api: str = "", - schemas: dict = None, - ): - """ - Initializes the PromptEngineer with a specific strategy and handlers for LLM and responses. - - Args: - strategy (PromptStrategy): The prompt engineering strategy to use. - history (dict, optional): The history of chats. Defaults to None. - handlers (tuple): The LLM handler and response handler. - context (PromptContext): The context for which prompts are generated. - rest_api (str, optional): The REST API endpoint. - schemas (dict, optional): Schemas relevant for the context. - """ - self.strategy = strategy - self.rest_api = rest_api - self.llm_handler, self.response_handler = handlers - self.prompt_helper = PromptGenerationHelper(response_handler=self.response_handler, schemas=schemas or {}) - self.context = context - self.turn = 0 - self._prompt_history = history or [] - - self.strategies = { - PromptStrategy.CHAIN_OF_THOUGHT: ChainOfThoughtPrompt( - context=self.context, prompt_helper=self.prompt_helper - ), - PromptStrategy.TREE_OF_THOUGHT: TreeOfThoughtPrompt( - context=self.context, prompt_helper=self.prompt_helper, rest_api=self.rest_api - ), - PromptStrategy.IN_CONTEXT: InContextLearningPrompt( - context=self.context, - prompt_helper=self.prompt_helper, - context_information={self.turn: {"content": "initial_prompt"}}, - ), - } - - self.purpose = None - - def generate_prompt(self, turn: int, move_type="explore", hint=""): - """ - Generates a prompt based on the specified strategy and gets a response. - - Args: - turn (int): The current round or step in the process. - move_type (str, optional): The type of move for the strategy. Defaults to "explore". - hint (str, optional): An optional hint to guide the prompt generation. Defaults to "". - - Returns: - list: Updated prompt history after generating the prompt and receiving a response. - - Raises: - ValueError: If an invalid prompt strategy is specified. - """ - prompt_func = self.strategies.get(self.strategy) - if not prompt_func: - raise ValueError("Invalid prompt strategy") - - is_good = False - self.turn = turn - while not is_good: - try: - prompt = prompt_func.generate_prompt( - move_type=move_type, hint=hint, previous_prompt=self._prompt_history, turn=0 - ) - self.purpose = prompt_func.purpose - is_good = self.evaluate_response(prompt, "") - except InstructorRetryException: - hint = f"invalid prompt: {prompt}" - - self._prompt_history.append({"role": "system", "content": prompt}) - self.previous_prompt = prompt - self.turn += 1 - return self._prompt_history - - def evaluate_response(self, prompt, response_text): - """ - Evaluates the response to determine if it is acceptable. - - Args: - prompt (str): The generated prompt. - response_text (str): The response text to evaluate. - - Returns: - bool: True if the response is acceptable, otherwise False. - """ - # TODO: Implement a proper evaluation mechanism - return True - - def get_purpose(self): - """Returns the purpose of the current prompt strategy.""" - return self.purpose - - def process_step(self, step: str, prompt_history: list) -> tuple[list, str]: - """ - Helper function to process each analysis step with the LLM. - - Args: - step (str): The current step to process. - prompt_history (list): The history of prompts and responses. - - Returns: - tuple: Updated prompt history and the result of the step processing. - """ - print(f"Processing step: {step}") - prompt_history.append({"role": "system", "content": step}) - - # Call the LLM and handle the response - self.prompt_helper.check_prompt(prompt_history, step) - response, completion = self.llm_handler.call_llm(prompt_history) - message = completion.choices[0].message - prompt_history.append(message) - tool_call_id = message.tool_calls[0].id - - try: - result = response.execute() - except Exception as e: - result = f"Error executing tool call: {str(e)}" - prompt_history.append(tool_message(str(result), tool_call_id)) - - return prompt_history, result diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompt_generation_helper.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompt_generation_helper.py deleted file mode 100644 index 040ef6bd..00000000 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompt_generation_helper.py +++ /dev/null @@ -1,139 +0,0 @@ -import re - -import nltk - -from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_handler import ResponseHandler - - -class PromptGenerationHelper(object): - """ - A helper class for managing and generating prompts, tracking endpoints, and ensuring consistency in HTTP actions. - - Attributes: - response_handler (object): Handles responses for prompts. - found_endpoints (list): A list of discovered endpoints. - endpoint_methods (dict): A dictionary mapping endpoints to their HTTP methods. - endpoint_found_methods (dict): A dictionary mapping HTTP methods to endpoints. - schemas (dict): A dictionary of schemas used for constructing HTTP requests. - """ - - def __init__(self, response_handler: ResponseHandler = None, schemas: dict = None): - """ - Initializes the PromptAssistant with a response handler and downloads necessary NLTK models. - - Args: - response_handler (object): The response handler used for managing responses. - schemas(tuple): Schemas used - """ - if schemas is None: - schemas = {} - - self.response_handler = response_handler - self.found_endpoints = ["/"] - self.endpoint_methods = {} - self.endpoint_found_methods = {} - self.schemas = schemas - - # Download NLTK models if not already installed - nltk.download("punkt") - nltk.download("stopwords") - - def get_endpoints_needing_help(self): - """ - Identifies endpoints that need additional HTTP methods and returns guidance for the first missing method. - - Returns: - list: A list containing guidance for the first missing method of the first endpoint that needs help. - """ - endpoints_needing_help = [] - endpoints_and_needed_methods = {} - http_methods_set = {"GET", "POST", "PUT", "DELETE"} - - for endpoint, methods in self.endpoint_methods.items(): - missing_methods = http_methods_set - set(methods) - if len(methods) < 4: - endpoints_needing_help.append(endpoint) - endpoints_and_needed_methods[endpoint] = list(missing_methods) - - if endpoints_needing_help: - first_endpoint = endpoints_needing_help[0] - needed_method = endpoints_and_needed_methods[first_endpoint][0] - return [ - f"For endpoint {first_endpoint}, find this missing method: {needed_method}. If all HTTP methods have already been found for an endpoint, do not include this endpoint in your search." - ] - return [] - - def get_http_action_template(self, method): - """ - Constructs a consistent HTTP action description based on the provided method. - - Args: - method (str): The HTTP method to construct the action description for. - - Returns: - str: The constructed HTTP action description. - """ - if method in ["POST", "PUT"]: - return f"Create HTTPRequests of type {method} considering the found schemas: {self.schemas} and understand the responses. Ensure that they are correct requests." - else: - return f"Create HTTPRequests of type {method} considering only the object with id=1 for the endpoint and understand the responses. Ensure that they are correct requests." - - def get_initial_steps(self, common_steps): - """ - Provides the initial steps for identifying available endpoints and documenting their details. - - Args: - common_steps (list): A list of common steps to be included. - - Returns: - list: A list of initial steps combined with common steps. - """ - return [ - f"Identify all available endpoints via GET Requests. Exclude those in this list: {self.found_endpoints}", - "Note down the response structures, status codes, and headers for each endpoint.", - "For each endpoint, document the following details: URL, HTTP method, query parameters and path variables, expected request body structure for requests, response structure for successful and error responses.", - ] + common_steps - - def token_count(self, text): - """ - Counts the number of word tokens in the provided text using NLTK's tokenizer. - - Args: - text (str): The input text to tokenize and count. - - Returns: - int: The number of tokens in the input text. - """ - tokens = re.findall(r"\b\w+\b", text) - words = [token.strip("'") for token in tokens if token.strip("'").isalnum()] - return len(words) - - def check_prompt(self, previous_prompt: list, steps: str, max_tokens: int = 900) -> str: - """ - Validates and shortens the prompt if necessary to ensure it does not exceed the maximum token count. - - Args: - previous_prompt (list): The previous prompt content. - steps (str): A list of steps to be included in the new prompt. - max_tokens (int, optional): The maximum number of tokens allowed. Defaults to 900. - - Returns: - str: The validated and possibly shortened prompt. - """ - - def validate_prompt(prompt): - if self.token_count(prompt) <= max_tokens: - return prompt - shortened_prompt = self.response_handler.get_response_for_prompt("Shorten this prompt: " + prompt) - if self.token_count(shortened_prompt) <= max_tokens: - return shortened_prompt - return "Prompt is still too long after summarization." - - if not all(step in previous_prompt for step in steps): - if isinstance(steps, list): - potential_prompt = "\n".join(str(element) for element in steps) - else: - potential_prompt = str(steps) + "\n" - return validate_prompt(potential_prompt) - - return validate_prompt(previous_prompt) diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/basic_prompt.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/basic_prompt.py deleted file mode 100644 index af753d5c..00000000 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/basic_prompt.py +++ /dev/null @@ -1,73 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional - -# from hackingBuddyGPT.usecases.web_api_testing.prompt_generation import PromptGenerationHelper -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information import ( - PenTestingInformation, -) -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( - PlanningType, - PromptContext, - PromptStrategy, -) - - -class BasicPrompt(ABC): - """ - Abstract base class for generating prompts based on different strategies and contexts. - - This class serves as a blueprint for creating specific prompt generators that operate under different strategies, - such as chain-of-thought or simple prompt generation strategies, tailored to different contexts like documentation - or pentesting. - - Attributes: - context (PromptContext): The context in which prompts are generated. - prompt_helper (PromptHelper): A helper object for managing and generating prompts. - strategy (PromptStrategy): The strategy used for prompt generation. - pentesting_information (Optional[PenTestingInformation]): Contains information relevant to pentesting when the context is pentesting. - """ - - def __init__( - self, - context: PromptContext = None, - planning_type: PlanningType = None, - prompt_helper=None, - strategy: PromptStrategy = None, - ): - """ - Initializes the BasicPrompt with a specific context, prompt helper, and strategy. - - Args: - context (PromptContext): The context in which prompts are generated. - planning_type (PlanningType): The type of planning. - prompt_helper (PromptHelper): A helper object for managing and generating prompts. - strategy (PromptStrategy): The strategy used for prompt generation. - """ - self.context = context - self.planning_type = planning_type - self.prompt_helper = prompt_helper - self.strategy = strategy - self.pentesting_information: Optional[PenTestingInformation] = None - - if self.context == PromptContext.PENTESTING: - self.pentesting_information = PenTestingInformation(schemas=prompt_helper.schemas) - - @abstractmethod - def generate_prompt( - self, move_type: str, hint: Optional[str], previous_prompt: Optional[str], turn: Optional[int] - ) -> str: - """ - Abstract method to generate a prompt. - - This method must be implemented by subclasses to generate a prompt based on the given move type, optional hint, and previous prompt. - - Args: - move_type (str): The type of move to generate. - hint (Optional[str]): An optional hint to guide the prompt generation. - previous_prompt (Optional[str]): The previous prompt content based on the conversation history. - turn (Optional[int]): The current turn - - Returns: - str: The generated prompt. - """ - pass diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/in_context_learning_prompt.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/in_context_learning_prompt.py deleted file mode 100644 index f5772683..00000000 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/in_context_learning_prompt.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Dict, Optional - -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( - PromptContext, - PromptPurpose, - PromptStrategy, -) -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts.state_learning.state_planning_prompt import ( - StatePlanningPrompt, -) - - -class InContextLearningPrompt(StatePlanningPrompt): - """ - A class that generates prompts using the in-context learning strategy. - - This class extends the BasicPrompt abstract base class and implements - the generate_prompt method for creating prompts based on the - in-context learning strategy. - - Attributes: - context (PromptContext): The context in which prompts are generated. - prompt_helper (PromptHelper): A helper object for managing and generating prompts. - prompt (Dict[int, Dict[str, str]]): A dictionary containing the prompts for each round. - turn (int): The round number for which the prompt is being generated. - purpose (Optional[PromptPurpose]): The purpose of the prompt generation, which can be set during the process. - """ - - def __init__(self, context: PromptContext, prompt_helper, context_information: Dict[int, Dict[str, str]]) -> None: - """ - Initializes the InContextLearningPrompt with a specific context, prompt helper, and initial prompt. - - Args: - context (PromptContext): The context in which prompts are generated. - prompt_helper (PromptHelper): A helper object for managing and generating prompts. - context_information (Dict[int, Dict[str, str]]): A dictionary containing the prompts for each round. - round (int): The round number for which the prompt is being generated. - """ - super().__init__(context=context, prompt_helper=prompt_helper, strategy=PromptStrategy.IN_CONTEXT) - self.prompt: Dict[int, Dict[str, str]] = context_information - self.purpose: Optional[PromptPurpose] = None - - def generate_prompt( - self, move_type: str, hint: Optional[str], previous_prompt: Optional[str], turn: Optional[int] - ) -> str: - """ - Generates a prompt using the in-context learning strategy. - - Args: - move_type (str): The type of move to generate. - hint (Optional[str]): An optional hint to guide the prompt generation. - previous_prompt (List[Dict[str, str]]): A list of previous prompt entries, each containing a "content" key. - - Returns: - str: The generated prompt. - """ - history_content = [entry["content"] for entry in previous_prompt] - prompt_content = self.prompt.get(turn, {}).get("content", "") - - # Add hint if provided - if hint: - prompt_content += f"\n{hint}" - - return "\n".join(history_content + [prompt_content]) diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/state_planning_prompt.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/state_planning_prompt.py deleted file mode 100644 index 5cbb936b..00000000 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/state_planning_prompt.py +++ /dev/null @@ -1,39 +0,0 @@ -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( - PlanningType, - PromptContext, - PromptStrategy, -) -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts import ( - BasicPrompt, -) - - -class StatePlanningPrompt(BasicPrompt): - """ - A class for generating state planning prompts, including strategies like In-Context Learning (ICL). - - This class extends BasicPrompt to provide specific implementations for state planning strategies, focusing on - adapting prompts based on the current context or state of information provided. - - Attributes: - context (PromptContext): The context in which prompts are generated. - prompt_helper (PromptHelper): A helper object for managing and generating prompts. - strategy (PromptStrategy): The strategy used for prompt generation, typically state-oriented like ICL. - pentesting_information (Optional[PenTestingInformation]): Contains information relevant to pentesting when the context is pentesting. - """ - - def __init__(self, context: PromptContext, prompt_helper, strategy: PromptStrategy): - """ - Initializes the StatePlanningPrompt with a specific context, prompt helper, and strategy. - - Args: - context (PromptContext): The context in which prompts are generated. - prompt_helper (PromptHelper): A helper object for managing and generating prompts. - strategy (PromptStrategy): The state planning strategy used for prompt generation. - """ - super().__init__( - context=context, - planning_type=PlanningType.STATE_PLANNING, - prompt_helper=prompt_helper, - strategy=strategy, - ) diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/chain_of_thought_prompt.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/chain_of_thought_prompt.py deleted file mode 100644 index 9825d17c..00000000 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/chain_of_thought_prompt.py +++ /dev/null @@ -1,146 +0,0 @@ -from typing import List, Optional - -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( - PromptContext, - PromptPurpose, - PromptStrategy, -) -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts.task_planning.task_planning_prompt import ( - TaskPlanningPrompt, -) - - -class ChainOfThoughtPrompt(TaskPlanningPrompt): - """ - A class that generates prompts using the chain-of-thought strategy. - - This class extends the BasicPrompt abstract base class and implements - the generate_prompt method for creating prompts based on the - chain-of-thought strategy. - - Attributes: - context (PromptContext): The context in which prompts are generated. - prompt_helper (PromptHelper): A helper object for managing and generating prompts. - explored_steps (List[str]): A list of steps that have already been explored in the chain-of-thought strategy. - purpose (Optional[PromptPurpose]): The purpose of the current prompt. - """ - - def __init__(self, context: PromptContext, prompt_helper): - """ - Initializes the ChainOfThoughtPrompt with a specific context and prompt helper. - - Args: - context (PromptContext): The context in which prompts are generated. - prompt_helper (PromptHelper): A helper object for managing and generating prompts. - """ - super().__init__(context=context, prompt_helper=prompt_helper, strategy=PromptStrategy.CHAIN_OF_THOUGHT) - self.explored_steps: List[str] = [] - self.purpose: Optional[PromptPurpose] = None - - def generate_prompt( - self, move_type: str, hint: Optional[str], previous_prompt: Optional[str], turn: Optional[int] - ) -> str: - """ - Generates a prompt using the chain-of-thought strategy. - - Args: - move_type (str): The type of move to generate. - hint (Optional[str]): An optional hint to guide the prompt generation. - previous_prompt (Optional[str]): The previous prompt content based on the conversation history. - - Returns: - str: The generated prompt. - """ - common_steps = self._get_common_steps() - chain_of_thought_steps = self._get_chain_of_thought_steps(common_steps, move_type) - - if hint: - chain_of_thought_steps.append(hint) - - return self.prompt_helper.check_prompt(previous_prompt=previous_prompt, steps=chain_of_thought_steps) - - def _get_common_steps(self) -> List[str]: - """ - Provides a list of common steps for generating prompts. - - Returns: - List[str]: A list of common steps for generating prompts. - """ - if self.context == PromptContext.DOCUMENTATION: - return [ - "Identify common data structures returned by various endpoints and define them as reusable schemas. " - "Determine the type of each field (e.g., integer, string, array) and define common response structures as components that can be referenced in multiple endpoint definitions.", - "Create an OpenAPI document including metadata such as API title, version, and description, define the base URL of the API, list all endpoints, methods, parameters, and responses, and define reusable schemas, response types, and parameters.", - "Ensure the correctness and completeness of the OpenAPI specification by validating the syntax and completeness of the document using tools like Swagger Editor, and ensure the specification matches the actual behavior of the API.", - "Refine the document based on feedback and additional testing, share the draft with others, gather feedback, and make necessary adjustments. Regularly update the specification as the API evolves.", - "Make the OpenAPI specification available to developers by incorporating it into your API documentation site and keep the documentation up to date with API changes.", - ] - else: - return [ - "Identify common data structures returned by various endpoints and define them as reusable schemas, specifying field types like integer, string, and array.", - "Create an OpenAPI document that includes API metadata (title, version, description), the base URL, endpoints, methods, parameters, and responses.", - "Ensure the document's correctness and completeness using tools like Swagger Editor, and verify it matches the API's behavior. Refine the document based on feedback, share drafts for review, and update it regularly as the API evolves.", - "Make the specification available to developers through the API documentation site, keeping it current with any API changes.", - ] - - def _get_chain_of_thought_steps(self, common_steps: List[str], move_type: str) -> List[str]: - """ - Provides the steps for the chain-of-thought strategy based on the current context. - - Args: - common_steps (List[str]): A list of common steps for generating prompts. - move_type (str): The type of move to generate. - - Returns: - List[str]: A list of steps for the chain-of-thought strategy. - """ - if self.context == PromptContext.DOCUMENTATION: - self.purpose = PromptPurpose.DOCUMENTATION - return self._get_documentation_steps(common_steps, move_type) - else: - return self._get_pentesting_steps(move_type) - - def _get_documentation_steps(self, common_steps: List[str], move_type: str) -> List[str]: - """ - Provides the steps for the chain-of-thought strategy when the context is documentation. - - Args: - common_steps (List[str]): A list of common steps for generating prompts. - move_type (str): The type of move to generate. - - Returns: - List[str]: A list of steps for the chain-of-thought strategy in the documentation context. - """ - if move_type == "explore": - return self.prompt_helper.get_initial_steps(common_steps) - else: - return self.prompt_helper.get_endpoints_needing_help() - - def _get_pentesting_steps(self, move_type: str) -> List[str]: - """ - Provides the steps for the chain-of-thought strategy when the context is pentesting. - - Args: - move_type (str): The type of move to generate. - - Returns: - List[str]: A list of steps for the chain-of-thought strategy in the pentesting context. - """ - if move_type == "explore": - purpose = list(self.pentesting_information.explore_steps.keys())[0] - step = self.pentesting_information.explore_steps[purpose] - if step not in self.explored_steps: - if len(step) > 1: - step = self.pentesting_information.explore_steps[purpose][0] - if len(self.pentesting_information.explore_steps[purpose]) == 0: - del self.pentesting_information.explore_steps[purpose][0] - prompt = step - self.purpose = purpose - self.explored_steps.append(step) - if len(step) == 1: - del self.pentesting_information.explore_steps[purpose] - - print(f"prompt: {prompt}") - return prompt - else: - return ["Look for exploits."] diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/task_planning_prompt.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/task_planning_prompt.py deleted file mode 100644 index 181f30ab..00000000 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/task_planning_prompt.py +++ /dev/null @@ -1,39 +0,0 @@ -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( - PlanningType, - PromptContext, - PromptStrategy, -) -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts import ( - BasicPrompt, -) - - -class TaskPlanningPrompt(BasicPrompt): - """ - A class for generating task planning prompts, including strategies like Chain-of-Thought (CoT) and Tree-of-Thought (ToT). - - This class extends BasicPrompt to provide specific implementations for task planning strategies, allowing for - detailed step-by-step reasoning or exploration of multiple potential reasoning paths. - - Attributes: - context (PromptContext): The context in which prompts are generated. - prompt_helper (PromptHelper): A helper object for managing and generating prompts. - strategy (PromptStrategy): The strategy used for prompt generation, which could be CoT, ToT, etc. - pentesting_information (Optional[PenTestingInformation]): Contains information relevant to pentesting when the context is pentesting. - """ - - def __init__(self, context: PromptContext, prompt_helper, strategy: PromptStrategy): - """ - Initializes the TaskPlanningPrompt with a specific context, prompt helper, and strategy. - - Args: - context (PromptContext): The context in which prompts are generated. - prompt_helper (PromptHelper): A helper object for managing and generating prompts. - strategy (PromptStrategy): The task planning strategy used for prompt generation. - """ - super().__init__( - context=context, - planning_type=PlanningType.TASK_PLANNING, - prompt_helper=prompt_helper, - strategy=strategy, - ) diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/tree_of_thought_prompt.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/tree_of_thought_prompt.py deleted file mode 100644 index 028a79da..00000000 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/tree_of_thought_prompt.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import Optional - -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( - PromptContext, - PromptPurpose, - PromptStrategy, -) -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompts.task_planning import ( - TaskPlanningPrompt, -) -from hackingBuddyGPT.usecases.web_api_testing.utils.custom_datatypes import Prompt - - -class TreeOfThoughtPrompt(TaskPlanningPrompt): - """ - A class that generates prompts using the tree-of-thought strategy. - - This class extends the BasicPrompt abstract base class and implements - the generate_prompt method for creating prompts based on the - tree-of-thought strategy. - - Attributes: - context (PromptContext): The context in which prompts are generated. - prompt_helper (PromptHelper): A helper object for managing and generating prompts. - rest_api (str): The REST API endpoint for which prompts are generated. - round (int): The round number for the prompt generation process. - purpose (Optional[PromptPurpose]): The purpose of the prompt generation, which can be set during the process. - """ - - def __init__(self, context: PromptContext, prompt_helper, rest_api: str) -> None: - """ - Initializes the TreeOfThoughtPrompt with a specific context and prompt helper. - - Args: - context (PromptContext): The context in which prompts are generated. - prompt_helper (PromptHelper): A helper object for managing and generating prompts. - rest_api (str): The REST API endpoint. - round (int): The round number for the prompt generation process. - """ - super().__init__(context=context, prompt_helper=prompt_helper, strategy=PromptStrategy.TREE_OF_THOUGHT) - self.rest_api: str = rest_api - self.purpose: Optional[PromptPurpose] = None - - def generate_prompt(self, move_type: str, hint: Optional[str], previous_prompt: Prompt, turn: Optional[int]) -> str: - """ - Generates a prompt using the tree-of-thought strategy. - - Args: - move_type (str): The type of move to generate. - hint (Optional[str]): An optional hint to guide the prompt generation. - previous_prompt (List[Union[ChatCompletionMessage, ChatCompletionMessageParam]]): A list of previous prompt entries, each containing a "content" key. - turn (Optional[int]): The current turn or step in the conversation. - - Returns: - str: The generated prompt. - """ - if self.context == PromptContext.DOCUMENTATION: - tree_of_thoughts_steps = [ - ( - "Imagine three different OpenAPI specification specialists.\n" - "All experts will write down one step of their thinking,\n" - "then share it with the group.\n" - "After that, all remaining specialists will proceed to the next step, and so on.\n" - "If any specialist realizes they're wrong at any point, they will leave.\n" - f"The question is: Create an OpenAPI specification for this REST API {self.rest_api} " - ) - ] - else: - tree_of_thoughts_steps = [ - ( - "Imagine three different Pentest experts are answering this question.\n" - "All experts will write down one step of their thinking,\n" - "then share it with the group.\n" - "After that, all experts will proceed to the next step, and so on.\n" - "If any expert realizes they're wrong at any point, they will leave.\n" - f"The question is: Create pentests for this REST API {self.rest_api} " - ) - ] - - # Assuming ChatCompletionMessage and ChatCompletionMessageParam have a 'content' attribute - previous_content = previous_prompt[turn].content if turn is not None else "initial_prompt" - - self.purpose = PromptPurpose.AUTHENTICATION_AUTHORIZATION - - return "\n".join([previous_content] + tree_of_thoughts_steps) diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_analyzer.py b/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_analyzer.py index 9b2c2ac9..ff9fa4ca 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_analyzer.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_analyzer.py @@ -2,7 +2,7 @@ import re from typing import Any, Dict, Optional, Tuple -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptPurpose +from hackingBuddyGPT.utils.prompt_generation.information import PromptPurpose class ResponseAnalyzer: @@ -49,8 +49,12 @@ def parse_http_response(self, raw_response: str) -> Tuple[Optional[int], Dict[st if body != {} and bool(body and not body.isspace()): body = json.loads(body)[0] - else: - body = "Empty" + + if body == "": + for line in header_lines: + if line.startswith("{") or line.startswith("["): + body = line + body = json.loads(body) status_line = header_lines[0].strip() headers = { @@ -77,7 +81,7 @@ def analyze_response(self, raw_response: str) -> Optional[Dict[str, Any]]: return self.analyze_parsed_response(status_code, headers, body) def analyze_parsed_response( - self, status_code: Optional[int], headers: Dict[str, str], body: str + self, status_code: Optional[int], headers: Dict[str, str], body: str ) -> Optional[Dict[str, Any]]: """ Analyzes the parsed HTTP response based on the purpose, invoking the appropriate method. @@ -91,7 +95,7 @@ def analyze_parsed_response( Optional[Dict[str, Any]]: The analysis results based on the purpose. """ analysis_methods = { - PromptPurpose.AUTHENTICATION_AUTHORIZATION: self.analyze_authentication_authorization( + PromptPurpose.AUTHENTICATION: self.analyze_authentication_authorization( status_code, headers, body ), PromptPurpose.INPUT_VALIDATION: self.analyze_input_validation(status_code, headers, body), @@ -99,7 +103,7 @@ def analyze_parsed_response( return analysis_methods.get(self.purpose) def analyze_authentication_authorization( - self, status_code: Optional[int], headers: Dict[str, str], body: str + self, status_code: Optional[int], headers: Dict[str, str], body: str ) -> Dict[str, Any]: """ Analyzes the HTTP response with a focus on authentication and authorization. @@ -134,7 +138,7 @@ def analyze_authentication_authorization( return analysis def analyze_input_validation( - self, status_code: Optional[int], headers: Dict[str, str], body: str + self, status_code: Optional[int], headers: Dict[str, str], body: str ) -> Dict[str, Any]: """ Analyzes the HTTP response with a focus on input validation. @@ -176,12 +180,12 @@ def is_valid_input_response(self, status_code: Optional[int], body: str) -> str: return "Unexpected" def document_findings( - self, - status_code: Optional[int], - headers: Dict[str, str], - body: str, - expected_behavior: str, - actual_behavior: str, + self, + status_code: Optional[int], + headers: Dict[str, str], + body: str, + expected_behavior: str, + actual_behavior: str, ) -> Dict[str, Any]: """ Documents the findings from the analysis, comparing expected and actual behavior. @@ -203,9 +207,7 @@ def document_findings( "Expected Behavior": expected_behavior, "Actual Behavior": actual_behavior, } - print("Documenting Findings:") - print(json.dumps(document, indent=4)) - print("-" * 50) + return document def report_issues(self, document: Dict[str, Any]) -> None: diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_analyzer_with_llm.py b/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_analyzer_with_llm.py index 204eba13..02e03663 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_analyzer_with_llm.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_analyzer_with_llm.py @@ -4,10 +4,10 @@ from unittest.mock import MagicMock from hackingBuddyGPT.capabilities.http_request import HTTPRequest -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information import ( +from hackingBuddyGPT.utils.prompt_generation.information import ( PenTestingInformation, ) -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( +from hackingBuddyGPT.utils.prompt_generation.information import ( PromptPurpose, ) from hackingBuddyGPT.usecases.web_api_testing.utils import LLMHandler @@ -23,7 +23,8 @@ class ResponseAnalyzerWithLLM: purpose (PromptPurpose): The specific purpose for analyzing the HTTP response. """ - def __init__(self, purpose: PromptPurpose = None, llm_handler: LLMHandler = None): + def __init__(self, purpose: PromptPurpose = None, llm_handler: LLMHandler = None, + pentesting_info: PenTestingInformation = None, capacity: Any = None, prompt_helper: Any = None): """ Initializes the ResponseAnalyzer with an optional purpose and an LLM instance. @@ -34,7 +35,10 @@ def __init__(self, purpose: PromptPurpose = None, llm_handler: LLMHandler = None """ self.purpose = purpose self.llm_handler = llm_handler - self.pentesting_information = PenTestingInformation() + self.pentesting_information = pentesting_info + self.capacity = capacity + self.prompt_helper = prompt_helper + self.token = "" def set_purpose(self, purpose: PromptPurpose): """ @@ -57,7 +61,7 @@ def print_results(self, results: Dict[str, str]): print(f"Response: {response}") print("-" * 50) - def analyze_response(self, raw_response: str, prompt_history: list) -> tuple[dict[str, Any], list]: + def analyze_response(self, raw_response: str, prompt_history: list, analysis_context: Any) -> tuple[list[str], Any]: """ Parses the HTTP response, generates prompts for an LLM, and processes each step with the LLM. @@ -67,20 +71,24 @@ def analyze_response(self, raw_response: str, prompt_history: list) -> tuple[dic Returns: dict: A dictionary with the final results after processing all steps through the LLM. """ - status_code, headers, body = self.parse_http_response(raw_response) - full_response = f"Status Code: {status_code}\nHeaders: {json.dumps(headers, indent=4)}\nBody: {body}" # Start processing the analysis steps through the LLM llm_responses = [] - steps_dict = self.pentesting_information.analyse_steps(full_response) - for steps in steps_dict.values(): - response = full_response # Reset to the full response for each purpose + + + steps = analysis_context.get("steps") + if len(steps) > 1: # multisptep test case for step in steps: - prompt_history, response = self.process_step(step, prompt_history) - llm_responses.append(response) - print(f"Response:{response}") + if step != steps[0]: + + current_step = step.get("step") + prompt_history, raw_response = self.process_step(current_step, prompt_history, "http_request") + test_case_responses, status_code = self.analyse_response(raw_response, step, prompt_history) + llm_responses = llm_responses + test_case_responses + else: + llm_responses, status_code = self.analyse_response(raw_response, steps[0], prompt_history) - return llm_responses + return llm_responses, status_code def parse_http_response(self, raw_response: str): """ @@ -95,42 +103,90 @@ def parse_http_response(self, raw_response: str): header_body_split = raw_response.split("\r\n\r\n", 1) header_lines = header_body_split[0].split("\n") body = header_body_split[1] if len(header_body_split) > 1 else "" + if body == "": + for line in header_lines: + if line.startswith("{") or line.startswith("["): + body = line + status_line = header_lines[0].strip() - match = re.match(r"HTTP/1\.1 (\d{3}) (.*)", status_line) - status_code = int(match.group(1)) if match else None - if body.__contains__(""): + body = "" + elif body.startswith("["): body = json.loads(body) - if isinstance(body, list) and len(body) > 1: + print(f'"body:{body}') + elif body.__contains__("{") and (body != '' or body != ""): + if not body.lower().__contains__("png") : + body = json.loads(body) + if "token" in body: + + self.prompt_helper.current_user["token"] = body["token"] + self.token = body["token"] + for account in self.prompt_helper.accounts: + if account.get("x") == self.prompt_helper.current_user.get("x"): + if "token" not in account.keys(): + account["token"] = self.token + else: + if account["token"] != self.token: + account["token"] = self.token + print(f'token:{self.token}') + print(f"accoun:{account}") + if any (value in body.values() for value in self.prompt_helper.current_user.values()): + if "id" in body: + for account in self.prompt_helper.accounts: + if account.get("x") == self.prompt_helper.current_user.get( + "x") and "id" not in account.keys(): + account["id"] = body["id"] + + + #self.replace_account() + elif isinstance(body, list) and len(body) > 1: body = body[0] + if self.prompt_helper.current_user in body: + self.prompt_helper.current_user["id"] = self.get_id_from_user(body) + if self.prompt_helper.current_user not in self.prompt_helper.accounts: + self.prompt_helper.accounts.append(self.prompt_helper.current_user) + else: + if self.prompt_helper.current_user not in self.prompt_helper.accounts: + self.prompt_helper.accounts.append(self.prompt_helper.current_user) + headers = { key.strip(): value.strip() for key, value in (line.split(":", 1) for line in header_lines[1:] if ":" in line) } - match = re.match(r"HTTP/1\.1 (\d{3}) (.*)", status_line) - status_code = int(match.group(1)) if match else None + if isinstance(body, str) and body.startswith(" ") and body.endswith(""): + body = "" return status_code, headers, body - def process_step(self, step: str, prompt_history: list) -> tuple[list, str]: + def get_id_from_user(self, body) -> str: + id = body.split("id")[1].split(",")[0] + return id + + + def process_step(self, step: str, prompt_history: list, capability:str) -> tuple[list, str]: """ Helper function to process each analysis step with the LLM. """ # Log current step - # print(f'Processing step: {step}') - prompt_history.append({"role": "system", "content": step}) + prompt_history.append({"role": "system", "content": step + "Stay within the output limit."}) # Call the LLM and handle the response - response, completion = self.llm_handler.call_llm(prompt_history) + response, completion = self.llm_handler.execute_prompt_with_specific_capability(prompt_history, capability) message = completion.choices[0].message prompt_history.append(message) tool_call_id = message.tool_calls[0].id @@ -144,6 +200,75 @@ def process_step(self, step: str, prompt_history: list) -> tuple[list, str]: return prompt_history, result + def analyse_response(self, raw_response, step, prompt_history): + llm_responses = [] + + status_code, additional_analysis_context, full_response= self.get_addition_context(raw_response, step) + + expected_responses = step.get("expected_response_code") + + + if step.get("purpose") == PromptPurpose.SETUP: + _, additional_analysis_context, full_response = self.do_setup(status_code, step, additional_analysis_context, full_response, prompt_history) + + if not any(str(status_code) in response for response in expected_responses): + additional_analysis_context += step.get("conditions").get("if_unsuccessful") + else: + additional_analysis_context += step.get("conditions").get("if_successful") + + llm_responses.append(full_response) + if step.get("purpose") != PromptPurpose.SETUP: + for purpose in self.pentesting_information.analysis_step_list: + analysis_step = self.pentesting_information.get_analysis_step(purpose, full_response, + additional_analysis_context) + prompt_history, response = self.process_step(analysis_step, prompt_history, "record_note") + llm_responses.append(response) + full_response = response # make it iterative + + return llm_responses, status_code + + def get_addition_context(self, raw_response: str, step: dict) : + # Parse response + status_code, headers, body = self.parse_http_response(raw_response) + + full_response = f"Status Code: {status_code}\nHeaders: {json.dumps(headers, indent=4)}\nBody: {body}" + expected_responses = step.get("expected_response_code") + security = step.get("security") + additional_analysis_context = f"\n Ensure that the status code is one of the expected responses: '{expected_responses}\n Also ensure that the following security requirements have been met: {security}" + return status_code, additional_analysis_context, full_response + + def do_setup(self, status_code, step, additional_analysis_context, full_response, prompt_history): + counter = 0 + if not any(str(status_code) in response for response in step.get("expected_response_code")): + add_info = "Unsuccessful. Try a different input for the schema." + while not any(str(status_code) in response for response in step.get("expected_response_code")): + prompt_history, response = self.process_step(step.get("step") + add_info, prompt_history, "http_request") + status_code, additional_analysis_context, full_response = self.get_addition_context(response, step) + counter += 1 + + if counter == 5: + full_response += "Unsuccessful:" + step.get("conditions").get("if_unsuccessful") + break + + + + return status_code, additional_analysis_context, full_response + + def replace_account(self): + # Now let's replace the existing account if it exists, otherwise add it + replaced = False + for i, account in enumerate(self.prompt_helper.accounts): + # Compare the 'id' (or any unique field) to find the matching account + if account.get("x") == self.prompt_helper.current_user.get("x"): + self.prompt_helper.accounts[i] = self.prompt_helper.current_user + replaced = True + break + + # If we did not replace any existing account, append this as a new account + if not replaced: + self.prompt_helper.accounts.append(self.prompt_helper.current_user) + + if __name__ == "__main__": # Example HTTP response to parse diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_handler.py b/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_handler.py index c7ac733d..c3b33dd2 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_handler.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/response_processing/response_handler.py @@ -1,10 +1,19 @@ +import copy import json import re +from collections import Counter +from itertools import cycle from typing import Any, Dict, Optional, Tuple - +import random +from urllib.parse import urlencode +import pydantic_core from bs4 import BeautifulSoup +from rich.panel import Panel -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.pentesting_information import ( +from hackingBuddyGPT.usecases.web_api_testing.documentation.pattern_matcher import PatternMatcher +from hackingBuddyGPT.utils.prompt_generation import PromptGenerationHelper +from hackingBuddyGPT.utils.prompt_generation.information import PromptContext +from hackingBuddyGPT.utils.prompt_generation.information import ( PenTestingInformation, ) from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_analyzer_with_llm import ( @@ -12,6 +21,7 @@ ) from hackingBuddyGPT.usecases.web_api_testing.utils import LLMHandler from hackingBuddyGPT.usecases.web_api_testing.utils.custom_datatypes import Prompt +from hackingBuddyGPT.utils import tool_message class ResponseHandler: @@ -25,18 +35,110 @@ class ResponseHandler: response_analyzer (ResponseAnalyzerWithLLM): An instance for analyzing responses with the LLM. """ - def __init__(self, llm_handler: LLMHandler) -> None: + def __init__(self, llm_handler: LLMHandler, prompt_context: PromptContext, config: Any, + prompt_helper: PromptGenerationHelper, pentesting_information: PenTestingInformation = None) -> None: """ Initializes the ResponseHandler with the specified LLM handler. Args: llm_handler (LLMHandler): An instance of the LLM handler for interacting with the LLM. """ + self.no_new_endpoint_counter = 0 + self.all_query_combinations = [] self.llm_handler = llm_handler - self.pentesting_information = PenTestingInformation() - self.response_analyzer = ResponseAnalyzerWithLLM(llm_handler=llm_handler) - - def get_response_for_prompt(self, prompt: str) -> str: + self.no_action_counter = 0 + if prompt_context == PromptContext.PENTESTING: + self.pentesting_information = pentesting_information + + self.common_endpoints = ['autocomplete', '/api', '/auth', '/login', '/admin', '/register', '/users', '/photos', '/images', + '/products', '/orders', + '/search', '/posts', '/todos', '/1', '/resources', '/categories', + '/cart', '/checkout', '/payments', '/transactions', '/invoices', '/teams', '/comments', + '/jobs', + '/notifications', '/messages', '/files', '/settings', '/status', '/health', + '/healthcheck', + '/info', '/docs', '/swagger', '/openapi', '/metrics', '/logs', '/analytics', + '/feedback', + '/support', '/profile', '/account', '/reports', '/dashboard', '/activity', + '/subscriptions', '/webhooks', + '/events', '/upload', '/download', '/images', '/videos', '/user/login', '/api/v1', + '/api/v2', + '/auth/login', '/auth/logout', '/auth/register', '/auth/refresh', '/users/{id}', + '/users/me', '/products/{id}' + '/users/profile', '/users/settings', '/products/{id}', '/products/search', + '/orders/{id}', + '/orders/history', '/cart/items', '/cart/checkout', '/checkout/confirm', + '/payments/{id}', + '/payments/methods', '/transactions/{id}', '/transactions/history', + '/notifications/{id}', + '/messages/{id}', '/messages/send', '/files/upload', '/files/{id}', '/admin/users', + '/admin/settings', + '/settings/preferences', '/search/results', '/feedback/{id}', '/support/tickets', + '/profile/update', + '/password/reset', '/password/change', '/account/delete', '/account/activate', + '/account/deactivate', + '/account/settings', '/account/preferences', '/reports/{id}', '/reports/download', + '/dashboard/stats', + '/activity/log', '/subscriptions/{id}', '/subscriptions/cancel', '/webhooks/{id}', + '/events/{id}', + '/images/{id}', '/videos/{id}', '/files/download/{id}', '/support/tickets/{id}'] + self.common_endpoints_categorized_cycle, self.common_endpoints_categorized = self.categorize_endpoints() + self.query_counter = 0 + self.repeat_counter = 0 + self.variants_of_found_endpoints = [] + self.name = config.get("name") + self.token = config.get("token") + self.last_path = "" + self.prompt_helper = prompt_helper + self.pattern_matcher = PatternMatcher() + self.saved_endpoints = {} + self.response_analyzer = None + + def set_response_analyzer(self, response_analyzer: ResponseAnalyzerWithLLM) -> None: + self.response_analyzer = response_analyzer + + def categorize_endpoints(self) : + root_level = [] + single_parameter = [] + subresource = [] + related_resource = [] + multi_level_resource = [] + + # Iterate through the cycle of endpoints + for endpoint in self.common_endpoints: + parts = [part for part in endpoint.split('/') if part] + + if len(parts) == 1: + root_level.append(endpoint) + elif len(parts) == 2: + if "{id}" in parts[1]: + single_parameter.append(endpoint) + else: + subresource.append(endpoint) + elif len(parts) == 3: + if any("{id}" in part for part in parts): + related_resource.append(endpoint) + else: + multi_level_resource.append(endpoint) + else: + multi_level_resource.append(endpoint) + + return { + 1: cycle(root_level), + 2: cycle(single_parameter), + 3: cycle(subresource), + 4: cycle(related_resource), + 5: cycle(multi_level_resource), + }, { + 1: root_level, + 2: single_parameter, + 3: subresource, + 4: related_resource, + 5: multi_level_resource, + } + + + def get_response_for_prompt(self, prompt: str) -> object: """ Sends a prompt to the LLM's API and retrieves the response. @@ -47,9 +149,8 @@ def get_response_for_prompt(self, prompt: str) -> str: str: The response from the API. """ messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] - response, completion = self.llm_handler.call_llm(messages) - response_text = response.execute() - return response_text + response, completion = self.llm_handler.execute_prompt(messages) + return response, completion def parse_http_status_line(self, status_line: str) -> str: """ @@ -95,7 +196,7 @@ def extract_response_example(self, html_content: str) -> Optional[Dict[str, Any] return None def parse_http_response_to_openapi_example( - self, openapi_spec: Dict[str, Any], http_response: str, path: str, method: str + self, openapi_spec: Dict[str, Any], http_response: str, path: str, method: str ) -> Tuple[Optional[Dict[str, Any]], Optional[str], Dict[str, Any]]: """ Parses an HTTP response to generate an OpenAPI example. @@ -118,20 +219,50 @@ def parse_http_response_to_openapi_example( reference, object_name, openapi_spec = self.parse_http_response_to_schema(openapi_spec, body_dict, path) entry_dict = {} + old_body_dict = copy.deepcopy(body_dict) - if len(body_dict) == 1: - entry_dict["id"] = {"value": body_dict} - self.llm_handler.add_created_object(entry_dict, object_name) + if len(body_dict) == 1 and "data" not in body_dict: + entry_dict["id"] = body_dict + self.llm_handler._add_created_object(entry_dict, object_name) else: - if isinstance(body_dict, list): - for entry in body_dict: - key = entry.get("title") or entry.get("name") or entry.get("id") - entry_dict[key] = {"value": entry} - self.llm_handler.add_created_object(entry_dict[key], object_name) + if "data" in body_dict: + body_dict = body_dict["data"] + if isinstance(body_dict, list) and len(body_dict) > 0: + body_dict = body_dict[0] + if isinstance(body_dict, list): + for entry in body_dict: + key = entry.get("title") or entry.get("name") or entry.get("id") + entry_dict[key] = {"value": entry} + self.llm_handler._add_created_object(entry_dict[key], object_name) + if len(entry_dict) > 3: + break + + + if isinstance(body_dict, list) and len(body_dict) > 0: + body_dict = body_dict[0] + if isinstance(body_dict, list): + + for entry in body_dict: + key = entry.get("title") or entry.get("name") or entry.get("id") + entry_dict[key] = entry + self.llm_handler._add_created_object(entry_dict[key], object_name) + if len(entry_dict) > 3: + break else: - key = body_dict.get("title") or body_dict.get("name") or body_dict.get("id") - entry_dict[key] = {"value": body_dict} - self.llm_handler.add_created_object(entry_dict[key], object_name) + if isinstance(body_dict, list) and len(body_dict) == 0: + entry_dict = "" + elif isinstance(body_dict, dict) and "data" in body_dict.keys(): + entry_dict = body_dict["data"] + if isinstance(entry_dict, list) and len(entry_dict) > 0: + entry_dict = entry_dict[0] + else: + entry_dict= body_dict + self.llm_handler._add_created_object(entry_dict, object_name) + if isinstance(old_body_dict, dict) and len(old_body_dict.keys()) > 0 and "data" in old_body_dict.keys() and isinstance(old_body_dict, dict) \ + and isinstance(entry_dict, dict): + old_body_dict.pop("data") + entry_dict = {**entry_dict, **old_body_dict} + return entry_dict, reference, openapi_spec @@ -148,36 +279,41 @@ def extract_description(self, note: Any) -> str: return note.action.content def parse_http_response_to_schema( - self, openapi_spec: Dict[str, Any], body_dict: Dict[str, Any], path: str + self, openapi_spec: Dict[str, Any], body_dict: Dict[str, Any], path: str ) -> Tuple[str, str, Dict[str, Any]]: """ Parses an HTTP response body to generate an OpenAPI schema. Args: openapi_spec (Dict[str, Any]): The OpenAPI specification to update. - body_dict (Dict[str, Any]): The HTTP response body as a dictionary. + body_dict (Dict[str, Any]): The HTTP response body as a dictionary or list. path (str): The API path. Returns: Tuple[str, str, Dict[str, Any]]: A tuple containing the reference, object name, and updated OpenAPI specification. """ + if "/" not in path: + return None, None, openapi_spec + object_name = path.split("/")[1].capitalize().rstrip("s") properties_dict = {} - if len(body_dict) == 1: - properties_dict["id"] = {"type": "int", "format": "uuid", "example": str(body_dict["id"])} - else: - for param in body_dict: - if isinstance(body_dict, list): - for key, value in param.items(): - properties_dict = self.extract_keys(key, value, properties_dict) - break - else: - for key, value in body_dict.items(): - properties_dict = self.extract_keys(key, value, properties_dict) + # Handle different structures of `body_dict` + if isinstance(body_dict, dict): + for key, value in body_dict.items(): + # If it's a nested dictionary, extract keys recursively + properties_dict = self.extract_keys(key, value, properties_dict) + elif isinstance(body_dict, list) and len(body_dict) > 0: + first_item = body_dict[0] + if isinstance(first_item, dict): + for key, value in first_item.items(): + properties_dict = self.extract_keys(key, value, properties_dict) + + # Create the schema object for this response object_dict = {"type": "object", "properties": properties_dict} + # Add the schema to OpenAPI spec if not already present if object_name not in openapi_spec["components"]["schemas"]: openapi_spec["components"]["schemas"][object_name] = object_dict @@ -252,7 +388,7 @@ def extract_keys(self, key: str, value: Any, properties_dict: Dict[str, Any]) -> return properties_dict - def evaluate_result(self, result: Any, prompt_history: Prompt) -> Any: + def evaluate_result(self, result: Any, prompt_history: Prompt, analysis_context: Any) -> Any: """ Evaluates the result using the LLM-based response analyzer. @@ -263,5 +399,603 @@ def evaluate_result(self, result: Any, prompt_history: Prompt) -> Any: Returns: Any: The evaluation result from the LLM response analyzer. """ - llm_responses = self.response_analyzer.analyze_response(result, prompt_history) - return llm_responses + self.response_analyzer._prompt_helper = self.prompt_helper + llm_responses, status_code = self.response_analyzer.analyze_response(result, prompt_history, analysis_context) + return llm_responses, status_code + + def extract_key_elements_of_response(self, raw_response: Any) -> str: + status_code, headers, body = self.response_analyzer.parse_http_response(raw_response) + return "Status Code: " + str(status_code) + "\nHeaders:" + str(headers) + "\nBody" + str(body) + + def handle_response(self, response, completion, prompt_history, log, categorized_endpoints, move_type): + """ + Evaluates the response to determine if it is acceptable. + + Args: + response (str): The response to evaluate. + completion (Completion): The completion object with tool call results. + prompt_history (list): History of prompts and responses. + log (Log): Logging object for console output. + + Returns: + tuple: (bool, prompt_history, result, result_str) indicating if response is acceptable. + """ + # Extract message and tool call information + message = completion.choices[0].message + tool_call_id = message.tool_calls[0].id + if "undefined" in response.action.path : + response.action.path = response.action.path.replace("undefined", "1") + if "Id" in response.action.path: + path = response.action.path.split("/") + if len(path) > 2: + response.action.path = f"/{path[0]}/1/{path[2]}" + else: + response.action.path = f"/{path[0]}/1" + + + + + if self.repeat_counter == 3: + self.repeat_counter = 0 + if self.prompt_helper.current_step == 2: + adjusted_path = self.adjust_path_if_necessary(response.action.path) + self.prompt_helper.hint_for_next_round = f'Try this endpoint in the next round {adjusted_path}' + self.no_action_counter += 1 + return False, prompt_history, None, None + + if response.__class__.__name__ == "RecordNote": + prompt_history.append(tool_message(response, tool_call_id)) + return False, prompt_history, None, None + + else: + return self.handle_http_response(response, prompt_history, log, completion, message, categorized_endpoints, + tool_call_id, move_type) + + def normalize_path(self, path): + # Use regex to strip trailing digits + return re.sub(r'\d+$', '', path) + + def check_path_variants(self, path, paths): + # Normalize the paths + normalized_paths = [self.normalize_path(path) for path in paths] + + # Count each normalized path + path_counts = Counter(normalized_paths) + + # Extract paths that have more than one variant + variants = {path: count for path, count in path_counts.items() if count > 1} + if len(variants) != 0: + return True + return False + + def handle_http_response(self, response: Any, prompt_history: Any, log: Any, completion: Any, message: Any, + categorized_endpoints, tool_call_id, move_type) -> Any: + + response = self.adjust_path(response, move_type) + # Add Authorization header if token is available + if self.token: + response.action.headers = {"Authorization": f"Bearer {self.token}"} + if self.name.__contains__("ballardtide"): + response.action.headers = {"Authorization": f"{self.token}"} + + # Convert response to JSON and display it + command = json.loads(pydantic_core.to_json(response).decode()) + log.console.print(Panel(json.dumps(command, indent=2), title="assistant")) + + # Execute the command and parse the result + with log.console.status("[bold green]Executing command..."): + + + result = response.execute() + self.query_counter += 1 + result_dict = self.extract_json(result) + log.console.print(Panel(result, title="tool")) + if "Could not request" in result: + return False, prompt_history, result, "" + + if response.action.__class__.__name__ != "RecordNote": + self.prompt_helper.tried_endpoints.append(response.action.path) + + # Parse HTTP status and request path + result_str = self.parse_http_status_line(result) + request_path = response.action.path + + if "action" not in command: + return False, prompt_history, response, completion + + # Check response success + is_successful = result_str.startswith("200") + prompt_history.append(message) + self.last_path = request_path + + status_message = self.check_if_successful(is_successful, request_path, result_dict, result_str, categorized_endpoints) + log.console.print(Panel(status_message, title="system")) + + prompt_history.append(tool_message(status_message, tool_call_id)) + + else: + prompt_history.append(tool_message(result, tool_call_id)) + is_successful = False + result_str = result[:20] + + return is_successful, prompt_history, result, result_str + + def extract_params(self, url): + + params = re.findall(r'(\w+)=([^&]*)', url) + extracted_params = {key: value for key, value in params} + + return extracted_params + + def get_next_key(self, current_key, dictionary): + keys = list(dictionary.keys()) # Convert keys to a list + try: + current_index = keys.index(current_key) # Find the index of the current key + return keys[current_index + 1] # Return the next key + except (ValueError, IndexError): + return None # Return None if the current key is not found or there is no next key + + def extract_json(self, response: str) -> dict: + try: + # Find the start of the JSON body by locating the first '{' character + json_start = response.index('{') + # Extract the JSON part of the response + json_data = response[json_start:] + # Convert the JSON string to a dictionary + data_dict = json.loads(json_data) + return data_dict + except (ValueError, json.JSONDecodeError) as e: + print(f"Error extracting JSON: {e}") + return {} + + def generate_variants_of_found_endpoints(self, type_of_variant): + for endpoint in self.prompt_helper.found_endpoints: + if endpoint + "/1" in self.variants_of_found_endpoints: + self.variants_of_found_endpoints.remove(endpoint + "/1") + if "id" not in endpoint and endpoint + "/{id}" not in self.prompt_helper.found_endpoints and endpoint.endswith( + 's'): + self.variants_of_found_endpoints.append(endpoint + "/1") + if "/1" not in self.variants_of_found_endpoints or self.prompt_helper.found_endpoints: + self.variants_of_found_endpoints.append("/1") + + def get_next_path(self, path): + counter = 0 + if self.prompt_helper.current_step >= 6: + new_path = self.create_common_query_for_endpoint(path) + if path == "params": + return path + return new_path + try: + + new_path = next(self.common_endpoints_categorized_cycle[self.prompt_helper.current_step]) + while not new_path in self.prompt_helper.found_endpoints or not new_path in self.prompt_helper.unsuccessful_paths: + new_path = next(self.common_endpoints_categorized_cycle[self.prompt_helper.current_step]) + counter = counter + 1 + if counter >= 6: + return new_path + + return new_path + except StopIteration: + return path + + + def finalize_path(self, path: str) -> str: + """ + Final processing on the path before returning: + - Replace any '{id}' with '1' + - Then ALWAYS replace '1' with 'bitcoin' (no more 'if "Coin" in self.name') + - If "OWASP API" in self.name, capitalize the path + """ + # Replace {id} with '1' + # Unconditionally replace '1' with 'bitcoin' + + if path is None: + l = self.common_endpoints_categorized[self.prompt_helper.current_step] + return random.choice(l) + if ("Coin" in self.name or "gbif" in self.name)and self.prompt_helper.current_step == 2: + id = self.prompt_helper.get_possible_id_for_instance_level_ep(path) + if id: + path = path.replace("1", f"{id}") + else: + path = path.replace("{id}", "1") + + # Keep the OWASP API naming convention if needed + if "OWASP API" in self.name: + path = path.capitalize() + + return path + + def adjust_path_if_necessary(self, path: str) -> str: + """ + Adjusts the given path based on the current step in self.prompt_helper and certain conditions. + Always replaces '1' with 'bitcoin', no matter what self.name is. + """ + # Ensure path starts with a slash + if not path.startswith("/"): + path = "/" + path + + parts = [part for part in path.split("/") if part] + pattern_replaced_path = self.pattern_matcher.replace_according_to_pattern(path) + + # Reset logic + if self.no_action_counter == 5: + self.no_action_counter = 0 + # Return next path (finalize it) + return self.finalize_path(self.get_next_path(path)) + + if parts: + root_path = '/' + parts[0] + + if self.prompt_helper.current_step == 1: + if len(parts) > 1: + if root_path not in ( + self.prompt_helper.found_endpoints or self.prompt_helper.unsuccessful_paths): + self.save_endpoint(path) + return self.finalize_path(root_path) + else: + self.save_endpoint(path) + return self.finalize_path(self.get_next_path(path)) + else: + # Single-part path + if (path in self.prompt_helper.found_endpoints or + path in self.prompt_helper.unsuccessful_paths or + path == self.last_path): + return self.finalize_path(self.get_next_path(path)) + + elif self.prompt_helper.current_step == 2: + if len(parts) != 2: + if path in self.prompt_helper.unsuccessful_paths: + ep = self.prompt_helper._get_instance_level_endpoint(self.name) + return self.finalize_path(ep) + + if path in self.prompt_helper.found_endpoints and len(parts) == 1: + if "Coin" in self.name or "gbif" in self.name: + id = self.prompt_helper.get_possible_id_for_instance_level_ep(path) + if id: + path = path.replace("1", f"{id}") + return self.finalize_path(path) + # Append /1 -> becomes /bitcoin after finalize_path + return self.finalize_path(f"{path}/1") + + ep = self.prompt_helper._get_instance_level_endpoint(self.name) + return self.finalize_path(ep) + + elif self.prompt_helper.current_step == 3: + if path in self.prompt_helper.unsuccessful_paths: + ep = self.prompt_helper._get_sub_resource_endpoint( + random.choice(self.prompt_helper.found_endpoints), + self.common_endpoints, self.name + ) + return self.finalize_path(ep) + + ep = self.prompt_helper._get_sub_resource_endpoint(path, self.common_endpoints, self.name) + return self.finalize_path(ep) + + elif self.prompt_helper.current_step == 4: + if path in self.prompt_helper.unsuccessful_paths: + ep = self.prompt_helper._get_related_resource_endpoint( + random.choice(self.prompt_helper.found_endpoints), + self.common_endpoints, + self.name + ) + return self.finalize_path(ep) + + ep = self.prompt_helper._get_related_resource_endpoint(path, self.common_endpoints, self.name) + return self.finalize_path(ep) + + elif self.prompt_helper.current_step == 5: + if path in self.prompt_helper.unsuccessful_paths: + ep = self.prompt_helper._get_multi_level_resource_endpoint( + random.choice(self.prompt_helper.found_endpoints), + self.common_endpoints, + self.name + ) + else: + ep = self.prompt_helper._get_multi_level_resource_endpoint(path, self.common_endpoints, self.name) + return self.finalize_path(ep) + + elif (self.prompt_helper.current_step == 6 and + "?" not in path): + new_path = self.create_common_query_for_endpoint(path) + # If "no params", keep original path, else use new_path + return self.finalize_path(path if new_path == "no params" else new_path) + + # Already-handled paths + if (path in {self.last_path, + *self.prompt_helper.unsuccessful_paths, + *self.prompt_helper.found_endpoints} + and self.prompt_helper.current_step != 6): + return self.finalize_path(random.choice(self.common_endpoints)) + + # Pattern-based check + if (pattern_replaced_path in self.prompt_helper.found_endpoints or + pattern_replaced_path in self.prompt_helper.unsuccessful_paths) and self.prompt_helper.current_step != 2: + return self.finalize_path(random.choice(self.common_endpoints)) + + else: + # No parts + if self.prompt_helper.current_step == 1: + root_level_endpoints = self.prompt_helper._get_root_level_endpoints() + chosen = root_level_endpoints[0] if root_level_endpoints else self.get_next_path(path) + return self.finalize_path(chosen) + + if self.prompt_helper.current_step == 2: + ep = self.prompt_helper._get_instance_level_endpoint(self.name) + return self.finalize_path(ep) + + # If none of the above conditions matched, we finalize the path or get_next_path + if path: + return self.finalize_path(path) + return self.finalize_path(self.get_next_path(path)) + + + + def save_endpoint(self, path): + + parts = [part.strip() for part in path.split("/") if part.strip()] + if len(parts) not in self.saved_endpoints.keys(): + self.saved_endpoints[len(parts)] = [] + if path not in self.saved_endpoints[len(parts)]: + self.saved_endpoints[len(parts)].append(path) + if path not in self.prompt_helper.saved_endpoints: + self.prompt_helper.saved_endpoints.append(path) + + def get_saved_endpoint(self): + # First check if there are any saved endpoints for the current step + if self.prompt_helper.current_step in self.saved_endpoints and self.saved_endpoints[ + self.prompt_helper.current_step]: + # Get the first endpoint in the list for the current step + saved_endpoint = self.saved_endpoints[self.prompt_helper.current_step][0] + saved_endpoint = saved_endpoint.replace("{id}", "1") + + # Check if this endpoint has not been found or unsuccessfully tried + if saved_endpoint not in self.prompt_helper.found_endpoints and saved_endpoint not in self.prompt_helper.unsuccessful_paths: + # If it is a valid endpoint, delete it from saved endpoints to avoid reuse + del self.saved_endpoints[self.prompt_helper.current_step][0] + if not saved_endpoint.endswith("s") and not saved_endpoint.endswith("1"): + saved_endpoint = saved_endpoint + "s" + return saved_endpoint + + # Return None or raise an exception if no valid endpoint is found + return None + + def adjust_counter(self, categorized_endpoints): + # Helper function to handle the increment and reset actions + def update_step_and_category(): + if self.prompt_helper.current_step != 6: + self.prompt_helper.current_step += 1 + self.prompt_helper.current_category = self.get_next_key(self.prompt_helper.current_category, + categorized_endpoints) + self.query_counter = 0 + + # Check for step-specific conditions or query count thresholds + if (self.prompt_helper.current_step == 1 and self.query_counter > 150): + update_step_and_category() + elif self.prompt_helper.current_step == 2 and not self.prompt_helper._get_instance_level_endpoints(self.name): + update_step_and_category() + elif self.prompt_helper.current_step > 2 and self.query_counter > 30: + update_step_and_category() + elif self.prompt_helper.current_step == 7 and not self.prompt_helper._get_root_level_endpoints(self.name): + update_step_and_category() + + def create_common_query_for_endpoint(self, endpoint): + """ + Constructs complete URLs with one query parameter for each API endpoint. + + + Returns: + list: A list of full URLs with appended query parameters. + """ + + endpoint = endpoint + "?" + # Define common query parameters + common_query_params = [ + "page", "limit", "sort", "filter", "search", "api_key", "access_token", + "callback", "fields", "expand", "since", "until", "status", "lang", + "locale", "region", "embed", "version", "format", "username" + ] + + # Sample dictionary of parameters for demonstration + full_params = { + "page": 2, + "limit": 10, + "sort": "date_desc", + "filter": "status:active", + "search": "example query", + "api_key": "YourAPIKeyHere", + "access_token": "YourAccessToken", + "callback": "myFunction", + "fields": "id,name,status", + "expand": "details,owner", + "since": "2020-01-01T00:00:00Z", + "until": "2022-01-01T00:00:00Z", + "status": "active", + "lang": "en", + "locale": "en_US", + "region": "North America", + "embed": "true", + "version": "1.0", + "format": "json", + "username": "test" + } + + urls_with_params = [] + + # Iterate through all found endpoints + # Pick one random parameter from the common query params + random_param_key = random.choice(common_query_params) + + # Check if the selected key is in the full_params + if random_param_key in full_params: + sampled_params = {random_param_key: full_params[random_param_key]} + else: + sampled_params = {} + + # Encode the parameters into a query string + query_string = urlencode(sampled_params) + + # Ensure the endpoint doesn't end with a slash + if endpoint.endswith('/') or endpoint.endswith("?"): + endpoint = endpoint[:-1] + + # Construct the full URL with the query parameter + full_url = f"{endpoint}?{query_string}" + urls_with_params.append(full_url) + if endpoint in self.prompt_helper.query_endpoints_params.keys(): + if random_param_key not in self.prompt_helper.query_endpoints_params[endpoint]: + if random_param_key not in self.prompt_helper.tried_endpoints_with_params[endpoint]: + return full_url + + if urls_with_params == None: + return "no params" + return random.choice(urls_with_params) + + def adjust_path(self, response, move_type): + """ + Adjusts the response action path based on current step, unsuccessful paths, and move type. + + Args: + response (Any): The HTTP response object containing the action and path. + move_type (str): The type of move (e.g., 'exploit') influencing path adjustment. + + Returns: + Any: The updated response object with an adjusted path. + """ + old_path = response.action.path + + if "?" not in response.action.path and self.prompt_helper.current_step == 6: + if response.action.path not in self.prompt_helper.saved_endpoints: + if response.action.query is not None: + return response + # Process action if it's not RecordNote + if response.action.__class__.__name__ != "RecordNote": + if self.prompt_helper.current_step == 6 : + response.action.path = self.create_common_query_for_endpoint(response.action.path) + + if response.action.path in self.prompt_helper.unsuccessful_paths: + self.repeat_counter += 1 + + if self.no_action_counter == 5: + response.action.path = self.get_next_path(response.action.path) + self.no_action_counter = 0 + parts = response.action.path.split("/") + len_path = len([part.strip() for part in parts if part.strip()]) + if self.prompt_helper.current_step == 2: + if len_path <2 or len_path > 2 or response.action.path in self.prompt_helper.unsuccessful_paths: + id = self.prompt_helper.get_possible_id_for_instance_level_ep(parts[0]) + if id: + response.action.path = parts[0] + f"/{id}" + else: + if self.prompt_helper.current_step != 6 and not response.action.path.endswith("?"): + adjusted_path = self.adjust_path_if_necessary(response.action.path) + if adjusted_path != None: + response.action.path = adjusted_path + + if move_type == "exploit" and self.repeat_counter == 3: + if len(self.prompt_helper.endpoints_to_try) != 0: + exploit_endpoint = self.prompt_helper.endpoints_to_try[0] + response.action.path = self.create_common_query_for_endpoint(exploit_endpoint) + else: + exploit_endpoint = self.prompt_helper._get_instance_level_endpoint(self.name) + self.repeat_counter = 0 + + if exploit_endpoint and response.action.path not in self.prompt_helper._get_instance_level_endpoints(self.name): + response.action.path = exploit_endpoint + if move_type != "exploit": + response.action.method = "GET" + + if response.action.path == None: + response.action.path = old_path + + return response + + def check_if_successful(self, is_successful, request_path, result_dict, result_str, categorized_endpoints): + if is_successful: + if "?" in request_path and request_path not in self.prompt_helper.found_query_endpoints: + self.prompt_helper.found_query_endpoints.append(request_path) + ep = request_path.split("?")[0] + if ep in self.prompt_helper.endpoints_to_try: + self.prompt_helper.endpoints_to_try.remove(ep) + if ep in self.saved_endpoints: + self.saved_endpoints[1].remove(ep) + if ep in self.prompt_helper.saved_endpoints: + self.prompt_helper.saved_endpoints.remove(ep) + if ep not in self.prompt_helper.found_endpoints: + self.prompt_helper.found_endpoints.append(ep) + + self.prompt_helper.query_endpoints_params.setdefault(ep, []) + self.prompt_helper.tried_endpoints_with_params.setdefault(ep, []) + # ep = self.check_if_crypto(ep) + if ep not in self.prompt_helper.found_endpoints: + if "?" not in ep and ep not in self.prompt_helper.found_endpoints: + self.prompt_helper.found_endpoints.append(ep) + if "?" in ep and ep not in self.prompt_helper.found_query_endpoints: + self.prompt_helper.found_query_endpoints.append(ep) + + for key in self.extract_params(request_path): + if ep not in self.prompt_helper.query_endpoints_params: + self.prompt_helper.query_endpoints_params[ep] = [] + if ep not in self.prompt_helper.tried_endpoints_with_params: + self.prompt_helper.tried_endpoints_with_params[ep] = [] + self.prompt_helper.query_endpoints_params[ep].append(key) + self.prompt_helper.tried_endpoints_with_params[ep].append(key) + + status_message = f"{request_path} is a correct endpoint" + self.no_new_endpoint_counter= 0 + else: + error_msg = result_dict.get("error", {}).get("message", "unknown error") if isinstance( + result_dict.get("error", {}), dict) else result_dict.get("error", "unknown error") + self.no_new_endpoint_counter +=1 + if error_msg == "unknown error" and (result_str.startswith("4") or result_str.startswith("5")): + error_msg = result_str + + if result_str.startswith("400") or result_str.startswith("401") or result_str.startswith("403"): + status_message = f"{request_path} is a correct endpoint, but encountered an error: {error_msg}" + self.prompt_helper.endpoints_to_try.append(request_path) + self.prompt_helper.bad_request_endpoints.append(request_path) + self.save_endpoint(request_path) + if request_path not in self.prompt_helper.saved_endpoints: + self.prompt_helper.saved_endpoints.append(request_path) + + if error_msg not in self.prompt_helper.correct_endpoint_but_some_error: + self.prompt_helper.correct_endpoint_but_some_error[error_msg] = [] + self.prompt_helper.correct_endpoint_but_some_error[error_msg].append(request_path) + else: + self.prompt_helper.unsuccessful_paths.append(request_path) + status_message = f"{request_path} is not a correct endpoint; Reason: {error_msg}" + + ep = request_path.split("?")[0] + self.prompt_helper.tried_endpoints_with_params.setdefault(ep, []) + for key in self.extract_params(request_path): + self.prompt_helper.tried_endpoints_with_params[ep].append(key) + + # self.adjust_counter(categorized_endpoints) + + return status_message + + def check_if_crypto(self, path): + + # Default list of cryptos to detect + cryptos = ["bitcoin", "ethereum", "litecoin", "dogecoin", + "cardano", "solana"] + + # Convert to lowercase for the match, but preserve the original path for reconstruction if you prefer + lower_path = path.lower() + + + for crypto in cryptos: + if crypto in lower_path: + # Example approach: split by '/' and replace the segment that matches crypto + parts = path.split('/') + replaced_any = False + for i, segment in enumerate(parts): + if segment.lower() == crypto: + parts[i] = "{id}" + if segment.lower() == crypto: + parts[i] = "{id}" + replaced_any = True + if replaced_any: + return "/".join(parts) + + + return path \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py b/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py index 98781cbb..b6114132 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py @@ -1,175 +1,450 @@ +import os from dataclasses import field from typing import Dict +from rich.panel import Panel + from hackingBuddyGPT.capabilities import Capability from hackingBuddyGPT.capabilities.http_request import HTTPRequest from hackingBuddyGPT.capabilities.record_note import RecordNote from hackingBuddyGPT.usecases.agents import Agent from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case -from hackingBuddyGPT.usecases.web_api_testing.documentation.openapi_specification_handler import ( - OpenAPISpecificationHandler, -) -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptContext -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompt_engineer import PromptEngineer, PromptStrategy +from hackingBuddyGPT.usecases.web_api_testing.documentation.openapi_specification_handler import \ + OpenAPISpecificationHandler +from hackingBuddyGPT.utils.prompt_generation import PromptGenerationHelper +from hackingBuddyGPT.utils.prompt_generation.information import PromptContext +from hackingBuddyGPT.utils.prompt_generation.prompt_engineer import PromptEngineer from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_handler import ResponseHandler +from hackingBuddyGPT.usecases.web_api_testing.utils import LLMHandler +from hackingBuddyGPT.usecases.web_api_testing.utils.configuration_handler import ConfigurationHandler from hackingBuddyGPT.usecases.web_api_testing.utils.custom_datatypes import Context, Prompt -from hackingBuddyGPT.usecases.web_api_testing.utils.llm_handler import LLMHandler +from hackingBuddyGPT.usecases.web_api_testing.utils.evaluator import Evaluator from hackingBuddyGPT.utils.configurable import parameter from hackingBuddyGPT.utils.openai.openai_lib import OpenAILib class SimpleWebAPIDocumentation(Agent): """ - SimpleWebAPIDocumentation is an agent that documents REST APIs of a website by interacting with the APIs and - generating an OpenAPI specification. - - Attributes: - llm (OpenAILib): The language model to use for interaction. - host (str): The host URL of the website to test. - _prompt_history (Prompt): The history of prompts and responses. - _context (Context): The context containing notes. - _capabilities (Dict[str, Capability]): The capabilities of the agent. - _all_http_methods_found (bool): Flag indicating if all HTTP methods were found. - _http_method_description (str): Description for expected HTTP methods. - _http_method_template (str): Template to format HTTP methods in API requests. - _http_methods (str): Expected HTTP methods in the API. - """ + SimpleWebAPIDocumentation is an agent class for automating REST API documentation. + Attributes: + llm (OpenAILib): The language model interface used for prompt execution. + _prompt_history (Prompt): Internal history of prompts exchanged with the LLM. + _context (Context): Context information used by capabilities (e.g., notes). + _capabilities (Dict[str, Capability]): Dictionary of active tool capabilities (HTTP requests, notes, etc.). + config_path (str): Path to the configuration file for the API under test. + strategy_string (str): Serialized string representing the documentation strategy to apply. + _http_method_description (str): Description for identifying HTTP methods in responses. + _http_method_template (str): Template string for formatting HTTP methods. + _http_methods (str): Comma-separated list of expected HTTP methods. + explore_steps_done (bool): Flag to indicate if exploration steps are completed. + found_all_http_methods (bool): Flag indicating whether all HTTP methods have been found. + all_steps_done (bool): Flag to indicate whether the full documentation process is complete. + """ llm: OpenAILib - host: str = parameter(desc="The host to test", default="https://jsonplaceholder.typicode.com") _prompt_history: Prompt = field(default_factory=list) _context: Context = field(default_factory=lambda: {"notes": list()}) _capabilities: Dict[str, Capability] = field(default_factory=dict) _all_http_methods_found: bool = False + config_path: str = parameter( + desc="Configuration file path", + default="", + ) + + strategy_string: str = parameter( + desc="strategy string", + default="", + ) + + prompt_file: str = parameter( + desc="prompt file name", + default="", + ) + - # Description for expected HTTP methods _http_method_description: str = parameter( desc="Pattern description for expected HTTP methods in the API response", default="A string that represents an HTTP method (e.g., 'GET', 'POST', etc.).", ) - - # Template for HTTP methods in API requests _http_method_template: str = parameter( desc="Template to format HTTP methods in API requests, with {method} replaced by actual HTTP method names.", default="{method}", ) - - # List of expected HTTP methods _http_methods: str = parameter( desc="Expected HTTP methods in the API, as a comma-separated list.", default="GET,POST,PUT,PATCH,DELETE", ) + def init(self): - """Initializes the agent with its capabilities and handlers.""" + """Initialize the agent with configurations, capabilities, and handlers.""" super().init() + self.explore_steps_done = False + self.found_all_http_methods = False + self.all_steps_done = False + + + config_handler = ConfigurationHandler(self.config_path, self.strategy_string) + config, self.strategy = config_handler.load() + token, self.host, description, self._correct_endpoints, query_params = config_handler._extract_config_values(config) + + self.categorized_endpoints = self.categorize_endpoints(self._correct_endpoints, query_params) + self._setup_capabilities() - self.llm_handler = LLMHandler(self.llm, self._capabilities) - self.response_handler = ResponseHandler(self.llm_handler) - self._setup_initial_prompt() - self.documentation_handler = OpenAPISpecificationHandler(self.llm_handler, self.response_handler) + self._prompt_context = PromptContext.DOCUMENTATION + name, initial_prompt = self._setup_initial_prompt(description=description) + self._initialize_handlers(config=config, description=description, token=token, name=name, + initial_prompt=initial_prompt) + - def _setup_capabilities(self): - """Sets up the capabilities for the agent.""" - notes = self._context["notes"] - self._capabilities = {"http_request": HTTPRequest(self.host), "record_note": RecordNote(notes)} - def _setup_initial_prompt(self): - """Sets up the initial prompt for the agent.""" + def _setup_initial_prompt(self, description: str): + """ + Configures the initial prompt for the API documentation process. + + This prompt provides system-level instructions to the LLM, guiding it to start documenting + the REST API from scratch using an empty OpenAPI specification. + + Args: + description (str): A short description of the website or service being documented. + + Returns: + tuple: + - str: The base project name, extracted from the config file name. + - dict: The initial prompt dictionary to be added to the prompt history. + """ initial_prompt = { "role": "system", - "content": f"You're tasked with documenting the REST APIs of a website hosted at {self.host}. " - f"Start with an empty OpenAPI specification.\n" - f"Maintain meticulousness in documenting your observations as you traverse the APIs.", + "content": ( + f"You're tasked with documenting the REST APIs of a website hosted at {self.host}. " + f"The website is {description}. Start with an empty OpenAPI specification and be meticulous in " + f"documenting your observations as you traverse the APIs." + ), } + + base_name = os.path.basename(self.config_path) + + # Split the base name by '_config' and take the first part + name = base_name.split('_config')[0] + + self.prompt_helper = PromptGenerationHelper(self.host, description) + return name, initial_prompt + + def _initialize_handlers(self, config, description, token, name, initial_prompt): + """ + Initializes the core handler components required for API exploration and documentation. + + This includes setting up: + - Capabilities such as HTTP request execution. + - LLM interaction handler. + - Response handling and OpenAPI documentation logic. + - Prompt engineering strategy. + - Evaluator for judging API test or doc performance. + + Args: + config (dict): Configuration dictionary containing API setup options. + description (str): Description of the target API or web service. + token (str): Authorization token (if any) to be used for API interaction. + name (str): Base name of the current documentation session. + initial_prompt (dict): Initial system prompt for the LLM. + """ + self.all_capabilities = { + "http_request": HTTPRequest(self.host)} + self._llm_handler = LLMHandler(self.llm, self._capabilities, all_possible_capabilities=self.all_capabilities) + + self._response_handler = ResponseHandler(llm_handler=self._llm_handler, prompt_context=self._prompt_context, + prompt_helper=self.prompt_helper, config=config) + self._documentation_handler = OpenAPISpecificationHandler( + self._llm_handler, self._response_handler, self.strategy, self.host, description, name + ) + self._prompt_history.append(initial_prompt) - handlers = (self.llm_handler, self.response_handler) - self.prompt_engineer = PromptEngineer( - strategy=PromptStrategy.CHAIN_OF_THOUGHT, - history=self._prompt_history, - handlers=handlers, + + self._prompt_engineer = PromptEngineer( + strategy=self.strategy, context=PromptContext.DOCUMENTATION, - rest_api=self.host, + prompt_helper=self.prompt_helper, + open_api_spec=self._documentation_handler.openapi_spec, + rest_api_info=(token, self.host, self._correct_endpoints, self.categorized_endpoints), + prompt_file=self.prompt_file ) + self._evaluator = Evaluator(config=config) + + def categorize_endpoints(self, endpoints, query: dict): - def all_http_methods_found(self, turn): """ - Checks if all expected HTTP methods have been found. + Categorizes a list of API endpoints based on their path depth and structure. + + Endpoints are grouped into categories such as root-level, instance-level, subresources, + and multi-level/related resources. Useful for prioritizing exploration and testing. + + Args: + endpoints (list[str]): A list of API endpoint paths. + query (dict): Dictionary of query parameters to associate with the categorized endpoints. - Args: - turn (int): The current turn number. + Returns: + dict: A dictionary containing categorized endpoint groups: + - "root_level": Endpoints like `/users` + - "instance_level": Endpoints with one ID parameter like `/users/{id}` + - "subresource": Direct subpaths without ID + - "related_resource": Nested resources with an ID in the middle like `/users/{id}/posts` + - "multi-level_resource": Deeper or complex nested resources + - "query": Query parameter values from the input + """ + root_level = [] + single_parameter = [] + subresource = [] + related_resource = [] + multi_level_resource = [] - Returns: - bool: True if all HTTP methods are found, False otherwise. + for endpoint in endpoints: + # Split the endpoint by '/' and filter out empty strings + parts = [part for part in endpoint.split('/') if part] + + # Determine the category based on the structure + if len(parts) == 1: + root_level.append(endpoint) + elif len(parts) == 2: + if "id" in endpoint: + single_parameter.append(endpoint) + else: + subresource.append(endpoint) + elif len(parts) == 3: + if "id" in endpoint: + related_resource.append(endpoint) + else: + multi_level_resource.append(endpoint) + else: + multi_level_resource.append(endpoint) + + return { + "root_level": root_level, + "instance_level": single_parameter, + "subresource": subresource, + "query": query.values(), + "related_resource": related_resource, + "multi-level_resource": multi_level_resource, + } + + + + def _setup_capabilities(self): """ - found_endpoints = sum(len(value_list) for value_list in self.documentation_handler.endpoint_methods.values()) - expected_endpoints = len(self.documentation_handler.endpoint_methods.keys()) * 4 - print(f"found methods:{found_endpoints}") - print(f"expected methods:{expected_endpoints}") - if ( - found_endpoints > 0 - and (found_endpoints == expected_endpoints) - or turn == 20 - and found_endpoints > 0 - and (found_endpoints == expected_endpoints) - ): - return True - return False + Initializes the LLM agent's capabilities for interacting with the API. + + This sets up tool wrappers that the language model can call, such as: + - `http_request`: For performing HTTP calls against the target API. + - `record_note`: For storing observations, notes, or documentation artifacts. + + Side Effects: + - Populates `self._capabilities` with callable tools used during exploration and documentation. + """ + """Initializes agent's capabilities for API documentation.""" + self._capabilities = { + "http_request": HTTPRequest(self.host), + "record_note": RecordNote(self._context["notes"]) + } - def perform_round(self, turn: int): + def all_http_methods_found(self, turn: int) -> bool: """ - Performs a round of API documentation. + Checks whether all expected HTTP methods (GET, POST, PUT, DELETE) have been discovered + for each endpoint by the documentation engine. - Args: - turn (int): The current turn number. + Args: + turn (int): The current execution round or step index. - Returns: - bool: True if all HTTP methods are found, False otherwise. + Returns: + bool: True if all methods are found and all exploration steps are complete, False otherwise. + + Side Effects: + - Sets `self.found_all_http_methods` to True if conditions are met. + """ + + found_count = sum(len(endpoints) for endpoints in self._documentation_handler.endpoint_methods.values()) + expected_count = len(self._documentation_handler.endpoint_methods.keys()) * 4 + if found_count >= len(self._correct_endpoints) and self.all_steps_done: + self.found_all_http_methods = True + return self.found_all_http_methods + + def perform_round(self, turn: int) -> bool: """ - if turn == 1: - counter = 0 - new_endpoint_found = 0 - while counter <= new_endpoint_found + 2 and counter <= 10: - self.run_documentation(turn, "explore") - counter += 1 - if len(self.documentation_handler.endpoint_methods) > new_endpoint_found: - new_endpoint_found = len(self.documentation_handler.endpoint_methods) - elif turn == 20: - while len(self.prompt_engineer.prompt_helper.get_endpoints_needing_help()) != 0: - self.run_documentation(turn, "exploit") + Executes a round of the API documentation loop based on the current turn number. + + The method selects between exploration and exploitation modes: + - Turns 1–18: Run exploration (`_explore_mode`) + - Turn 19: Switch to exploit mode until all endpoints are fully documented + - Turn 20+: Resume exploration for completeness + + Args: + turn (int): The current iteration index in the documentation process. + + Returns: + bool: True if all HTTP methods have been discovered by the end of the round. + """ + + if turn <= 18: + self._explore_mode(turn) + elif turn <= 19: + self._exploit_until_no_help_needed(turn) else: - self.run_documentation(turn, "exploit") + self._explore_mode(turn) + return self.all_http_methods_found(turn) - def has_no_numbers(self, path): + def _explore_mode(self, turn: int) -> None: """ - Checks if the path contains no numbers. + Executes the exploration phase for a documentation round. + + In this mode, the agent probes new API endpoints, extracts metadata, + and updates its OpenAPI spec. The process continues until: + - No new endpoints are discovered for several steps. + - A maximum exploration depth is reached. + - All HTTP methods are found. - Args: - path (str): The path to check. + Args: + turn (int): The current round number to be logged and used for prompt context. + """ - Returns: - bool: True if the path contains no numbers, False otherwise. + last_endpoint_found_x_steps_ago, new_endpoint_count = 0, len(self._documentation_handler.endpoint_methods) + last_found_endpoints = len(self._prompt_engineer.prompt_helper.found_endpoints) + + while ( + last_endpoint_found_x_steps_ago <= new_endpoint_count + 5 + and last_endpoint_found_x_steps_ago <= 10 + and not self.found_all_http_methods + ): + if self.explore_steps_done : + self.run_documentation(turn, "exploit") + else: + self.run_documentation(turn, "explore") + current_count = len(self._prompt_engineer.prompt_helper.found_endpoints) + last_endpoint_found_x_steps_ago = last_endpoint_found_x_steps_ago + 1 if current_count == last_found_endpoints else 0 + last_found_endpoints = current_count + if (updated_count := len(self._documentation_handler.endpoint_methods)) > new_endpoint_count: + new_endpoint_count = updated_count + self._prompt_engineer.open_api_spec = self._documentation_handler.openapi_spec + + def _exploit_until_no_help_needed(self, turn: int) -> None: """ - return not any(char.isdigit() for char in path) + Repeatedly performs exploit mode to gather deeper documentation details + for endpoints flagged as needing further clarification. - def run_documentation(self, turn, move_type): + This runs until all such endpoints are fully explained by the LLM agent. + + Args: + turn (int): Current round number, passed to `run_documentation()` for tracking. + + """ + while self._prompt_engineer.prompt_helper.get_endpoints_needing_help(): + self.run_documentation(turn, "exploit") + self._prompt_engineer.open_api_spec = self._documentation_handler.openapi_spec + + def _single_exploit_run(self, turn: int) -> None: """ - Runs the documentation process for a given turn and move type. + Performs a single exploit pass to extract more precise documentation + for endpoints or parameters that may have been incompletely parsed. + + Args: + turn (int): Current step number for context. + + """ + self.run_documentation(turn, "exploit") + self._prompt_engineer.open_api_spec = self._documentation_handler.openapi_spec - Args: - turn (int): The current turn number. - move_type (str): The move type ('explore' or 'exploit'). + def has_no_numbers(self, path: str) -> bool: """ - prompt = self.prompt_engineer.generate_prompt(turn, move_type) - response, completion = self.llm_handler.call_llm(prompt) - self.log, self._prompt_history, self.prompt_engineer = self.documentation_handler.document_response( - completion, response, self.log, self._prompt_history, self.prompt_engineer - ) + Checks whether a given API path contains any numeric characters. + + This is useful for detecting generic vs. instance-level paths (e.g., `/users` vs. `/users/123`). + + Args: + path (str): The API path to analyze. + + Returns: + bool: True if the path contains no digits, False otherwise. + """ + return not any(char.isdigit() for char in path) + + def run_documentation(self, turn: int, move_type: str) -> None: + """ + Runs a full documentation interaction cycle with the LLM agent for the given turn and mode. + + This method forms the core of the documentation loop. It generates prompts, interacts with + the LLM to simulate API calls, handles responses, updates the OpenAPI spec, and determines + when to advance exploration or exploitation steps based on multiple heuristics. + + Args: + turn (int): The current turn index (used for context and state progression). + move_type (str): Either `"explore"` or `"exploit"`, determining the type of documentation logic used. + + """ + is_good = False + counter = 0 + while not is_good: + prompt = self._prompt_engineer.generate_prompt(turn=turn, move_type=move_type, + prompt_history=self._prompt_history) + response, completion = self._llm_handler.execute_prompt_with_specific_capability(prompt,"http_request" ) + self.log.console.print(Panel(prompt[-1]["content"], title="system")) + + is_good, self._prompt_history, result, result_str = self._response_handler.handle_response(response, + completion, + self._prompt_history, + self.log, + self.categorized_endpoints, + move_type) + + if result == None or "Could not request" in result: + continue + self._prompt_history, self._prompt_engineer = self._documentation_handler.document_response( + result, response, result_str, self._prompt_history, self._prompt_engineer + ) + self.prompt_helper.endpoint_examples = self._documentation_handler.endpoint_examples + + if self._prompt_engineer.prompt_helper.current_step == 7 and move_type == "explore": + is_good = True + self.prompt_helper.current_step += 1 + self._response_handler.query_counter = 0 + if self._prompt_engineer.prompt_helper.current_step == 2 and len(self.prompt_helper._get_instance_level_endpoints("")) ==0: + is_good = True + self.prompt_helper.current_step += 1 + self._response_handler.query_counter = 0 + + + if self._response_handler.query_counter == 600 and self.prompt_helper.current_step == 6: + is_good = True + self.explore_steps_done = True + self.prompt_helper.current_step += 1 + self._response_handler.query_counter = 0 + + if move_type == "exploit" : + if self._response_handler.query_counter >= 50 : + is_good = True + self.all_steps_done = True + + if self._prompt_engineer.prompt_helper.current_step < 6 and self._response_handler.no_new_endpoint_counter >30: + is_good = True + self._response_handler.no_new_endpoint_counter = 0 + self.prompt_helper.current_step += 1 + self._response_handler.query_counter = 0 + + if self._prompt_engineer.prompt_helper.current_step < 6 and self._response_handler.query_counter > 200: + is_good = True + self.prompt_helper.current_step += 1 + self._response_handler.query_counter = 0 + + counter = counter + 1 + self.prompt_helper.found_endpoints = list(set(self._prompt_engineer.prompt_helper.found_endpoints)) + + self._evaluator.evaluate_response(response, self._prompt_engineer.prompt_helper.found_endpoints, self.prompt_helper.current_step, + self.prompt_helper.found_query_endpoints) + + self._evaluator.finalize_documentation_metrics( + file_path=self._documentation_handler.file.split(".yaml")[0] + ".txt") + + self.all_http_methods_found(turn) @use_case("Minimal implementation of a web API testing use case") class SimpleWebAPIDocumentationUseCase(AutonomousAgentUseCase[SimpleWebAPIDocumentation]): """Use case for the SimpleWebAPIDocumentation agent.""" - pass diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py b/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py index 6aff0267..9dc6773c 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py @@ -1,4 +1,7 @@ +import copy +import json import os.path +import re from dataclasses import field from typing import Any, Dict, List @@ -7,22 +10,31 @@ from hackingBuddyGPT.capabilities import Capability from hackingBuddyGPT.capabilities.http_request import HTTPRequest +from hackingBuddyGPT.capabilities.parsed_information import ParsedInformation +from hackingBuddyGPT.capabilities.python_test_case import PythonTestCase from hackingBuddyGPT.capabilities.record_note import RecordNote from hackingBuddyGPT.usecases.agents import Agent from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case +from hackingBuddyGPT.utils.prompt_generation import PromptGenerationHelper +from hackingBuddyGPT.utils.prompt_generation.information import PenTestingInformation +from hackingBuddyGPT.utils.prompt_generation.information import PromptPurpose from hackingBuddyGPT.usecases.web_api_testing.documentation.parsing import OpenAPISpecificationParser from hackingBuddyGPT.usecases.web_api_testing.documentation.report_handler import ReportHandler -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import PromptContext -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompt_engineer import PromptEngineer, PromptStrategy +from hackingBuddyGPT.utils.prompt_generation.information import PromptContext +from hackingBuddyGPT.utils.prompt_generation.prompt_engineer import PromptEngineer +from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_analyzer_with_llm import \ + ResponseAnalyzerWithLLM from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_handler import ResponseHandler +from hackingBuddyGPT.usecases.web_api_testing.testing.test_handler import GenerationTestHandler +from hackingBuddyGPT.usecases.web_api_testing.utils.configuration_handler import ConfigurationHandler from hackingBuddyGPT.usecases.web_api_testing.utils.custom_datatypes import Context, Prompt from hackingBuddyGPT.usecases.web_api_testing.utils.llm_handler import LLMHandler from hackingBuddyGPT.utils import tool_message from hackingBuddyGPT.utils.configurable import parameter from hackingBuddyGPT.utils.openai.openai_lib import OpenAILib + # OpenAPI specification file path -openapi_spec_filename = "/home/diana/Desktop/masterthesis/00/hackingBuddyGPT/src/hackingBuddyGPT/usecases/web_api_testing/utils/openapi_spec/openapi_spec_2024-08-16_14-14-07.yaml" class SimpleWebAPITesting(Agent): @@ -38,43 +50,103 @@ class SimpleWebAPITesting(Agent): _prompt_history (Prompt): The history of prompts sent to the language model. _context (Context): Contextual data for the test session. _capabilities (Dict[str, Capability]): Available capabilities for the agent. - _all_http_methods_found (bool): Flag indicating if all HTTP methods have been found. + _all_test_cases_run (bool): Flag indicating if all HTTP methods have been found. """ llm: OpenAILib host: str = parameter(desc="The host to test", default="https://jsonplaceholder.typicode.com") - http_method_description: str = parameter( - desc="Pattern description for expected HTTP methods in the API response", - default="A string that represents an HTTP method (e.g., 'GET', 'POST', etc.).", - ) - http_method_template: str = parameter( - desc="Template used to format HTTP methods in API requests. The {method} placeholder will be replaced by actual HTTP method names.", - default="{method}", + config_path: str = parameter( + desc="Configuration file path", + default="", ) - http_methods: str = parameter( - desc="Comma-separated list of HTTP methods expected to be used in the API response.", - default="GET,POST,PUT,DELETE", + + strategy_string: str = parameter( + desc="strategy string", + default="", ) + _http_method_description: str = parameter( + desc="Pattern description for expected HTTP methods in the API response", + default="A string that represents an HTTP method (e.g., 'GET', 'POST', etc.).", + ) _prompt_history: Prompt = field(default_factory=list) - _context: Context = field(default_factory=lambda: {"notes": list()}) + _context: Context = field(default_factory=lambda: {"notes": list(), "test_cases": list(), "parsed": list()}) _capabilities: Dict[str, Capability] = field(default_factory=dict) - _all_http_methods_found: bool = False + _all_test_cases_run: bool = False + + def init(self): + super().init() + configuration_handler = ConfigurationHandler(self.config_path, self.strategy_string) + self.config, self.strategy = configuration_handler.load() + self.token, self.host, self.description, self.correct_endpoints, self.query_params = configuration_handler._extract_config_values( + self.config) + self._load_openapi_specification() + self._setup_environment() + self._setup_handlers() + self._setup_initial_prompt() + self.last_prompt = "" - def init(self) -> None: + def _load_openapi_specification(self): """ - Initializes the SimpleWebAPITesting use case by setting up the context, response handler, - LLM handler, capabilities, and the initial prompt. + Loads the OpenAPI specification from the configured file path. + + If the config path exists, it initializes the `OpenAPISpecificationParser` and stores both + the parser instance and the parsed OpenAPI spec data. + """ + if os.path.exists(self.config_path): + self._openapi_specification_parser = OpenAPISpecificationParser(self.config_path) + self._openapi_specification = self._openapi_specification_parser.api_data + + def _setup_environment(self): """ - super().init() - if os.path.exists(openapi_spec_filename): - self._openapi_specification: Dict[str, Any] = OpenAPISpecificationParser(openapi_spec_filename).api_data + Initializes core environment context for API testing or exploration. + + This includes: + - Setting the target host. + - Configuring capabilities. + - Categorizing endpoints based on relevance and available query parameters. + - Setting the prompt context to `PromptContext.PENTESTING`. + """ self._context["host"] = self.host self._setup_capabilities() - self._llm_handler: LLMHandler = LLMHandler(self.llm, self._capabilities) - self._response_handler: ResponseHandler = ResponseHandler(self._llm_handler) - self._report_handler: ReportHandler = ReportHandler() - self._setup_initial_prompt() + self.categorized_endpoints = self._openapi_specification_parser.categorize_endpoints(self.correct_endpoints, + self.query_params) + self.prompt_context = PromptContext.PENTESTING + + def _setup_handlers(self): + """ + Sets up all core internal components and handlers required for API testing. + + This includes: + - LLM handler for prompt execution and capability routing. + - Prompt helper for managing request state and prompt logic. + - Pentesting information tracker to hold user/resource data and API config. + - Response handler for parsing and reacting to tool responses. + - Response analyzer powered by LLMs for deeper inspection. + - Reporting handler to track and export findings. + - Test case handler for saving and generating test cases. + + If username and password are not found in the config, defaults are used. + """ + self._llm_handler = LLMHandler(self.llm, self._capabilities, all_possible_capabilities=self.all_capabilities) + self.prompt_helper = PromptGenerationHelper(self.host, self.description) + if "username" in self.config.keys() and "password" in self.config.keys(): + username = self.config.get("username") + password = self.config.get("password") + else: + username = "test" + password = "" + self.pentesting_information = PenTestingInformation(self._openapi_specification_parser, self.config) + self._response_handler = ResponseHandler( + llm_handler=self._llm_handler, prompt_context=self.prompt_context, prompt_helper=self.prompt_helper, + config=self.config, pentesting_information=self.pentesting_information) + self.response_analyzer = ResponseAnalyzerWithLLM(llm_handler=self._llm_handler, + pentesting_info=self.pentesting_information, + capacity=self.parse_capacity, + prompt_helper=self.prompt_helper) + self._response_handler.set_response_analyzer(self.response_analyzer) + self._report_handler = ReportHandler(self.config) + self._test_handler = GenerationTestHandler(self._llm_handler) def _setup_initial_prompt(self) -> None: """ @@ -88,30 +160,28 @@ def _setup_initial_prompt(self) -> None: f"Your primary objective is to thoroughly explore the APIs, understanding their endpoints, parameters, and responses, " f"and then devise and execute potential exploits. Be diligent in documenting your findings as you navigate through the APIs. " f"Avoid resorting to brute-force tactics. All necessary information can be obtained from the API endpoints themselves. " - f"Remember, if you encounter an HTTP method ({self.http_method_description}), promptly submit it as it is of utmost importance." + f"Remember, if you encounter an HTTP method ({self._http_method_description}), promptly submit it as it is of utmost importance." ), } self._prompt_history.append(initial_prompt) - handlers = (self._llm_handler, self._response_handler) - schemas: Dict[str, Any] = ( - self._openapi_specification["components"]["schemas"] if os.path.exists(openapi_spec_filename) else {} - ) - self.prompt_engineer: PromptEngineer = PromptEngineer( - strategy=PromptStrategy.CHAIN_OF_THOUGHT, - history=self._prompt_history, - handlers=handlers, + + self.prompt_engineer = PromptEngineer( + strategy=self.strategy, context=PromptContext.PENTESTING, - rest_api=self.host, - schemas=schemas, + open_api_spec=self._openapi_specification, + rest_api_info=(self.token, self.description, self.correct_endpoints, self.categorized_endpoints), + prompt_helper=self.prompt_helper ) + self.prompt_engineer.set_pentesting_information(self.pentesting_information) + self.purpose = self.pentesting_information.pentesting_step_list[0] - def all_http_methods_found(self) -> None: + def all_test_cases_run(self) -> None: """ Handles the event when all HTTP methods are found. Displays a congratulatory message and sets the _all_http_methods_found flag to True. """ - self.log.console.print(Panel("All HTTP methods found! Congratulations!", title="system")) - self._all_http_methods_found = True + self.log.console.print(Panel("All test cases run!", title="system")) + self._all_test_cases_run = True def _setup_capabilities(self) -> None: """ @@ -119,15 +189,18 @@ def _setup_capabilities(self) -> None: note recording capabilities, and HTTP method submission capabilities based on the provided configuration. """ - methods_set: set[str] = { - self.http_method_template.format(method=method) for method in self.http_methods.split(",") - } notes: List[str] = self._context["notes"] + parsed: List[str] = self._context["parsed"] + test_cases = self._context["test_cases"] + self.python_test_case_capability = {"python_test_case": PythonTestCase(test_cases)} + self.parse_capacity = {"parse": ParsedInformation(test_cases)} self._capabilities = { - "submit_http_method": HTTPRequest(self.host), - "http_request": HTTPRequest(self.host), - "record_note": RecordNote(notes), - } + "http_request": HTTPRequest(self.host)} + self.all_capabilities = {"python_test_case": PythonTestCase(test_cases), "parse": ParsedInformation(test_cases), + "http_request": HTTPRequest(self.host), + "record_note": RecordNote(notes)} + self.http_capability = {"http_request": HTTPRequest(self.host), + } def perform_round(self, turn: int) -> None: """ @@ -137,13 +210,30 @@ def perform_round(self, turn: int) -> None: Args: turn (int): The current round number. """ - prompt = self.prompt_engineer.generate_prompt(turn) + self._perform_prompt_generation(turn) + if len(self.prompt_engineer.pentesting_information.pentesting_step_list) == 0: + self.all_test_cases_run() + return + if turn == 20: + self._report_handler.save_report() + + def _perform_prompt_generation(self, turn: int) -> None: response: Any completion: Any - response, completion = self._llm_handler.call_llm(prompt) - self._handle_response(completion, response, self.prompt_engineer.purpose) + while self.purpose == self.prompt_engineer._purpose and not self._all_test_cases_run: + prompt = self.prompt_engineer.generate_prompt(turn=turn, move_type="explore", + prompt_history=self._prompt_history) - def _handle_response(self, completion: Any, response: Any, purpose: str) -> None: + response, completion = self._llm_handler.execute_prompt_with_specific_capability(prompt, "http_request") + self._handle_response(completion, response) + if len(self.prompt_engineer.pentesting_information.pentesting_step_list) == 0: + self.all_test_cases_run() + return + + self.purpose = self.prompt_engineer._purpose + + + def _handle_response(self, completion: Any, response: Any) -> None: """ Handles the response from the LLM. Parses the response, executes the necessary actions, and updates the prompt history. @@ -153,25 +243,304 @@ def _handle_response(self, completion: Any, response: Any, purpose: str) -> None response (Any): The response object from the LLM. purpose (str): The purpose or intent behind the response handling. """ + with self.log.console.status("[bold green]Executing that command..."): + if response is None: + return + + + response = self.adjust_action(response) + + result = self.execute_response(response, completion) + + self._report_handler.write_vulnerability_to_report(self.prompt_helper.current_sub_step, + self.prompt_helper.current_test_step, result, + self.prompt_helper.counter) + + analysis, status_code = self._response_handler.evaluate_result( + result=result, + prompt_history=self._prompt_history, + analysis_context=self.prompt_engineer.prompt_helper.current_test_step) + + if self.purpose != PromptPurpose.SETUP: + self._prompt_history = self._test_handler.generate_test_cases( + analysis=analysis, + endpoint=response.action.path, + method=response.action.method, + body=response.action.body, + prompt_history=self._prompt_history, status_code=status_code) + + self._report_handler.write_analysis_to_report(analysis=analysis, purpose=self.prompt_engineer._purpose) + + def extract_ids(self, data, id_resources=None, parent_key=''): + """ + Recursively extracts all string-based identifiers (IDs) from a nested data structure. + + This method traverses a deeply nested dictionary or list (e.g., a parsed JSON response) + and collects all keys that contain `"id"` and have string values. It organizes these IDs + into a dictionary grouped by normalized resource categories based on the key names. + + Args: + data (Union[dict, list]): The input data structure (e.g., API response) to search for IDs. + id_resources (dict, optional): A dictionary used to accumulate found IDs, grouped by category. + If None, a new dictionary is initialized. + parent_key (str, optional): The key path used for context when processing nested fields. + + Returns: + dict: A dictionary where keys are derived categories (e.g., `"user_id"`, `"post_id"`) and + values are lists of extracted ID strings. + + """ + if id_resources is None: + id_resources = {} + if isinstance(data, dict): + for key, value in data.items(): + # Update the key to reflect nested structures + new_key = f"{parent_key}.{key}" if parent_key else key + if 'id' in key and isinstance(value, str): + # Determine the category based on the key name before 'id' + category = key.replace('id', '').rstrip('_').lower() # Normalize the key + if category == '': # If no specific category, it could just be 'id' + category = parent_key.split('.')[-1] # Use parent key as category + category = category.rstrip('s') # Singular form for consistency + if category != "id": + category = category + "_id" + + if category in id_resources: + id_resources[category].append(value) + else: + id_resources[category] = [value] + else: + # Recursively search for ids within nested dictionaries or lists + self.extract_ids(value, id_resources, new_key) + + # If the data is a list, apply the function recursively to each item + elif isinstance(data, list): + for index, item in enumerate(data): + self.extract_ids(item, id_resources, f"{parent_key}[{index}]") + + return id_resources + + def extract_resource_name(self, path: str) -> str: + """ + Extracts the key resource word from a path. + + Examples: + - '/identity/api/v2/user/videos/{video_id}' -> 'video' + - '/workshop/api/shop/orders/{order_id}' -> 'order' + - '/community/api/v2/community/posts/{post_id}/comment' -> 'comment' + """ + # Split into non-empty segments + parts = [p for p in path.split('/') if p] + if not parts: + return "" + + last_segment = parts[-1] + + # 1) If last segment is a placeholder like "{video_id}", return 'video' + # i.e., capture the substring before "_id". + match = re.match(r'^\{(\w+)_id\}$', last_segment) + if match: + return match.group(1) # e.g. 'video', 'order' + + # 2) Otherwise, if the last segment is a word like "videos" or "orders", + # strip a trailing 's' (e.g., "videos" -> "video"). + if last_segment.endswith('s'): + return last_segment[:-1] + + # 3) If it's just "comment" or a similar singular word, return as-is + return last_segment + + def extract_token_from_http_response(self, http_response): + """ + Extracts the token from an HTTP response body. + + Args: + http_response (str): The raw HTTP response as a string. + + Returns: + str: The extracted token if found, otherwise None. + """ + # Split the HTTP headers from the body + try: + headers, body = http_response.split("\r\n\r\n", 1) + except ValueError: + return None + + try: + # Parse the body as JSON + body_json = json.loads(body) + # Extract the token + if "token" in body_json.keys(): + return body_json["token"] + elif "authentication" in body_json.keys(): + return body_json.get("authentication", {}).get("token", None) + except json.JSONDecodeError: + # If the body is not valid JSON, return None + return None + + def save_resource(self, path, data): + """ + Saves a discovered API resource and its associated data to the current user context. + + This method extracts the resource name from the given API path (e.g., from `/users/1/posts` → `posts`), + then stores the provided `data` under that resource for the current user in `prompt_helper.current_user`. + + If the resource does not already exist in the user's data, it initializes it as an empty list. + It also updates the corresponding account entry in `pentesting_information.accounts` to ensure + consistency across known user accounts. + + Args: + path (str): The API endpoint path from which to extract the resource name. + data (Any): The resource data to be saved under the extracted resource name. + """ + resource = self.extract_resource_name(path) + if resource != "" and resource not in self.prompt_helper.current_user.keys(): + self.prompt_helper.current_user[resource] = [] + if data not in self.prompt_helper.current_user[resource]: + self.prompt_helper.current_user[resource].append(data) + for i, account in enumerate(self.prompt_helper.accounts): + if account.get("x") == self.prompt_helper.current_user.get("x"): + self.pentesting_information.accounts[i][resource] = self.prompt_helper.current_user[resource] + + def adjust_user(self, result): + """ + Adjusts the current user and pentesting state based on the contents of an HTTP response. + + This method parses the HTTP response into headers and body, and inspects the body for specific + keys such as `"key"`, `"posts"`, and `"id"` to update user-related data structures accordingly. + + Behavior: + - If the body contains `"html"`, the method returns early (assumed to be an invalid or non-JSON response). + - If `"key"` is found: + - Parses the body and updates the `"key"` field of the matching user in `prompt_helper.accounts`. + - If `"posts"` is found: + - Parses the body, extracts resource IDs, and updates `pentesting_information.resources`. + - If `"id"` is found and the current sub-step purpose is `PromptPurpose.SETUP`: + - Extracts the user ID from the body and stores it in the matching user account. + + Args: + result (str): The full HTTP response string including headers and body (separated by `\r\n\r\n`). + """ + if "Could not" in result: + return + headers, body = result.split("\r\n\r\n", 1) + if "html" in body: + return + + if "key" in body: + data = json.loads(body) + for account in self.prompt_helper.accounts: + if account.get("x") == self.prompt_helper.current_user.get("x"): + account["key"] = data.get("key") + if "posts" in body: + data = json.loads(body) + # Extract ids + id_resources = self.extract_ids(data) + if len(self.pentesting_information.resources) == 0: + self.pentesting_information.resources = id_resources + else: + self.pentesting_information.resources.update(id_resources) + + if "id" in body and self.prompt_helper.current_sub_step.get("purpose") == PromptPurpose.SETUP: + data = json.loads(body) + user_id = data.get('id') + for account in self.prompt_helper.accounts: + + if account.get("x") == self.prompt_helper.current_user.get("x"): + account["id"] = user_id + break + + def adjust_action(self, response: Any): + """ + Modifies the action of an API response object based on the current prompt context and configuration. + + This method is typically used during API test setup or fuzzing to: + - Modify the HTTP method (e.g., set to POST during setup). + - Inject authorization tokens into the request headers based on the API type (`vAPI`, `crapi`, etc.). + - Correct or override request paths and bodies with current user context. + - Optionally save resource data if the path contains identifiable parameters (e.g., `_id`). + + Args: + response (Any): The response object containing an `action` field (with `method`, `headers`, `path`, `body`, etc.). + + Returns: + Any: The updated response object with modified action values based on prompt context and configuration. + """ + old_response = copy.deepcopy(response) + if self.prompt_engineer._purpose == PromptPurpose.SETUP: + response.action.method = "POST" + + token = self.prompt_helper.current_sub_step.get("token") + if token is not None and "{{" in token: + for account in self.prompt_helper.accounts: + if account["x"] == self.prompt_helper.current_user["x"]: + token = account["token"] + break + if token and (token != "" or token is not None): + if self.config.get("name") == "vAPI": + response.action.headers = {"Authorization-Token": f"{token}"} + elif self.config.get("name") == "crapi": + response.action.headers = {"Authorization": f"Bearer {token}"} + + else: + + response.action.headers = {"Authorization-Token": f"Bearer {token}"} + + if response.action.path != self.prompt_helper.current_sub_step.get("path"): + response.action.path = self.prompt_helper.current_sub_step.get("path") + + if response.action.path and "_id}" in response.action.path: + if response.action.__class__.__name__ != "HTTPRequest": + self.save_resource(response.action.path, response.action.data) + + if isinstance(response.action.path, dict): + response.action.path = response.action.path.get("path") + + if response.action.body is None: + response.action.body = self.prompt_helper.current_user + + if response.action.path is None: + response.action.path = old_response.action.path + + return response + + def execute_response(self, response, completion): + """ + Executes the API response, logs it, and updates internal state for documentation and testing. + + This method performs the following actions: + - Converts the `response` object to JSON and prints it as an assistant message. + - Executes the response as a tool call (i.e., performs the API request). + - Logs and prints the tool response. + - If the result is not a string, it attempts to extract the endpoint name and write it to a report. + - Appends a tool message with key elements extracted from the result to the prompt history. + - Adjusts user-related state based on the result (e.g., tokens, user IDs). + - Prints the state of user accounts after the request for debugging. + + Args: + response (Any): The response object that encapsulates the tool call to be executed. + completion (Any): The LLM completion object, including metadata like the tool call ID. + + Returns: + Any: The result of executing the tool call (typically a string or structured object). + """ message = completion.choices[0].message tool_call_id: str = message.tool_calls[0].id command: str = pydantic_core.to_json(response).decode() self.log.console.print(Panel(command, title="assistant")) self._prompt_history.append(message) - with self.log.console.status("[bold green]Executing that command..."): - result: Any = response.execute() - self.log.console.print(Panel(result[:30], title="tool")) - if not isinstance(result, str): - endpoint: str = str(response.action.path).split("/")[1] - self._report_handler.write_endpoint_to_report(endpoint) - self._prompt_history.append(tool_message(str(result), tool_call_id)) - - analysis = self._response_handler.evaluate_result(result=result, prompt_history=self._prompt_history) - self._report_handler.write_analysis_to_report(analysis=analysis, purpose=self.prompt_engineer.purpose) - # self._prompt_history.append(tool_message(str(analysis), tool_call_id)) - - self.all_http_methods_found() + result: Any = response.execute() + self.log.console.print(Panel(result, title="tool")) + if not isinstance(result, str): + endpoint: str = str(response.action.path).split("/")[1] + self._report_handler.write_endpoint_to_report(endpoint) + + self._prompt_history.append( + tool_message(self._response_handler.extract_key_elements_of_response(result), tool_call_id)) + + self.adjust_user(result) + return result @use_case("Minimal implementation of a web API testing use case") diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/testing/__init__.py b/src/hackingBuddyGPT/usecases/web_api_testing/testing/__init__.py new file mode 100644 index 00000000..be3b5ebc --- /dev/null +++ b/src/hackingBuddyGPT/usecases/web_api_testing/testing/__init__.py @@ -0,0 +1 @@ +from .test_handler import GenerationTestHandler diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/testing/test_handler.py b/src/hackingBuddyGPT/usecases/web_api_testing/testing/test_handler.py new file mode 100644 index 00000000..b1ff44b3 --- /dev/null +++ b/src/hackingBuddyGPT/usecases/web_api_testing/testing/test_handler.py @@ -0,0 +1,246 @@ +import json +import os +import re +from datetime import datetime +from typing import Any, Dict, Tuple + + +class GenerationTestHandler: + """ + A class responsible for parsing, generating, and saving structured API test cases, + including generating pytest-compatible test functions using an LLM. + + Attributes: + _llm_handler: Handler to communicate with a language model (LLM). + test_path (str): Directory path for saving test case data. + file (str): Path to the file for saving structured test case data. + test_file (str): Path to the file for saving pytest test functions. + """ + + def __init__(self, llm_handler): + """ + Initializes the TestHandler with paths for saving generated test case data. + + Args: + llm_handler: LLM handler instance used for generating test logic from prompts. + """ + self._llm_handler = llm_handler + current_path = os.path.dirname(os.path.abspath(__file__)) + self.test_path = os.path.join(current_path, "tests", f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}") + os.makedirs(self.test_path, exist_ok=True) + + self.file = os.path.join(self.test_path, "test_cases.txt") + self.test_file = os.path.join(self.test_path, "python_test.py") + + def parse_test_case(self, note: str) -> Dict[str, Any]: + """ + Parses a text note into a structured test case dictionary. + + Args: + note (str): A human-readable note that describes the test case. + + Returns: + dict: A structured test case with description, input, and expected output. + """ + method_endpoint_pattern = re.compile(r"Test case for (\w+) (\/\S+):") + description_pattern = re.compile(r"Description: (.+)") + input_data_pattern = re.compile(r"Input Data: (\{.*\})") + expected_output_pattern = re.compile(r"Expected Output: (.+)") + + method_endpoint_match = method_endpoint_pattern.search(note) + if method_endpoint_match: + method, endpoint = method_endpoint_match.groups() + else: + raise ValueError("Method and endpoint not found in the note") + + description = description_pattern.search(note).group(1) if description_pattern.search( + note) else "No description found" + input_data = input_data_pattern.search(note).group(1) if input_data_pattern.search(note) else "{}" + expected_output = expected_output_pattern.search(note).group(1) if expected_output_pattern.search( + note) else "No expected output found" + + return { + "description": f"Test case for {method} {endpoint}", + "input": input_data, + "expected_output": expected_output + } + + def generate_test_case(self, analysis: str, endpoint: str, method: str, body:str, status_code: Any, prompt_history) -> Tuple[ + str, Dict[str, Any], list]: + """ + Uses LLM to generate a test case dictionary from analysis and test metadata. + + Args: + analysis (str): Textual analysis of API behavior. + endpoint (str): API endpoint. + method (str): HTTP method used. + status_code (Any): Expected HTTP status code. + prompt_history (list): History of prompts exchanged with the LLM. + + Returns: + tuple: Test case description, test case dictionary, and updated prompt history. + """ + prompt_text = f""" + Based on the following analysis of the API response, generate a detailed test case: + + Analysis: {analysis} + + Endpoint: {endpoint} + HTTP Method: {method} + + The test case should include: + - Description of the test. + - Example input data in JSON format. + - Expected result or assertion based on method and endpoint call. + + Format: + {{ + "description": "Test case for {method} {endpoint}", + "input": {body}, + "expected_output": {{"expected_body": body, "expected_status_code": status_code}} + }} + + return a PythonTestCase object + """ + prompt_history.append({"role": "system", "content": prompt_text}) + response, completion = self._llm_handler.execute_prompt_with_specific_capability(prompt_history, + capability="python_test_case") + test_case = response.execute() + + test_case["method"] = method + test_case["endpoint"] = endpoint + + return test_case["description"], test_case, prompt_history + + def write_test_case_to_file(self, description: str, test_case: Dict[str, Any]) -> None: + """ + Saves a structured test case to a text file. + + Args: + description (str): Description of the test. + test_case (dict): Test case dictionary. + """ + entry = { + "description": description, + "test_case": test_case + } + with open(self.file, "a") as f: + f.write(json.dumps(entry, indent=2) + "\n\n") + print(f"Test case written to {self.file}") + + def write_pytest_case(self, description: str, test_case: Dict[str, Any], prompt_history) -> list: + """ + Uses LLM to generate a pytest-compatible test function and saves it to a `.py` file. + + Args: + description (str): Description of the test case. + test_case (dict): Test case dictionary. + prompt_history (list): Prompt history for LLM context. + + Returns: + list: Updated prompt history. + """ + prompt = f""" + As a testing expert, you are tasked with creating pytest-compatible test functions using the Python 'requests' library (also import it). + + Test Details: + - Description: {description} + - Endpoint: {test_case['endpoint']} + - Method: {test_case['method'].upper()} + - Input: {json.dumps(test_case.get("input", {}), indent=4)} + - Expected Status: {test_case['expected_output'].get('expected_status_code')} + - Expected Body: {test_case['expected_output'].get('expected_body', {})} + + Instructions: + Write a syntactically and semantically correct pytest function that: + - Includes a docstring explaining the purpose of the test. + - Sends the appropriate HTTP request to the specified endpoint. + - Asserts the correctness of both the response status code and the response body. + + Test Function Name: + Use the description to create a meaningful and relevant test function name, following Python's naming conventions for functions. + + Example: + If the description is "Test for successful login", the function name could be 'test_successful_login'. + + Code Example: + def test_function_name(): + \"""Docstring describing the test purpose.\""" + response = requests.METHOD('http://example.com/api/endpoint', json={{"key": "value"}}) + assert response.status_code == 200 + assert response.json() == {{"expected": "output"}} + + Replace 'METHOD', 'http://example.com/api/endpoint', and other placeholders with actual data based on the test details provided.""" + + prompt_history.append({"role": "system", "content": prompt}) + response, completion = self._llm_handler.execute_prompt_with_specific_capability(prompt_history, "record_note") + result = response.execute() + + test_function = self.extract_pytest_from_string(result) + if test_function: + with open(self.test_file, "a") as f: + f.write(test_function) + print(f"Pytest case written to {self.test_file}") + + return prompt_history + + def extract_pytest_from_string(self, text: str) -> str: + """ + Extracts the first Python function definition from a string. + + Args: + text (str): Raw string potentially containing Python code. + + Returns: + str: Extracted function block, or None if not found. + """ + func_start = text.find("import ") + if func_start == -1: + func_start = text.find("def ") + if func_start == -1: + return None + + func_end = text.find("import ", func_start + 1) + if func_end == -1: + func_end = len(text) + + return text[func_start:func_end] + + def generate_test_cases(self, analysis: str, endpoint: str, method: str, body:str, status_code: Any, prompt_history) -> list: + """ + Generates and stores both JSON and Python test cases based on analysis. + + Args: + analysis (str): Analysis summary of the API behavior. + endpoint (str): API endpoint. + method (str): HTTP method. + status_code (Any): Expected status code. + prompt_history (list): Prompt history. + + Returns: + list: Updated prompt history. + """ + description, test_case, prompt_history = self.generate_test_case(analysis, endpoint, method, body, status_code, + prompt_history) + self.write_test_case_to_file(description, test_case) + prompt_history = self.write_pytest_case(description, test_case, prompt_history) + return prompt_history + + def get_status_code(self, description: str) -> int: + """ + Extracts the first HTTP status code (3-digit integer) from a description string. + + Args: + description (str): A string potentially containing a status code. + + Returns: + int: The extracted status code. + + Raises: + ValueError: If no 3-digit status code is found. + """ + match = re.search(r"\b(\d{3})\b", description) + if match: + return int(match.group(1)) + raise ValueError("No valid status code found in the description.") + diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/utils/configuration_handler.py b/src/hackingBuddyGPT/usecases/web_api_testing/utils/configuration_handler.py new file mode 100644 index 00000000..68771316 --- /dev/null +++ b/src/hackingBuddyGPT/usecases/web_api_testing/utils/configuration_handler.py @@ -0,0 +1,59 @@ +import json +import os + +from hackingBuddyGPT.utils.prompt_generation.information import PromptStrategy + + +class ConfigurationHandler(object): + + def __init__(self, config_file, strategy_string=None): + self.config_file = config_file + self.strategy_string = strategy_string + + def load(self, strategy_string=None): + if self.config_file != "": + if self.config_file != "": + current_file_path = os.path.dirname(os.path.abspath(__file__)) + self.config_path = os.path.join(current_file_path, "configs", self.config_file) + config = self._load_config() + + if "spotify" in self.config_path: + os.environ['SPOTIPY_CLIENT_ID'] = config['client_id'] + os.environ['SPOTIPY_CLIENT_SECRET'] = config['client_secret'] + os.environ['SPOTIPY_REDIRECT_URI'] = config['redirect_uri'] + + return config, self.get_strategy(strategy_string) + + def get_strategy(self, strategy_string=None): + + strategies = { + "cot": PromptStrategy.CHAIN_OF_THOUGHT, + "tot": PromptStrategy.TREE_OF_THOUGHT, + "icl": PromptStrategy.IN_CONTEXT + } + if strategy_string: + return strategies.get(strategy_string, PromptStrategy.IN_CONTEXT) + + return strategies.get(self.strategy_string, PromptStrategy.IN_CONTEXT) + + def _load_config(self, config_path=None): + if config_path is None: + config_path = self.config_path + """Loads JSON configuration from the specified path.""" + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file not found at {config_path}") + with open(config_path, 'r') as file: + return json.load(file) + + + + + def _extract_config_values(self, config): + token = config.get("token") + host = config.get("host") + description = config.get("description") + correct_endpoints = config.get("correct_endpoints", {}) + query_params = config.get("query_params", {}) + return token, host, description, correct_endpoints, query_params + + diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/utils/documentation_handler.py b/src/hackingBuddyGPT/usecases/web_api_testing/utils/documentation_handler.py new file mode 100644 index 00000000..32aa8317 --- /dev/null +++ b/src/hackingBuddyGPT/usecases/web_api_testing/utils/documentation_handler.py @@ -0,0 +1,129 @@ +import os +import yaml +from datetime import datetime +from hackingBuddyGPT.capabilities.yamlFile import YAMLFile + + +class DocumentationHandler: + """ + Handles the generation and updating of an OpenAPI specification document based on dynamic API responses. + + Attributes: + response_handler (object): An instance of the response handler for processing API responses. + schemas (dict): A dictionary to store API schemas. + filename (str): The filename for the OpenAPI specification file. + openapi_spec (dict): The OpenAPI specification document structure. + llm_handler (object): An instance of the LLM handler for interacting with the LLM. + api_key (str): The API key for accessing the LLM. + file_path (str): The path to the directory where the OpenAPI specification file will be stored. + file (str): The complete path to the OpenAPI specification file. + _capabilities (dict): A dictionary to store capabilities related to YAML file handling. + """ + + def __init__(self, llm_handler, response_handler): + """ + Initializes the handler with a template OpenAPI specification. + + Args: + llm_handler (object): An instance of the LLM handler for interacting with the LLM. + response_handler (object): An instance of the response handler for processing API responses. + """ + self.response_handler = response_handler + self.schemas = {} + self.filename = f"openapi_spec_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.yaml" + self.openapi_spec = { + "openapi": "3.0.0", + "info": { + "title": "Generated API Documentation", + "version": "1.0", + "description": "Automatically generated description of the API." + }, + "servers": [{"url": "https://localhost:8080"}], + "endpoints": {}, + "components": {"schemas": {}} + } + self.llm_handler = llm_handler + self.api_key = llm_handler.llm.api_key + current_path = os.path.dirname(os.path.abspath(__file__)) + self.file_path = os.path.join(current_path, "openapi_spec") + self.file = os.path.join(self.file_path, self.filename) + self._capabilities = { + "yaml": YAMLFile() + } + + def update_openapi_spec(self, resp, result): + """ + Updates the OpenAPI specification based on the API response provided. + + Args: + resp (object): The response object containing details like the path and method which should be documented. + result (str): The result of the API call. + """ + request = resp.action + + if request.__class__.__name__ == 'RecordNote': # TODO: check why isinstance does not work + self.check_openapi_spec(resp) + if request.__class__.__name__ == 'HTTPRequest': + path = request.path + method = request.method + # Ensure that path and method are not None and method has no numeric characters + if path and method: + # Initialize the path if not already present + if path not in self.openapi_spec['endpoints']: + self.openapi_spec['endpoints'][path] = {} + # Update the method description within the path + example, reference, self.openapi_spec = self.response_handler.parse_http_response_to_openapi_example( + self.openapi_spec, result, path, method) + if example is not None or reference is not None: + self.openapi_spec['endpoints'][path][method.lower()] = { + "summary": f"{method} operation on {path}", + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "$ref": reference + }, + "examples": example + } + } + } + } + } + + def write_openapi_to_yaml(self): + """ + Writes the updated OpenAPI specification to a YAML file with a timestamped filename. + """ + try: + # Prepare data to be written to YAML + openapi_data = { + "openapi": self.openapi_spec["openapi"], + "info": self.openapi_spec["info"], + "servers": self.openapi_spec["servers"], + "components": self.openapi_spec["components"], + "paths": self.openapi_spec["endpoints"] + } + + # Create directory if it doesn't exist and generate the timestamped filename + os.makedirs(self.file_path, exist_ok=True) + + # Write to YAML file + with open(self.file, 'w') as yaml_file: + yaml.dump(openapi_data, yaml_file, allow_unicode=True, default_flow_style=False) + print(f"OpenAPI specification written to {self.filename}.") + except Exception as e: + raise Exception(f"Error writing YAML file: {e}") + + def check_openapi_spec(self, note): + """ + Uses OpenAI's GPT model to generate a complete OpenAPI specification based on a natural language description. + + Args: + note (object): The note object containing the description of the API. + """ + description = self.response_handler.extract_description(note) + from hackingBuddyGPT.usecases.web_api_testing.documentation.parsing.yaml_assistant import YamlFileAssistant + yaml_file_assistant = YamlFileAssistant(self.file_path, self.llm_handler) + yaml_file_assistant.run(description) diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/utils/endpoint_categorizer.py b/src/hackingBuddyGPT/usecases/web_api_testing/utils/endpoint_categorizer.py new file mode 100644 index 00000000..e69de29b diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/utils/evaluator.py b/src/hackingBuddyGPT/usecases/web_api_testing/utils/evaluator.py new file mode 100644 index 00000000..acc7205d --- /dev/null +++ b/src/hackingBuddyGPT/usecases/web_api_testing/utils/evaluator.py @@ -0,0 +1,296 @@ +import copy +from itertools import chain + +from hackingBuddyGPT.usecases.web_api_testing.documentation.pattern_matcher import PatternMatcher + + +class Evaluator: + def __init__(self, num_runs=10, config=None): + self._pattern_matcher = PatternMatcher() + self.documented_query_params = config.get("query_params") + self.num_runs = num_runs + self.ids = [] + self.query_params_found = {} + self.name = config.get("name") + self.documented_routes = config.get("correct_endpoints") # Example documented GET routes + self.query_params_documented = len(config.get("query_params")) # Example documented query parameters + self.results = { + "routes_found": [], + "query_params_found": {}, + "false_positives": [], + } + + def calculate_metrics(self): + """ + Calculate evaluation metrics. + """ + # Average percentages of documented routes and parameters found + percent_params_found_values = 0 + percent_params_found_keys = 0 + + self.results["routes_found"] = list(set(self.results["routes_found"])) + # Calculate percentages + percent_routes_found = self.get_percentage(self.results["routes_found"], self.documented_routes) + if len(self.documented_query_params) > 0: + percent_params_found_values = self.calculate_match_percentage(self.documented_query_params, self.results["query_params_found"])["Value Match Percentage"] + percent_params_found_keys = self.calculate_match_percentage(self.documented_query_params, self.results["query_params_found"])["Key Match Percentage"] + else: + percent_params_found = 0 + + # Average false positives + avg_false_positives = len(self.results["false_positives"]) / self.num_runs + + + # Best and worst for routes and parameters + if len(self.results["routes_found"]) >0: + + r_best = max(self.results["routes_found"]) + r_worst = min(self.results["routes_found"]) + else: + r_best = 0 + r_worst = 0 + self.documented_routes = list(set(self.documented_routes)) + + metrics = { + "Percent Routes Found": percent_routes_found, + "Percent Parameters Values Found": percent_params_found_values, + "Percent Parameters Keys Found": percent_params_found_keys, + "Average False Positives": avg_false_positives, + "Routes Best/Worst": (r_best, r_worst), + "Additional_Params Best/Worst": set( + tuple(value) if isinstance(value, list) else value for value in self.documented_query_params.values() +).difference( + set(tuple(value) if isinstance(value, list) else value for value in self.query_params_found.values()) +), + "Additional_routes Found": set(self.results["routes_found"]).difference(set(self.documented_routes)), + "Missing routes Found": set(self.documented_routes).difference(set(self.results["routes_found"])), + } + + return metrics + + def check_false_positives(self, path): + """ + Identify and count false positive query parameters in the response. + + Args: + response (dict): The response data to check for false positive parameters. + + Returns: + int: The count of false positive query parameters. + """ + # Example list of documented query parameters + # Extract the query parameters from the response + response_query_params = self._pattern_matcher.extract_query_params(path).keys() + + # Identify false positives + false_positives = [param for param in response_query_params if param not in self.documented_query_params] + + return len(false_positives) + + def extract_query_params_from_response_data(self, response): + """ + Extract query parameters from the actual response data. + + Args: + response (dict): The response data. + + Returns: + list: A list of query parameter names found in the response. + """ + return response.get("query_params", []) + + def all_query_params_found(self, path, response): + """ + Count the number of documented query parameters found in a response. + + Args: + turn (int): The current turn number for the documentation process. + + Returns: + int: The count of documented query parameters found in this turn. + """ + if response.action.query is not None: + query = response.action.query.split("?")[0] + path = path + "&"+ query + # Simulate response query parameters found (this would usually come from the response data) + response_query_params = self._pattern_matcher.extract_query_params(path) + valid_query_params = [] + if "?" in path: + ep = path.split("?")[0] # Count the valid query parameters found in the response + if response_query_params: + for param, value in response_query_params.items(): + if ep in self.documented_query_params.keys(): + x = self.documented_query_params[ep] + if param in x: + valid_query_params.append(param) + if ep not in self.results["query_params_found"].keys(): + self.results["query_params_found"][ep] = [] + if param not in self.results["query_params_found"][ep]: + self.results["query_params_found"][ep].append(param) + if ep not in self.query_params_found.keys(): + self.query_params_found[ep] = [] + if param not in self.query_params_found[ep]: + self.query_params_found[ep].append(param) + self.results["query_params_found"] = self.query_params_found + + def extract_query_params_from_response(self, path): + """ + Extract query parameters from the response in a specific turn. + + Args: + turn (int): The current turn number for the documentation process. + + Returns: + list: A list of query parameter names found in the response. + """ + # Placeholder code: Replace this with actual extraction logic + return self._pattern_matcher.extract_query_params(path).keys() + + def calculate_match_percentage(self, documented, result): + total_keys = len(documented) + matching_keys = 0 + value_matches = 0 + total_values = 0 + + for key in documented: + # Check if the key exists in the result + if key in result: + matching_keys += 1 + # Compare values as sets (ignoring order) + documented_values = set(documented[key]) + result_values = set(result[key]) + + # Count the number of matching values + value_matches += len(documented_values & result_values) # Intersection + total_values += len(documented_values) # Total documented values for the key + else: + total_values += len(documented[key]) # Add documented values for missing keys + + # Calculate percentages + key_match_percentage = (matching_keys / total_keys) * 100 + value_match_percentage = (value_matches / total_values) * 100 if total_values > 0 else 0 + + return { + "Key Match Percentage": key_match_percentage, + "Value Match Percentage": value_match_percentage, + } + + def evaluate_response(self, response, routes_found, current_step, query_endpoints): + query_params_found = 0 + routes_found = copy.deepcopy(routes_found) + + false_positives = 0 + for idx, route in enumerate(routes_found): + routes_found = self.add_if_is_cryptocurrency(idx, route, routes_found, current_step) + # Use evaluator to record routes and parameters found + if response.action.__class__.__name__ != "RecordNote": + for path in query_endpoints : + self.all_query_params_found(path, response) # This function should return the number found + false_positives = self.check_false_positives(path) # Define this function to determine FP count + + # Record these results in the evaluator + self.results["routes_found"] += routes_found + #self.results["query_params_found"].append(query_params_found) + self.results["false_positives"].append(false_positives) + + def add_if_is_cryptocurrency(self, idx, path,routes_found,current_step): + """ + If the path contains a known cryptocurrency name, replace that part with '{id}' + and add the resulting path to `self.prompt_helper.found_endpoints`. + """ + # Default list of cryptos to detect + routes_found = list(set(routes_found)) + cryptos = ["bitcoin", "ethereum", "litecoin", "dogecoin", + "cardano", "solana", "binance", "polkadot", "tezos",] + + # Convert to lowercase for the match, but preserve the original path for reconstruction if you prefer + lower_path = path.lower() + + parts = [part.strip() for part in path.split("/") if part.strip()] + + for crypto in cryptos: + if crypto in lower_path: + # Example approach: split by '/' and replace the segment that matches crypto + parts = path.split('/') + replaced_any = False + for i, segment in enumerate(parts): + if segment.lower() == crypto: + parts[i] = "{id}" + replaced_any = True + + # Only join and store once per path + if replaced_any: + replaced_path = "/".join(parts) + if path in routes_found: + for i, route in enumerate(routes_found): + if route == path: + routes_found[i] = replaced_path + + else: + routes_found.append(replaced_path) + if len(parts) == 3 and current_step == 4: + if "/"+ parts[0] + "/{id}/" + parts[2] not in routes_found: + for i, route in enumerate(routes_found): + if route == path: + routes_found[i] = "/" + parts[0] + "/{id}/" + parts[2] + break + if len(parts) == 2 and current_step == 2: + if "/"+parts[0] + "/{id}" not in routes_found: + for i, route in enumerate(routes_found): + if route == path: + routes_found[i] ="/"+parts[0] + "/{id}" + break + + if "/1" in path: + if idx < len(routes_found): + routes_found[idx] = routes_found[idx].replace("/1", "/{id}") + return routes_found + + + def get_percentage(self, param, documented_param): + found_set = set(param) + documented_set = set(documented_param) + + common_items = documented_set.intersection(found_set) + common_count = len(common_items) + percentage = (common_count / len(documented_set)) * 100 + + return percentage + + def finalize_documentation_metrics(self, file_path): + """Calculate and log the final effectiveness metrics after documentation process is complete.""" + metrics = self.calculate_metrics() + # Specify the file path + + + print(f'Appending metrics to {file_path}') + + # Appending the formatted data to a text file + with open(file_path, 'a') as file: # 'a' is for append mode + file.write("\n\nDocumentation Effectiveness Metrics:\n") + file.write(f"Percent Routes Found: {metrics['Percent Routes Found']:.2f}%\n") + file.write(f"Percent Parameters Values Found: {metrics['Percent Parameters Values Found']:.2f}%\n") + file.write(f"Percent Parameters Keys Found: {metrics['Percent Parameters Keys Found']:.2f}%\n") + file.write(f"Average False Positives: {metrics['Average False Positives']}\n") + file.write( + f"Routes Found - Best: {metrics['Routes Best/Worst'][0]}, Worst: {metrics['Routes Best/Worst'][1]}\n") + file.write( + f"Additional Query Parameters Found - Best: {', '.join(map(str, metrics['Additional_Params Best/Worst']))}\n") + file.write(f"Additional Routes Found: {', '.join(map(str, metrics['Additional_routes Found']))}\n") + file.write(f"Missing Routes Found: {', '.join(map(str, metrics['Missing routes Found']))}\n") + + # Adding a summary or additional information + total_documented_routes = len(self.documented_routes) + total_additional_routes = len(metrics['Additional_routes Found']) + total_missing_routes = len(metrics['Missing routes Found']) + file.write("\nSummary:\n") + file.write(f"Total Params Found: {self.query_params_found}\n") + file.write(f"Total Documented Routes: {total_documented_routes}\n") + file.write(f"Total Additional Routes Found: {total_additional_routes}\n") + file.write(f"Total Missing Routes: {total_missing_routes}\n") + file.write(f" Missing Parameters: {total_missing_routes}\n") + + # Optionally include a timestamp or other metadata + from datetime import datetime + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + file.write(f"Metrics generated on: {current_time}\n") diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/utils/llm_handler.py b/src/hackingBuddyGPT/usecases/web_api_testing/utils/llm_handler.py index 16b0dff1..7547b7f1 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/utils/llm_handler.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/utils/llm_handler.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List import openai +from instructor.exceptions import IncompleteOutputException from hackingBuddyGPT.capabilities.capability import capabilities_to_action_model @@ -17,7 +18,7 @@ class LLMHandler: created_objects (Dict[str, List[Any]]): A dictionary to keep track of created objects by their type. """ - def __init__(self, llm: Any, capabilities: Dict[str, Any]) -> None: + def __init__(self, llm: Any, capabilities: Dict[str, Any], all_possible_capabilities= None) -> None: """ Initializes the LLMHandler with the specified LLM and capabilities. @@ -29,8 +30,14 @@ def __init__(self, llm: Any, capabilities: Dict[str, Any]) -> None: self._capabilities = capabilities self.created_objects: Dict[str, List[Any]] = {} self._re_word_boundaries = re.compile(r"\b") + self.adjusting_counter = 0 + self.all_possible_capabilities = all_possible_capabilities - def call_llm(self, prompt: List[Dict[str, Any]]) -> Any: + + def get_specific_capability(self, capability_name: str) -> Any: + return {f"{capability_name}": self.all_possible_capabilities[capability_name]} + + def execute_prompt(self, prompt: List[Dict[str, Any]]) -> Any: """ Calls the LLM with the specified prompt and retrieves the response. @@ -40,45 +47,183 @@ def call_llm(self, prompt: List[Dict[str, Any]]) -> Any: Returns: Any: The response from the LLM. """ - print(f"Initial prompt length: {len(prompt)}") def call_model(prompt: List[Dict[str, Any]]) -> Any: - """Helper function to avoid redundancy in making the API call.""" + """Helper function to make the API call with the adjusted prompt.""" + if isinstance(prompt, list): + if isinstance(prompt[0], list): + prompt = prompt[0] + return self.llm.instructor.chat.completions.create_with_completion( model=self.llm.model, messages=prompt, response_model=capabilities_to_action_model(self._capabilities), + #max_tokens=200 # adjust as needed + ) + + # Helper to adjust the prompt based on its length. + + try: + if isinstance(prompt, list) and len(prompt) >= 10: + prompt = prompt[-10:] + if isinstance(prompt, str): + prompt = [prompt] + return call_model(prompt) + + except (openai.BadRequestError, IncompleteOutputException) as e: + + try: + # First adjustment attempt based on prompt length + self.adjusting_counter = 1 + if isinstance(prompt, list) and len(prompt) >= 5: + adjusted_prompt = self.adjust_prompt(prompt, num_prompts=1) + adjusted_prompt = self._ensure_that_tool_messages_are_correct(adjusted_prompt, prompt) + prompt= adjusted_prompt + if isinstance(prompt, str): + adjusted_prompt = [prompt] + prompt= adjusted_prompt + + + + return call_model(prompt) + + except (openai.BadRequestError, IncompleteOutputException) as e: + # Second adjustment based on token size if the first attempt fails + adjusted_prompt = self.adjust_prompt(prompt) + if isinstance(adjusted_prompt, str): + adjusted_prompt = [adjusted_prompt] + if adjusted_prompt == [] or adjusted_prompt == None: + adjusted_prompt = prompt[-1:] + if isinstance(adjusted_prompt, list): + if isinstance(adjusted_prompt[0], list): + adjusted_prompt = adjusted_prompt[0] + adjusted_prompt = self._ensure_that_tool_messages_are_correct(adjusted_prompt, prompt) + self.adjusting_counter = 2 + return call_model(adjusted_prompt) + + def execute_prompt_with_specific_capability(self, prompt: List[Dict[str, Any]], capability: Any) -> Any: + """ + Calls the LLM with the specified prompt and retrieves the response. + + Args: + prompt (List[Dict[str, Any]]): The prompt messages to send to the LLM. + + Returns: + Any: The response from the LLM. + """ + + def call_model(adjusted_prompt: List[Dict[str, Any]], capability: Any) -> Any: + """Helper function to make the API call with the adjusted prompt.""" + capability = self.get_specific_capability(capability) + + return self.llm.instructor.chat.completions.create_with_completion( + model=self.llm.model, + messages=adjusted_prompt, + response_model=capabilities_to_action_model(capability), + #max_tokens=1000 # adjust as needed ) + # Helper to adjust the prompt based on its length. + def adjust_prompt_based_on_length(prompt: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + if self.adjusting_counter == 2: + num_prompts = 10 + self.adjusting_counter = 0 + else: + num_prompts = int( + len(prompt) - 0.5 * len(prompt) if len(prompt) >= 20 else len(prompt) - 0.3 * len(prompt)) + return self.adjust_prompt(prompt, num_prompts=num_prompts) + try: - if len(prompt) > 30: - return call_model(self.adjust_prompt(prompt, num_prompts=5)) + # First adjustment attempt based on prompt length + if len(prompt) >= 10: + prompt = prompt[-10:] + return call_model(prompt, capability) + + except (openai.BadRequestError, IncompleteOutputException) as e: - return call_model(self.adjust_prompt_based_on_token(prompt)) - except openai.BadRequestError as e: try: - print(f"Error: {str(e)} - Adjusting prompt size and retrying.") - # Reduce prompt size; removing elements and logging this adjustment - return call_model(self.adjust_prompt_based_on_token(self.adjust_prompt(prompt))) - except openai.BadRequestError as e: - new_prompt = self.adjust_prompt_based_on_token(self.adjust_prompt(prompt, num_prompts=2)) - print("New prompt:") - print(f"Len New prompt:{len(new_prompt)}") - - for prompt in new_prompt: - print(f"{prompt}") - return call_model(new_prompt) + # Second adjustment based on token size if the first attempt fails + adjusted_prompt = self.adjust_prompt(prompt) + adjusted_prompt = self._ensure_that_tool_messages_are_correct(adjusted_prompt, prompt) + + self.adjusting_counter = 2 + adjusted_prompt = call_model(adjusted_prompt, capability) + return adjusted_prompt + + except (openai.BadRequestError, IncompleteOutputException) as e: + + # Final fallback with the smallest prompt size + shortened_prompt = self.adjust_prompt(prompt) + shortened_prompt = self._ensure_that_tool_messages_are_correct(shortened_prompt, prompt) + if isinstance(shortened_prompt, list): + if isinstance(shortened_prompt[0], list): + shortened_prompt = shortened_prompt[0] + print(f'shortened_prompt;{shortened_prompt}') + return call_model(shortened_prompt, capability) def adjust_prompt(self, prompt: List[Dict[str, Any]], num_prompts: int = 5) -> List[Dict[str, Any]]: - adjusted_prompt = prompt[len(prompt) - num_prompts - (len(prompt) % 2) : len(prompt)] - if not isinstance(adjusted_prompt[0], dict): - adjusted_prompt = prompt[len(prompt) - num_prompts - (len(prompt) % 2) - 1 : len(prompt)] + """ + Adjusts the prompt list to contain exactly `num_prompts` items. - print(f"Adjusted prompt length: {len(adjusted_prompt)}") - print(f"adjusted prompt:{adjusted_prompt}") - return prompt + Args: + prompt (List[Dict[str, Any]]): The list of prompts to adjust. + num_prompts (int): The desired number of prompts. Defaults to 5. + + Returns: + List[Dict[str, Any]]: The adjusted list containing exactly `num_prompts` items. + """ + # Ensure the number of prompts does not exceed the total available + if len(prompt) < num_prompts: + return prompt # Return all available if there are fewer prompts than requested + + # Limit to the last `num_prompts` items + # Ensure not to exceed the available prompts + adjusted_prompt = prompt[-num_prompts:] + adjusted_prompt = adjusted_prompt[:len(adjusted_prompt) - len(adjusted_prompt) % 2] + if adjusted_prompt == []: + return prompt + + # Ensure adjusted_prompt starts with a dict item + + if not isinstance(adjusted_prompt, str): + if not isinstance(adjusted_prompt[0], dict): + adjusted_prompt = prompt[len(prompt) - num_prompts - (len(prompt) % 2) - 1: len(prompt)] + + # If adjusted_prompt is None, fallback to the full prompt + if not adjusted_prompt: + adjusted_prompt = prompt - def add_created_object(self, created_object: Any, object_type: str) -> None: + # Ensure adjusted_prompt items are valid dicts and follow `tool` message constraints + validated_prompt = self._ensure_that_tool_messages_are_correct(adjusted_prompt, prompt) + + return validated_prompt + + def _ensure_that_tool_messages_are_correct(self, adjusted_prompt, prompt): + # Ensure adjusted_prompt items are valid dicts and follow `tool` message constraints + validated_prompt = [] + last_item = None + adjusted_prompt.reverse() + + for item in adjusted_prompt: + if isinstance(item, dict): + # Remove `tool` messages without a preceding `tool_calls` message + if item.get("role") == "tool" and (last_item is None or last_item.get("role") != "tool_calls"): + continue + + # Track valid items + validated_prompt.append(item) + last_item = item + + # Reverse back if `prompt` is not a string (just in case) + if not isinstance(validated_prompt, str): + validated_prompt.reverse() + if validated_prompt == []: + validated_prompt = [prompt[-1]] + if isinstance(validated_prompt, object): + validated_prompt = [validated_prompt] + return validated_prompt + + def _add_created_object(self, created_object: Any, object_type: str) -> None: """ Adds a created object to the dictionary of created objects, categorized by object type. @@ -91,34 +236,57 @@ def add_created_object(self, created_object: Any, object_type: str) -> None: if len(self.created_objects[object_type]) < 7: self.created_objects[object_type].append(created_object) - def get_created_objects(self) -> Dict[str, List[Any]]: + def _get_created_objects(self) -> Dict[str, List[Any]]: """ Retrieves the dictionary of created objects and prints its contents. Returns: Dict[str, List[Any]]: The dictionary of created objects. """ - print(f"created_objects: {self.created_objects}") return self.created_objects def adjust_prompt_based_on_token(self, prompt: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - prompt.reverse() + if not isinstance(prompt, str): + prompt.reverse() + + last_item = None tokens = 0 - max_tokens = 10000 + max_tokens = 100 + last_action = "" + removed_item = 0 for item in prompt: if tokens > max_tokens: - prompt.remove(item) + if not isinstance(last_item, dict): + prompt.remove(item) + else: + prompt.remove(item) + last_action = "remove" + removed_item = removed_item + 1 else: + + if last_action == "remove": + if isinstance(last_item, dict) and last_item.get('role') == 'tool': + prompt.remove(item) + last_action = "" if isinstance(item, dict): new_token_count = tokens + self.get_num_tokens(item["content"]) - if new_token_count <= max_tokens: - tokens = new_token_count + tokens = new_token_count else: - continue + new_token_count = tokens + 100 + tokens = new_token_count + + last_item = item - print(f"tokens:{tokens}") - prompt.reverse() + if removed_item == 0: + counter = 5 + for item in prompt: + prompt.remove(item) + counter = counter + 1 + if not isinstance(prompt, str): + prompt.reverse() return prompt def get_num_tokens(self, content: str) -> int: + if not isinstance(content, str): + content = str(content) return len(self._re_word_boundaries.findall(content)) >> 1 diff --git a/src/hackingBuddyGPT/utils/local_shell/__init__.py b/src/hackingBuddyGPT/utils/local_shell/__init__.py new file mode 100644 index 00000000..93e07699 --- /dev/null +++ b/src/hackingBuddyGPT/utils/local_shell/__init__.py @@ -0,0 +1,3 @@ +from .local_shell import LocalShellConnection + +__all__ = ["LocalShellConnection"] diff --git a/src/hackingBuddyGPT/utils/local_shell/local_shell.py b/src/hackingBuddyGPT/utils/local_shell/local_shell.py new file mode 100755 index 00000000..0ecf913c --- /dev/null +++ b/src/hackingBuddyGPT/utils/local_shell/local_shell.py @@ -0,0 +1,335 @@ +from dataclasses import dataclass, field +from typing import Optional, Tuple +import time +import uuid +import subprocess +import re +import signal +import getpass + +from hackingBuddyGPT.utils.configurable import configurable + +@configurable("local_shell", "attaches to a running local shell inside tmux using tmux") +@dataclass +class LocalShellConnection: + tmux_session: str = field(metadata={"help": "tmux session name of the running shell inside tmux"}) + delay: float = field(default=0.5, metadata={"help": "delay between commands"}) + max_wait: int = field(default=300, metadata={"help": "maximum wait time for command completion"}) + + # Static attributes for connection info + username: str = field(default_factory=getpass.getuser, metadata={"help": "username for the connection"}) + password: str = field(default="", metadata={"help": "password for the connection"}) + host: str = field(default="localhost", metadata={"help": "host for the connection"}) + hostname: str = field(default="localhost", metadata={"help": "hostname for the connection"}) + port: Optional[int] = field(default=None, metadata={"help": "port for the connection"}) + keyfilename: str = field(default="", metadata={"help": "key filename for the connection"}) + + # Internal state + last_output_hash: Optional[int] = field(default=None, init=False) + _initialized: bool = field(default=False, init=False) + + def init(self): + if not self.check_session(): + raise RuntimeError(f"Tmux session '{self.tmux_session}' does not exist. Please create it first or use an existing session name.") + else: + print(f"Connected to existing tmux session: {self.tmux_session}") + self._initialized = True + + def new_with(self, *, tmux_session=None, delay=None, max_wait=None) -> "LocalShellConnection": + return LocalShellConnection( + tmux_session=tmux_session or self.tmux_session, + delay=delay or self.delay, + max_wait=max_wait or self.max_wait, + ) + + def run(self, cmd, *args, **kwargs) -> Tuple[str, str, int]: + """ + Run a command and return (stdout, stderr, return_code). + This is the main interface method that matches the project pattern. + """ + if not self._initialized: + self.init() + + if not cmd.strip(): + return "", "", 0 + + try: + output = self.run_with_unique_markers(cmd) + + return output, "", 0 + except Exception as e: + return "", str(e), 1 + + def send_command(self, command): + """Send a command to the tmux session.""" + try: + subprocess.run(['tmux', 'send-keys', '-t', self.tmux_session, command, 'Enter'], check=True) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to send command to tmux: {e}") + + def capture_output(self, history_lines=10000): + """Capture the entire tmux pane content including scrollback.""" + try: + # Capture with history to get more content + result = subprocess.run( + ['tmux', 'capture-pane', '-t', self.tmux_session, '-p', '-S', f'-{history_lines}'], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True + ) + return result.stdout + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Failed to capture tmux output: {e}") + + def get_cursor_position(self): + """Get cursor position to detect if command is still running.""" + try: + result = subprocess.run( + ['tmux', 'display-message', '-t', self.tmux_session, '-p', '#{cursor_x},#{cursor_y}'], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True + ) + return result.stdout.strip() + except subprocess.CalledProcessError: + return None + + def wait_for_command_completion(self, timeout=None, check_interval=0.5): + """ + Advanced method to wait for command completion using multiple indicators. + """ + if timeout is None: + timeout = self.max_wait + + start_time = time.time() + last_output_hash = None + last_cursor_pos = None + stable_count = 0 + min_stable_time = 1.5 # Reduced for faster detection + + while time.time() - start_time < timeout: + # Use hash for large outputs to detect changes more efficiently + current_output = self.capture_output(1000) # Smaller buffer for speed + current_output_hash = hash(current_output) + current_cursor = self.get_cursor_position() + + # Check if output and cursor position are stable + if (current_output_hash == last_output_hash and + current_cursor == last_cursor_pos and + current_cursor is not None): + stable_count += 1 + + # If stable for enough cycles, check for prompt + if stable_count >= (min_stable_time / check_interval): + if self._has_prompt_at_end(current_output): + return True + else: + stable_count = 0 + + last_output_hash = current_output_hash + last_cursor_pos = current_cursor + + time.sleep(check_interval) + + return False + + def _has_prompt_at_end(self, output): + if not output.strip(): + return False + + lines = output.strip().split('\n') + if not lines: + return False + + last_line = lines[-1].strip() + + prompt_patterns = [ + r'.*[$#]\s*$', # Basic $ or # prompts + r'.*>\s*$', # > prompts + r'.*@.*:.*[$#]\s*$', # user@host:path$ format + r'.*@.*:.*>\s*$', # user@host:path> format + r'^\S+:\S*[$#]\s*$', # Simple host:path$ format + r'.*\$\s*$', # Ends with $ (catch-all) + r'.*#\s*$', # Ends with # (catch-all) + ] + + for pattern in prompt_patterns: + if re.match(pattern, last_line): + return True + + if len(last_line) < 100 and any(char in last_line for char in ['$', '#', '>', ':']): + if not any(keyword in last_line.lower() for keyword in + ['error', 'warning', 'failed', 'success', 'completed', 'finished']): + return True + + return False + + def run_with_unique_markers(self, command): + """Run command using unique markers - improved version for large outputs.""" + start_marker = f"CMDSTART{uuid.uuid4().hex[:8]}" + end_marker = f"CMDEND{uuid.uuid4().hex[:8]}" + + try: + self.send_command(f"echo '{start_marker}'") + time.sleep(0.5) + + self.send_command(command) + + if not self.wait_for_command_completion(): + raise RuntimeError(f"Command timed out after {self.max_wait}s") + + self.send_command(f"echo '{end_marker}'") + time.sleep(0.8) + + final_output = self.capture_output(50000) + + # Extract content between markers + result = self._extract_between_markers(final_output, start_marker, end_marker, command) + return result + + except Exception as e: + return self.run_simple_fallback(command) + + def _extract_between_markers(self, output, start_marker, end_marker, original_command): + lines = output.splitlines() + start_idx = -1 + end_idx = -1 + + for i, line in enumerate(lines): + if start_marker in line: + start_idx = i + elif end_marker in line and start_idx != -1: + end_idx = i + break + + if start_idx == -1 or end_idx == -1: + return self.run_simple_fallback(original_command) + + extracted_lines = [] + for i in range(start_idx + 1, end_idx): + line = lines[i] + if not self._is_command_echo(line, original_command): + extracted_lines.append(line) + + return '\n'.join(extracted_lines).strip() + + def _is_command_echo(self, line, command): + stripped = line.strip() + if not stripped: + return False + + for prompt_char in ['$', '#', '>']: + if prompt_char in stripped: + after_prompt = stripped.split(prompt_char, 1)[-1].strip() + if after_prompt == command: + return True + + return stripped == command + + def run_simple_fallback(self, command): + try: + subprocess.run(['tmux', 'set-option', '-t', self.tmux_session, 'history-limit', '50000'], + capture_output=True) + + clear_marker = f"__CLEAR_{uuid.uuid4().hex[:8]}__" + self.send_command('clear') + time.sleep(0.3) + self.send_command(f'echo "{clear_marker}"') + time.sleep(0.3) + + self.send_command(command) + + self.wait_for_command_completion() + + end_marker = f"__END_{uuid.uuid4().hex[:8]}__" + self.send_command(f'echo "{end_marker}"') + time.sleep(0.5) + + output = self.capture_output(50000) + + lines = output.splitlines() + start_idx = -1 + end_idx = -1 + + for i, line in enumerate(lines): + if clear_marker in line: + start_idx = i + elif end_marker in line and start_idx != -1: + end_idx = i + break + + if start_idx != -1 and end_idx != -1: + result_lines = lines[start_idx + 1:end_idx] + if result_lines and command in result_lines[0]: + result_lines = result_lines[1:] + result = '\n'.join(result_lines).strip() + else: + result = self._extract_recent_output(output, command) + + subprocess.run(['tmux', 'set-option', '-t', self.tmux_session, 'history-limit', '10000'], + capture_output=True) + + return result + + except Exception as e: + subprocess.run(['tmux', 'set-option', '-t', self.tmux_session, 'history-limit', '10000'], + capture_output=True) + raise RuntimeError(f"Error executing command: {e}") + + def _extract_recent_output(self, output, command): + lines = output.splitlines() + + for i in range(len(lines) - 1, -1, -1): + line = lines[i] + if command in line and any(prompt in line for prompt in ['$', '#', '>', '└─']): + return '\n'.join(lines[i + 1:]).strip() + + return '\n'.join(lines[-50:]).strip() if lines else "" + + def run_with_timeout(self, command, timeout=60): + old_max_wait = self.max_wait + self.max_wait = timeout + try: + return self.run(command) + finally: + self.max_wait = old_max_wait + + def interrupt_command(self): + try: + subprocess.run(['tmux', 'send-keys', '-t', self.tmux_session, 'C-c'], check=True) + time.sleep(1) + return True + except subprocess.CalledProcessError: + return False + + def check_session(self): + try: + result = subprocess.run( + ['tmux', 'list-sessions', '-F', '#{session_name}'], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True + ) + + session_names = result.stdout.strip().split('\n') + return self.tmux_session in session_names + + except subprocess.CalledProcessError: + return False + + def get_session_info(self): + try: + result = subprocess.run( + ['tmux', 'display-message', '-t', self.tmux_session, '-p', + 'Session: #{session_name}, Window: #{window_name}, Pane: #{pane_index}'], + stdout=subprocess.PIPE, + text=True, + check=True + ) + return result.stdout.strip() + except subprocess.CalledProcessError: + return "Session info unavailable" + diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/__init__.py b/src/hackingBuddyGPT/utils/prompt_generation/__init__.py similarity index 100% rename from src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/__init__.py rename to src/hackingBuddyGPT/utils/prompt_generation/__init__.py diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/__init__.py b/src/hackingBuddyGPT/utils/prompt_generation/information/__init__.py similarity index 100% rename from src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/__init__.py rename to src/hackingBuddyGPT/utils/prompt_generation/information/__init__.py diff --git a/src/hackingBuddyGPT/utils/prompt_generation/information/pentesting_information.py b/src/hackingBuddyGPT/utils/prompt_generation/information/pentesting_information.py new file mode 100644 index 00000000..695bc745 --- /dev/null +++ b/src/hackingBuddyGPT/utils/prompt_generation/information/pentesting_information.py @@ -0,0 +1,3377 @@ +import base64 +import copy +import glob +import os +import random +import re +import secrets +from typing import List + +import pandas + +from hackingBuddyGPT.usecases.web_api_testing.documentation.parsing import OpenAPISpecificationParser +from hackingBuddyGPT.utils.prompt_generation.information.prompt_information import ( + PromptPurpose, +) +from faker import Faker + + +class PenTestingInformation: + def __init__(self, openapi_spec_parser: OpenAPISpecificationParser, config) -> None: + """ + Initializes the PenTestingInformation with optional authentication credentials. + + Args: + openapi_spec_parser (OpenAPISpecificationParser): An instance of OpenAPISpecificationParser. + username (str, optional): Username for authentication, if necessary. Defaults to an empty string. + password (str, optional): Password for authentication, if necessary. Defaults to an empty string. + """ + + # Set basic authentication details + if "admin" in config: + self.admin = config["admin"] + else: + self.admin = None + self.guest = None + self.credentials = {} + self.resources = {} + self.valid_token = None + self.current_post_endpoint = None # TODO + self.faker = Faker() + self.username = self.faker.email().lower() + self.password = self.faker.password() + self.available_numbers = [] + self.config = config + file = self.get_file(self.config.get("csv_file")) + if file == "Not found": + self.df = pandas.DataFrame() + else: + self.df = pandas.read_csv(file[0], names=["username", "password"]) + + # Parse endpoints and their categorization from the given parser instance + categorized_endpoints = openapi_spec_parser.classify_endpoints(self.config.get("name")) + + # Assign schema and endpoint attributes directly from the parser methods + self.schemas = openapi_spec_parser.get_schemas() + self.endpoints = openapi_spec_parser.get_endpoints() + self.openapi_spec_parser = openapi_spec_parser + self.get_comment_ep() + + # Assign categorized endpoint types to attributes + self.assign_endpoint_categories(categorized_endpoints) + self.accounts = [] + opt_endpoints = [ep for ep in self.endpoints if "otp" in ep] + + self.brute_force_accounts = [] + if self.admin is not None: + admin = self.config.get("admin").get("email") + self.assign_brute_force_endpoints(admin) + + self.pentesting_step_list = [PromptPurpose.SETUP, + PromptPurpose.VERIY_SETUP, + PromptPurpose.AUTHENTICATION, + PromptPurpose.AUTHORIZATION, + PromptPurpose.SPECIAL_AUTHENTICATION, + PromptPurpose.INPUT_VALIDATION, + PromptPurpose.ERROR_HANDLING_INFORMATION_LEAKAGE, + PromptPurpose.SESSION_MANAGEMENT, + PromptPurpose.CROSS_SITE_SCRIPTING, + PromptPurpose.CROSS_SITE_FORGERY, + PromptPurpose.BUSINESS_LOGIC_VULNERABILITIES, + PromptPurpose.RATE_LIMITING_THROTTLING, + PromptPurpose.SECURITY_MISCONFIGURATIONS, + PromptPurpose.LOGGING_MONITORING + ] + + def assign_endpoint_categories(self, categorized_endpoints): + """ + Assign categorized endpoint types to instance attributes from given categorized endpoints dictionary. + + Args: + categorized_endpoints (dict): A dictionary containing categorized endpoints. + """ + self.resource_intensive_endpoint = categorized_endpoints.get('resource_intensive_endpoint') + + self.secure_action_endpoint = categorized_endpoints.get('secure_action_endpoint') + self.role_access_endpoint = categorized_endpoints.get('role_access_endpoint') + self.sensitive_data_endpoint = categorized_endpoints.get('sensitive_data_endpoint') + self.sensitive_action_endpoint = categorized_endpoints.get('sensitive_action_endpoint') + self.login_endpoint = categorized_endpoints.get('login_endpoint') + self.account_endpoint = categorized_endpoints.get('account_creation') + self.auth_endpoint = categorized_endpoints.get('auth_endpoint') + self.generate_iter_and_assign_current_endpoints(categorized_endpoints) + self.analysis_step_list = [PromptPurpose.ANALYSIS, PromptPurpose.DOCUMENTATION, + PromptPurpose.REPORTING] + self.categorized_endpoints = categorized_endpoints + self.tokens = {} + self.counter = 0 + + def set_valid_token(self, token: str) -> None: + self.valid_token = token + + def generate_iter_and_assign_current_endpoints(self, categorized_endpoints): + for key in ['public_endpoint', 'protected_endpoint', 'refresh_endpoint']: + endpoint_list = categorized_endpoints.get(key, []) + if endpoint_list: + setattr(self, f"{key}_iterator", iter(endpoint_list)) + setattr(self, f"current_{key}", next(getattr(self, f"{key}_iterator"), None)) + else: + setattr(self, f"{key}_iterator", iter([])) + setattr(self, f"current_{key}", None) + + def explore_steps(self, purpose: PromptPurpose) -> List[str]: + """ + Provides initial penetration testing steps for the given purpose. + + Args: + purpose (PromptPurpose): The purpose for which testing steps are required. + + Returns: + list: A list of steps corresponding to the specified purpose. + """ + # Map purposes to their corresponding methods + purpose_methods = { + PromptPurpose.SETUP: self.setup_test, + PromptPurpose.VERIY_SETUP: self.verify_setup, + PromptPurpose.AUTHENTICATION: self.generate_authentication_prompts, + PromptPurpose.AUTHORIZATION: self.generate_authorization_prompts, + PromptPurpose.SPECIAL_AUTHENTICATION: self.generate_special_authentication, + PromptPurpose.INPUT_VALIDATION: self.generate_input_validation_prompts, + PromptPurpose.ERROR_HANDLING_INFORMATION_LEAKAGE: self.generate_error_handling_prompts, + PromptPurpose.SESSION_MANAGEMENT: self.generate_session_management_prompts, + PromptPurpose.CROSS_SITE_SCRIPTING: self.generate_xss_prompts, + PromptPurpose.CROSS_SITE_FORGERY: self.generate_csrf_prompts, + PromptPurpose.BUSINESS_LOGIC_VULNERABILITIES: self.generate_business_logic_vul_prompts, + PromptPurpose.RATE_LIMITING_THROTTLING: self.generate_rate_limit_throttling, + PromptPurpose.SECURITY_MISCONFIGURATIONS: self.generate_security_misconfiguration_prompts, + PromptPurpose.LOGGING_MONITORING: self.generate_logging_monitoring_prompts + } + + # Call the appropriate method based on the purpose + if purpose in purpose_methods: + return purpose_methods[purpose]() + else: + raise ValueError(f"Invalid purpose: {purpose}") + + def get_analysis_step(self, purpose: PromptPurpose = None, response: str = "", additional_context: str = "") -> str: + """ + Provides prompts for analysis based on the provided response for various purposes using an LLM. + + Args: + response (str, optional): The HTTP response to analyze. Default is an empty string. + + Returns: + dict: A dictionary where each key is a PromptPurpose and each value is a list of prompts. + """ + if purpose == PromptPurpose.ANALYSIS: + return f"Given the following parsed HTTP response:\n{response}\n" \ + f"Based on this context: {additional_context}\n" \ + "Analyze this response to determine in form of a RecordNote:\n" \ + "1. Whether the status code is appropriate for this type of request.\n" \ + "2. If the headers indicate proper security and rate-limiting practices. \n" \ + "3. If the headers include sensitive information like 'Server', 'X-Powered', 'X-Frame-Options', 'Cache-Control', 'Strict-Transport-Security', 'Set-Cookie', 'X-Request-ID', 'Accept-Encoding', 'Referer', and 'X-API-Version' ALso add why this can cause a vulnerability.\n" \ + "4. Whether the response body is correctly handled." + # "Keep your analysis short." + + if purpose == PromptPurpose.DOCUMENTATION: + return f"Based on the analysis provided, document the findings in form of a RecordNote:\n{response}." + # f" Keep your analysis short." + + if purpose == PromptPurpose.REPORTING: + return (f"Based on the documented findings : {response}.\n" + f"Suggest any improvements or issues that should be reported based on the findings to the API developers in form of a RecordNote.") + # f"Keep your analysis short." + + def get_steps_of_phase(self, purpose): + steps = self.explore_steps(purpose) + return steps + + def next_testing_endpoint(self): + self.current_public_endpoint = next(self.public_endpoint_iterator, None) + self.current_protected_endpoint = next(self.protected_endpoint_iterator, None) + self.current_refresh_endpoint = next(self.refresh_endpoint_iterator, None) + + def setup_test(self) -> List: + prompts = [] + counter = 0 + post_account = self.get_correct_endpoints_for_method("account_creation", "POST") + prompts, counter = self.generate_user(post_account, counter, prompts) + if len(self.accounts) == 1: # ensure that there are at least two users + prompts, counter = self.generate_user(post_account, 1, prompts) + + return prompts + + def verify_setup(self) -> List: + prompts = [] + + get_account = self.get_correct_endpoints_for_method("public_endpoint", + "GET") + self.get_correct_endpoints_for_method( + "protected_endpoint", "GET") + + get_account = [ep for ep in get_account if ep.get("path").endswith("user") or ep.get("path").endswith("login")] + if len(get_account) == 0: + get_account = [ + self.endpoints[path]for path, methods in self.endpoints.items() + if path.endswith("/users/{id}") and "get" in methods + ] + + for acc in get_account: + for account in self.accounts: + account_path = acc.get("path") + account_schema = acc.get("schema") + if "api" in account_path: + if account["api"] in account_path: + if "user" and "id" in account_path: + account_path = account_path.replace("{id}", str(account.get("id"))) + prompts = prompts + [{ + "objective": "Check if user was created", + "steps": [ + f"Endpoint to use : {account_path}\n" + f"Send a GET request to the {account_path} with the with the correct schema {account_schema} with user:{account}.\n" + ], + "path": [account_path], + "token": [account.get("token")], + "expected_response_code": ["200 OK", "201 Created"], + "security": [ + f"Ensure that the returned user matches this user {account}"] + }] + else: + if "id}" in account_path: + + if isinstance(account.get("example"), dict): + if "example" in account.keys(): + if "id" in account.get("example").keys(): + account_path = account_path.replace("{id}", + str(account_schema.get("example").get("id"))) + else: + account_path = account_path.replace("{id}", str(account_schema.get("example"))) + else: + account_path = self.replace_placeholders_with_1(account_path, account.get("id")) + + if account_schema: + if "Authorization-Token" in account_schema.values(): + if "example" in account.keys() and "id" in account.get("example") and account.get( + "example").get("id") not in self.tokens.keys(): + description = account_schema.get("description") \ + .replace("username", account.get("example").get("username")) \ + .replace("password", account.get("example").get("password")) + account_schema["description"] = description + credentials = account.get("example").get("username") + ":" + account.get("example").get( + "password") + self.tokens[account.get("example").get("id")] = base64.b64encode( + credentials.encode('utf-8')).decode('utf-8') + + prompts = prompts + [{ + "objective": "Check if user was created", + "steps": [ + f"Endpoint to use : {account_path}\n" + f"Send a GET request to the {account_path} with the with the correct schema {account_schema} with user:{account}.\n" + ], + "path": [account_path], + "token": [account.get("token")], + "expected_response_code": ["200 OK", "201 Created"], + "security": [ + f"Ensure that the returned user matches this user {account}"] + }] + for login in self.login_endpoint: + for account in self.accounts: + + login_path, login_schema = self.get_path_and_schema(login) + if login_schema is None: + continue + login_schema = self.fill_in_schema_with_account(login_schema, account) + + if "api" in account.keys(): + if account["api"] in login_path: + prompts = prompts + [ + { + "objective": "Get Valid token", + "steps": [ + f"Endpoint to use : {login_path}\n" + f"Send a POST request to the {login_path} with the correct credentials of user:{account}.\n" + f"Request body should be in application/json and look similar to this: {{ {login_schema}}}" + ], + "path": [login_path], + "token": [account.get("token")], + "expected_response_code": ["200 OK", "201 Created"], + "security": [ + "Ensure that the correct expected code is received confirming that access is restricted and no sensitive data is included in the response body."] + } + ] + else: + + if account is None: + account = self.create_account(login_schema, login_path) + if account is None: + continue + + prompts = prompts + [ + { + "objective": "Get Valid token", + "steps": [ + f"Endpoint to use : {login_path}\n" + f"Send a POST request to the {login_schema} with the correct credentials of user:{account}.\n" + f"Request body should be in application/json and look similar to this: {{ {account}}}" + ], + "path": [login_path], + "token": [""], + "expected_response_code": ["200 OK", "201 Created"], + "security": [ + "Ensure that the correct expected code is received confirming that access is restricted and no sensitive data is included in the response body."] + } + ] + + account = None + if len(prompts) == 0: + for login in self.login_endpoint: + login_path, login_schema = self.get_path_and_schema(login) + if login_schema is None: + continue + if account is None: + account = self.create_account(login_schema, login_path) + if account is None: + continue + + prompts = prompts + [ + { + "objective": "Get Valid token", + "steps": [ + f"Endpoint to use : {login_path}\n" + f"Send a POST request to the {login_schema} with the correct credentials of user:{account}.\n" + f"Request body should be in application/json and look similar to this: {{ {login_schema}}}" + ], + "path": [login_path], + "token": [""], + "expected_response_code": ["200 OK", "201 Created"], + "security": [ + "Ensure that the correct expected code is received confirming that access is restricted and no sensitive data is included in the response body."] + } + ] + + return prompts + + def generate_request_body_string(self, schema, endpoint): + """ + Generate a request body string based on the updated schema. + + Args: + schema (dict): A schema dictionary containing an example. + username (str): The username to populate in the example. + password (str): The password to populate in the example. + + Returns: + str: A formatted request body string. + """ + updated_schema = self.get_credentials(schema, endpoint) + example = updated_schema.get("example", {}) + + # Generate key-value pairs from the schema example dynamically + key_value_pairs = [f"'{key}': '{value}'" for key, value in example.items() if value != ""] + return key_value_pairs + + def replace_placeholders_with_1(self, path: str, id) -> str: + """ + Replaces any curly-brace placeholders (e.g., '{videoid}', '{orderid}', '{someid}') + with the number '1' in the given path. + + Example: + "/identity/api/v2/user/videos/{videoid}" -> "/identity/api/v2/user/videos/1" + "/workshop/api/shop/orders/{orderid}" -> "/workshop/api/shop/orders/1" + "{somethingid}" -> "1" + """ + + def substitute(match): + # Extract the placeholder from the match + placeholder = match.group(0).strip('{}') + # Return the replacement for the placeholder if it exists + return id.get(placeholder, match.group(0)) + + # Regex to match anything in curly braces + + # Regex to match anything in curly braces, e.g. {videoid}, {postid}, etc. + + if id is None: + return path + if isinstance(id, int): + id = str(id) + + return re.sub(r"\{[^}]+\}", id, path) + + def generate_authentication_prompts(self): + """ + Generate a list of prompts for testing authentication mechanisms on protected endpoints. + + This function constructs test prompts for various authentication scenarios, including: + - Accessing protected endpoints with different user accounts. + - Using login credentials to acquire tokens. + - Testing endpoints that require path parameters like user IDs. + - Verifying refresh token mechanisms if applicable. + + Returns: + list: A list of prompts for testing authentication and authorization. + """ + prompts = [] + endpoints = self.get_correct_endpoints_for_method("protected_endpoint", "GET") + + prompts = self.resource_endpoints(prompts) + + if len(endpoints) != 0: + for endpoint, login in zip(endpoints, self.login_endpoint): + for account in self.accounts: + endpoint_dict = endpoint + if isinstance(endpoint, dict): + endpoint = endpoint.get("path") + + login_path, login_schema = self.get_path_and_schema(login) + if login_schema is None: + continue + login_schema = self.fill_in_schema_with_account(login_schema, account) + + if "api" in endpoint and len(endpoint.split("/")) > 0 and "api" in account: + if account["api"] in endpoint: + prompts = self.test_authentication(endpoint, account, prompts) + else: + prompts = self.test_authentication(endpoint, account, prompts) + if "_id}" in endpoint: + endpoint = self.replace_id_placeholder(endpoint, "1") + + if login_path: + + if "api" in endpoint and len(endpoint.split("/")) > 0: + if account["api"] in endpoint: + id = account.get("id") + if id and "{id}" in endpoint: + new_endpoint = endpoint.replace("{id}", str(account.get("id"))) + prompts = self.test_token(login_path, new_endpoint, account, login_schema, prompts) + prompts = self.random_common_users(new_endpoint, login_path, login_schema, prompts) + else: + + prompts = self.test_token(login_path, endpoint, account, login_schema, prompts) + prompts = self.random_common_users(endpoint, login_path, login_schema, prompts) + + else: + if "id}" in endpoint: + endpoint = self.replace_placeholders_with_1(endpoint, f"{account.get('id')}") + prompts = self.random_common_users(endpoint, login_path, login_schema, prompts) + prompts = self.test_token(login_path, endpoint, account, login_schema, prompts) + + if self.current_refresh_endpoint: + refresh_get_endpoints = self.get_correct_endpoints_for_method("refresh_endpoint", "GET") + refresh_post_endpoints = self.get_correct_endpoints_for_method("refresh_endpoint", "POST") + if len(refresh_get_endpoints) != 0 and refresh_post_endpoints: + for account in self.accounts: + + for refresh_get_endpoint, refresh_post_endpoint in zip(refresh_get_endpoints, + refresh_post_endpoints): + if "id}" in refresh_get_endpoint: + refresh_get_endpoint = self.replace_placeholders_with_1(refresh_get_endpoint, + account.get("id")) + if account["api"] in refresh_get_endpoint: + prompts = self.test_refresh_token(refresh_post_endpoint, refresh_get_endpoint, + account, prompts) + else: + prompts = self.test_refresh_token(refresh_post_endpoint, refresh_get_endpoint, + account, prompts) + + return prompts + + def generate_authorization_prompts(self): + """ + Generate prompts to test authorization and role-based access control (RBAC) + on protected endpoints. + + This method performs a series of authorization tests by: + - Generating prompts for GET requests to protected endpoints. + - Including user-specific endpoints like "users". + - Excluding endpoints related to "community". + - Testing RBAC behavior across different roles (admin, user, guest). + - Verifying proper handling of user-specific data and ID placeholders. + - Checking for data masking mechanisms. + - Running CRUD operation tests on relevant endpoints. + + Returns: + list: A list of authorization prompts for validating access control. + """ + prompts = [] + endpoints = self.get_correct_endpoints_for_method("protected_endpoint", "GET") + + users_ep = [ep + "s" for ep in self.endpoints if ep.endswith("user")] + endpoints = endpoints + users_ep + endpoints = [ep for ep in endpoints if "community" not in ep] + if len(endpoints) != 0: + for endpoint in endpoints: + for account in self.accounts: + if isinstance(endpoint, dict): + endpoint = endpoint.get("path") + if "api" in endpoint and len( + endpoint.split("/")) > 0 and "id" in endpoint and not "identity" in endpoint: + + if "api" in account and account["api"] in endpoint: + + prompts = self.resource_prompts(endpoint, account, prompts) + + prompts = self.mechanic_report(endpoint, account, prompts) + if "{id}" in endpoint: + id = account.get("id") + prompts = self.rbac(endpoint, account, prompts, id=id) + else: + + prompts = self.rbac(endpoint, account, prompts) + + # Verify Data Masking + prompts = self.verify_data_masking(endpoint, account, prompts) + + + else: + + prompts = self.resource_prompts(endpoint, account, prompts) + + prompts = self.mechanic_report(endpoint, account, prompts) + if "_id}" in endpoint: + if "id" in account: + id = account.get("id") + else: + id = 1 + endpoint = self.replace_id_placeholder(endpoint, str(id)) + + if self.admin and self.guest: + prompts.append( + + # Verify Role-Based Access Control (RBAC) + + { + "objective": "Verify Role-Based Access Control (RBAC)", + "steps": [ + f"send a GET request to {endpoint} with an 'admin' role authentication token in the Authorization header. Expect a 200 OK response, indicating access to all resources. Verify that the response data includes complete access to the endpoint, reflecting the admin’s permissions.user:{account}.\n", + f"send a GET request to {endpoint} with a 'user' role authentication token in the Authorization header. Expect a 200 OK response, but confirm that only authorized resources or data for a regular user are included in the response. Verify that any restricted admin-only resources are not accessible to the user role.user:{account}.\n", + f"send a GET request to {endpoint} with a 'guest' role authentication token (or without an authentication token if guests are unauthenticated). Expect either a 403 Forbidden or a 404 Not Found response if guests are not permitted to access the endpoint, or a 200 OK with limited data if some resources are publicly accessible to guests.user:{account}.\n" + ], + "path": [endpoint, endpoint, endpoint], + "expected_response_code": [ + "200 OK for admin, confirming full access.", + "200 OK for users, confirming access is limited to non-admin resources.", + "403 Forbidden or 404 Not Found for guests if no access is granted, or 200 OK if limited access is applicable." + ], + "token": [self.admin.get("token"), account.get("token"), + self.guest.get("token")], + "security": [ + "Verify that the response data includes complete access to the endpoint, reflecting the admin’s permissions.", + "Verify that any restricted admin-only resources are not accessible to the user role.", + "Verify that guest role has no or limited access."], + + } + + # This prompt tests the RBAC implementation by confirming that each user role has the appropriate access level: + # - Admin has full access + # - User has restricted access + # - Guest has no or limited access. + + ) + + else: + prompts = self.rbac(endpoint, account, prompts) + + # Verify Data Masking + prompts = self.verify_data_masking(endpoint, account, prompts) + + prompts = self.test_crud(endpoints, prompts) + return prompts + + def key_in_path(self, path, dictionary): + # Split the path into components + path_parts = path.split('/') + + # Check each part of the path to see if it is a key in the dictionary + for part in path_parts: + part = part.replace("{", "") + part = part.replace("}", "") + + if part in dictionary: + return True, part # Return True and the matching part if found + + return False, None # Return False and None if no part matches + + def generate_special_authentication(self): + """ + Generate prompts for advanced authentication testing. + + This method focuses on testing the security robustness of login and authentication mechanisms. + It includes: + - Injection tests with special accounts (e.g., "bender" accounts). + - Brute force protection validation using repeated login attempts. + - CSS (Content Sniffing) vulnerability checks on authentication endpoints. + - Token revocation checks after login. + - ID substitution for user-specific paths. + - Custom test cases for comments and advanced login workflows. + + Returns: + list: A list of prompts targeting special authentication test cases. + """ + prompts = [] + + self.counter = self.counter + 1 + + for login in self.login_endpoint: + + login_path, login_schema = self.get_path_and_schema(login) + + prompts = self.test_injection_bender_account(login_path, prompts) + + if login_path: + prompts = self.test_brute_force(login_path, login_schema, prompts) + + for acc in self.accounts: + + if "api" in login_path and "api" in acc.keys(): + if acc["api"] in login_path: + str_id = f"{acc.get('id')}" + login_path = login_path.replace("{id}", str_id) + + str = [ + f"Attempt Login with Invalid Credentials - Send repeated POST requests to the login endpoint {login_path} using invalid credentials (e.g., incorrect password) {self.get_credentials(login_schema, login_path)}. " + + f"Simulate a high rate of attempts (e.g., 10 requests within 1 minute) to trigger potential brute force protection mechanisms. "] * 10 + + if login_path: + prompts.append( # Brute Force Protection + { + "objective": "Brute Force Protection", + "steps": str, + "expected_response_code": + [[ + "429 Too Many Requests to indicate effective rate-limiting measures against brute force attacks.", + "403 Forbidden as an alternative response indicating account lockout or other security measures preventing further attempts." + ]] * 10, + "token": [""] * 10, + "path": [login_path] * 10, + "security": + ["Ensure that user cannot login with invalid credentials.\n" + + "Ensure that rate-limiting mechanisms are robust and properly configured to prevent brute force attacks by limiting the number of allowed failed attempts within a given time frame. This prevents attackers from trying a large number of combinations in a short period.\n" + + "Check that account lockout mechanisms or other access denial policies are effective in disabling further login attempts after a certain number of failures, protecting against continuous brute force attempts and securing user accounts from unauthorized access."] + + } + ) + if self.auth_endpoint: + if acc["api"] in login_path: + str_id = f"{acc.get('id')}" + login_path = login_path.replace("{id}", str_id) + + get_paths = self.get_correct_endpoints_for_method("auth_endpoint", "GET") + post_paths = self.get_correct_endpoints_for_method("auth_endpoint", "POST") + + for get_path in get_paths: + if acc["api"] in get_path: + str_id = f"{acc.get('id')}" + get_path = get_path.replace("{id}", str_id) + prompts = self.test_css(get_path, prompts) + + for post_path in post_paths: + if acc["api"] in post_path: + str_id = f"{acc.get('id')}" + post_path = post_path.replace("{id}", str_id) + schema = self.openapi_spec_parser.get_schema_for_endpoint(post_path, "POST") + prompts = self.test_css(post_path, prompts, schema=schema) + + if self.current_protected_endpoint: + get_endpoints = self.get_correct_endpoints_for_method("protected_endpoint", "GET") + + for get_endpoint in get_endpoints: + for account in self.accounts: + if acc["api"] in get_endpoint: + str_id = f"{acc.get('id')}" + get_endpoint = get_endpoint.replace("{id}", str_id) + prompts = self.test_token(login_path, get_endpoint, account, login_schema, + prompts, revocation=True) + + + else: + + if login_path: + prompts = self.test_brute_force(login_path, login_schema, prompts, number=10) + + if self.auth_endpoint: + + get_paths = self.get_correct_endpoints_for_method("auth_endpoint", "GET") + post_paths = self.get_correct_endpoints_for_method("auth_endpoint", "POST") + + for get_path in get_paths: + prompts = self.test_css(get_path, prompts) + + for post_path in post_paths: + schema = self.openapi_spec_parser.get_schema_for_endpoint(post_path, "POST") + prompts = self.test_css(post_path, prompts, schema=schema) + + if self.current_protected_endpoint: + get_endpoints = self.get_correct_endpoints_for_method("protected_endpoint", "GET") + + for get_endpoint in get_endpoints: + for account in self.accounts: + if "id}" in get_endpoint: + get_endpoint = self.replace_placeholders_with_1(get_endpoint, account.get("id")) + + prompts = self.test_token(login_path, get_endpoint, account, login_schema, prompts, + revocation=True) + + # return prompts + + prompts = self.test_comment(acc, prompts) + + + return prompts + + return prompts + + def generate_input_validation_prompts(self): + """ + Generate prompts for testing input validation vulnerabilities on POST endpoints. + + This method targets both protected and public POST endpoints and performs: + - SQL Injection testing using the account context and endpoint schema. + - General input validation testing (e.g., missing fields, invalid types). + + For each endpoint and account combination, the method replaces path parameters (like {id}) + and checks the relevant OpenAPI schema to craft test cases. + + Returns: + list: A list of prompts designed to evaluate input validation robustness. + """ + prompts = [] + + endpoints = self.get_correct_endpoints_for_method("protected_endpoint", + "POST") + self.get_correct_endpoints_for_method( + "public_endpoint", "POST") + if self.current_protected_endpoint: + + for ep in endpoints: + for account in self.accounts: + post_endpoint = ep.get("path") + schema = self.openapi_spec_parser.get_schema_for_endpoint(post_endpoint, "POST") + if "api" in account.keys() and account["api"] in ep: + str_id = f"{account.get('id')}" + post_endpoint = ep.replace("{id}", str_id) + + prompts = self.test_sql_injection(account, post_endpoint, schema, prompts) + + prompts = self.test_inputs(post_endpoint, schema, account, prompts) + + + + else: + prompts = self.test_sql_injection(account, post_endpoint, schema, prompts) + + prompts = self.test_inputs(post_endpoint, schema, account, prompts) + + return prompts + + def generate_error_handling_prompts(self): + """ + Generate prompts for testing error handling on POST endpoints. + + This method verifies that endpoints respond with meaningful and secure error messages + when provided with incorrect or malformed input. + + It combines protected and public POST endpoints, retrieves their schemas, and uses + account information to inject malformed or edge-case data to observe error behavior. + + Returns: + list: A list of prompts to test the robustness and clarity of error handling. + """ + prompts = [] + endpoints = self.get_correct_endpoints_for_method("protected_endpoint", + "POST") + self.get_correct_endpoints_for_method( + "public_endpoint", "POST") + + for ep in endpoints: + post_endpoint = ep.get("path") + schema = self.openapi_spec_parser.get_schema_for_endpoint(post_endpoint, "POST") + for account in self.accounts: + if "api" in post_endpoint and "api" in account.keys() and account["api"] in ep: + str_id = f"{account.get('id')}" + post_endpoint = ep.replace("{id}", str_id) + + prompts = self.test_error_handling(post_endpoint, account, schema, prompts) + + + else: + prompts = self.test_error_handling(post_endpoint, account, schema, prompts) + + return prompts + + def generate_session_management_prompts(self): + """ + Generate prompts for testing session management and security. + + This method checks GET endpoints (both protected and public) for: + - Proper session validation. + - Session hijacking resistance. + - Session-related cookie attributes (e.g., HttpOnly, Secure). + - Other session vulnerabilities. + + It also evaluates login endpoints to simulate authentication flows and test + how sessions are managed, maintained, and secured afterward. + + Returns: + list: A list of prompts testing session integrity, hijacking protections, and cookie configurations. + """ + prompts = [] + + endpoints = self.get_correct_endpoints_for_method("protected_endpoint", + "GET") + self.get_correct_endpoints_for_method( + "public_endpoint", "GET") + + for get_endpoint, _, _ in endpoints: + # Check if API Uses Session Management + for account in self.accounts: + if "api" in account and account["api"] in get_endpoint: + str_id = f"{account.get('id')}" + get_endpoint = get_endpoint.replace("{id}", str_id) + + prompts = self.test_session_management(get_endpoint, account, prompts) + + + else: + if "id}" in get_endpoint: + get_endpoint = self.replace_placeholders_with_1(get_endpoint, account.get("id")) + prompts = self.test_session_management(get_endpoint, account, prompts) + + if self.login_endpoint: + for login in self.login_endpoint: + + login_path, login_schema = self.get_path_and_schema(login) + if login_schema is None: + continue + if "api" in account and account["api"] in login_path: + str_id = f"{account.get('id')}" + login_path = login_path.replace("{id}", str_id) + + prompts = self.test_session_hijacking(login_path, get_endpoint, login_schema, account, + prompts) + + else: + prompts = self.test_session_hijacking(login_path, get_endpoint, login_schema, account, + prompts) + prompts = self.test_sessions_vulnerabilitiy(login_path, login_schema, account, prompts) + + prompts = self.test_cookies(login_path, login_schema, prompts) + + return prompts + + def generate_xss_prompts(self): + """ + Generate prompts for detecting Cross-Site Scripting (XSS) vulnerabilities. + + This method covers both POST and GET endpoints, targeting public and protected resources. + It attempts to inject malicious XSS payloads into input fields (via POST) and query parameters (via GET), + and then observes the responses for evidence of unescaped rendering or unsafe HTML reflection. + + Returns: + list: A list of prompts designed to detect XSS vulnerabilities. + """ + prompts = [] + endpoints = self.get_correct_endpoints_for_method("protected_endpoint", + "POST") + self.get_correct_endpoints_for_method( + "public_endpoint", "POST") + + for account in self.accounts: + + for post_endpoint, _, _ in endpoints: + schema = self.openapi_spec_parser.get_schema_for_endpoint(post_endpoint, "POST") + prompts = self.test_xss(post_endpoint, account, schema, prompts) + + endpoints = self.get_correct_endpoints_for_method("protected_endpoint", + "GET") + self.get_correct_endpoints_for_method( + "public_endpoint", "GET") + for get_endpoint, _, _ in endpoints: + if "id}" in get_endpoint: + get_endpoint = self.replace_placeholders_with_1(get_endpoint, account.get("id")) + + prompts = self.test_xss_query(get_endpoint, account, prompts) + + return prompts + + def generate_csrf_prompts(self): + """ + Generate prompts to test Cross-Site Request Forgery (CSRF) protection. + + This method tests if sensitive endpoints are protected from unauthorized or forged requests + by simulating actions like changing user data without a valid CSRF token. + It also checks cookie configurations to ensure proper CSRF defense mechanisms are in place. + + Returns: + list: A list of CSRF-related prompts covering POST and GET requests on sensitive endpoints. + """ + prompts = [] + endpoints = self.get_correct_endpoints_for_method("sensitive_data_endpoint", + "POST") + self.get_correct_endpoints_for_method( + "sensitive_data_endpoint", "POST") + for account in self.accounts: + for sensitive_action_endpoint in endpoints: + schema = sensitive_action_endpoint.get("schema") + prompts = self.test_csrf(sensitive_action_endpoint, schema, prompts, method="POST") + endpoints = self.get_correct_endpoints_for_method("sensitive_data_endpoint", + "GET") + self.get_correct_endpoints_for_method( + "sensitive_data_endpoint", "GET") + for sensitive_data_endpoint in endpoints: + if "id}" in sensitive_data_endpoint: + sensitive_data_endpoint = self.replace_placeholders_with_1(sensitive_data_endpoint, + account.get("id")) + prompts = self.test_csrf(sensitive_data_endpoint, None, prompts, method="GET") + + # This prompt tests if the API applies CSRF protection to GET requests that handle sensitive data. + + for login in self.login_endpoint: + + login_path, login_schema = self.get_path_and_schema(login) + if login_schema is None: + continue + if login_path: + prompts = self.test_cookies(login_path, login_schema, prompts, account=account) + + return prompts + + def generate_business_logic_vul_prompts(self): + """ + Generate prompts to test for business logic vulnerabilities. + + These include logic flaws like privilege escalation, incorrect role validation, + or bypassing user controls. The method targets both protected and public POST endpoints, + as well as sensitive GET endpoints and role-based POST operations. + + Returns: + list: A list of prompts to test business rules and role validation logic. + """ + prompts = [] + endpoints = self.get_correct_endpoints_for_method("protected_endpoint", + "POST") + self.get_correct_endpoints_for_method( + "public_endpoint", "POST") + for account in self.accounts: + + for endpoint, _, _ in endpoints: + schema = self.openapi_spec_parser.get_schema_for_endpoint(endpoint, "POST") + prompts = self.test_buisness_logic(endpoint, schema, account, prompts, method="POST") + + get_endpoints = self.get_correct_endpoints_for_method("sensitive_data_endpoint", "GET") + + for endpoint, _, _ in get_endpoints: + if "id}" in endpoint: + get_endpoint = self.replace_placeholders_with_1(endpoint, account.get("id")) + prompts = self.test_buisness_logic(endpoint, None, account, prompts, method="GET") + post_endpoints = self.get_correct_endpoints_for_method("role_access_endpoint", "POST") + + for post_endpoint, _, _ in post_endpoints: + schema = self.openapi_spec_parser.get_schema_for_endpoint(post_endpoint, "POST") + prompts = self.test_roles(post_endpoint, schema, account, prompts) + + return prompts + + def generate_rate_limit_throttling(self): + """ + Generate prompts to test rate limiting and throttling protections. + + This method simulates high-frequency requests to login and OTP endpoints + to detect absence of proper rate-limiting mechanisms, which could lead to + brute-force attacks or denial-of-service (DoS). + + Returns: + list: A list of prompts that test for request throttling, OTP abuse, and rate limits. + """ + prompts = [] + for login in self.login_endpoint: + + login_path, login_schema = self.get_path_and_schema(login) + request = [ + f"Send a large number of POST requests to {login_path} based on schema {login_schema} within a short timeframe."] * 10 + if login_path: + prompts = self.test_limiting(request, login_path, prompts) + + # OPT test cases + opt_endpoints = [ep for ep in self.endpoints if "otp" in ep] + + for opt_endpoint in opt_endpoints: + numbers = list(range(0000, 10000)) + + # Pick 10 random numbers from the list + random_numbers = random.sample(numbers, 10) + random_numbers = [1872] + random_numbers + for number in random_numbers: + otp = {"otp": f"{number}"} + + prompts = prompts + [ + + { + "objective": "Test for Unrestricted Resource Consumption", + "steps": [ + f"Send a POST request to trigger the sending of an OTP to a mobile phone number. This is intended to test the system's resilience against DoS attacks through resource exhaustion. Request should look like the following:{otp}" + ], + "token": [""], + # Assuming no authentication is required for this endpoint, or specify where needed + "path": [opt_endpoint], # Specify the actual endpoint for sending OTP + "expected_response_code": [ + "403 Forbidden" + ], + "security": [ + "Verify that msg is not '4 Digit OPT sent on mobile no.'", + + ] + } + , + { + "objective": "Testing for Unrestricted Resource Consumption", + "steps": [ + f"Send a POST requests to trigger the verify an OTP to a mobile phone number. Request should look like the following:{otp}" + ], + "token": [""], + # Assuming no authentication is required for this endpoint, or specify where needed + "path": [opt_endpoint], # Specify the actual endpoint for sending OTP + "expected_response_code": [ + "403 Forbidden" + ], + "security": [ + "Verify that request is okay", + + ] + } + ] + + prompts = self.test_throttling(prompts) + return prompts + + def generate_security_misconfiguration_prompts(self): + """ + Generate prompts for identifying security misconfigurations in the API. + + This method checks for: + - Missing or misconfigured security-related HTTP headers (e.g., CSP, X-Frame-Options). + - Insecure key exposure through endpoint patterns. + - Incorrect CORS settings that may lead to data leakage. + - Disclosure of sensitive debug or configuration information in error responses. + + Returns: + list: A list of prompts targeting potential security misconfiguration vulnerabilities. + """ + prompts = [] + + user_endpoints = [ep for ep in self.endpoints if ep.endswith("user")] + + for key_endpoint in user_endpoints: + for account in self.accounts: + if account["api"] in key_endpoint: + if "key" in account: + prompts = self.test_limiting("", key_endpoint, prompts, account.get("key")) + + endpoints = self.get_correct_endpoints_for_method("protected_endpoint", + "GET") + self.get_correct_endpoints_for_method( + "public_endpoint", "GET") + user_endpoints = [ep for ep in self.endpoints if "key" in ep] + endpoints = endpoints + user_endpoints + + for endpoint in endpoints: + if isinstance(endpoint, dict): + endpoint = endpoint.get("path") + for account in self.accounts: + id = account.get("id") + if f"{id}".isdigit() and "{id}" in endpoint: + endpoint = endpoint.replace("{id}", f"{id}") + prompts.append( + + # Check for Security Misconfigurations in Headers + { + "objective": "Check for Security Misconfigurations in Headers", + "steps": [ + f"Send a GET request to {endpoint}.", + + ], + "path": [endpoint], + "token": [account.get("token")], + "expected_response_code": [ + "200 OK with response headers correctly set to include security-related configurations such as X-Content-Type-Options: nosniff, X-Frame-Options: SAMEORIGIN, and Content-Security-Policy." + ], + "security": [ + "Inspect the response headers, including X-Content-Type-Options, X-Frame-Options, and Content-Security-Policy.", + "Verify if these headers are properly configured to mitigate security risks." + "Ensure that response headers are configured to prevent security vulnerabilities such as clickjacking, MIME type sniffing, and XSS.", + "Verify that no unnecessary or insecure headers are present in the response." + "Check if 'Access-Control-Allow-Origin: * Access-Control-Allow-Credentials: true' is set as these headers indicate that the CORS policy for the given request is misconfigured, allowing it to be invoked from a third-party web application to access the resource and obtain sensitive information like the user's ID, username, password, and auth key." + ] + } + + # This prompt tests if security headers are properly configured to prevent exposure to security risks. + ) + endpoints = self.get_correct_endpoints_for_method("secure_action_endpoint", + "POST") + for endpoint in endpoints: + for account in self.accounts: + id = account.get("id") + endpoint = endpoint.replace("{id}", f"{id}") + + schema = self.openapi_spec_parser.get_schema_for_endpoint(endpoint, "POST") + for account in self.accounts: + prompts.append( + # Inspect API for Debug or Error Information Exposure + { + "objective": "Inspect API for Debug or Error Information Exposure", + "steps": [ + f"Send an invalid POST request to {endpoint} based on schema {schema}.", + + ], + "token": [account.get("token")], + "path": [endpoint], + "expected_response_code": [ + "400 Bad Request or 403 Forbidden with a generic error message that does not disclose sensitive internal details." + ], + "security": [ + "Check if the response contains detailed error messages or debug information.", + "Verify that the response provides only minimal and generic error messages." + "Ensure that the API avoids exposing debug or configuration details in error responses.", + "Verify that sensitive system information is not included in error messages." + ] + } + + # This prompt tests if the API avoids exposing sensitive configuration or debug information. + ) + + return prompts + + def generate_logging_monitoring_prompts(self): + """ + Generate prompts to test whether the API performs proper logging and monitoring. + + This method simulates: + - Invalid request submissions to see if they are logged appropriately. + - Potentially malicious requests (e.g., SQL injection attempts). + - Logging of unexpected or unauthorized access attempts. + + It ensures that logs are meaningful, timely, and privacy-compliant without exposing sensitive data. + + Returns: + list: A list of prompts to verify proper logging and monitoring mechanisms. + """ + prompts = [] + endpoints = self.get_correct_endpoints_for_method("secure_action_endpoint", + "POST") + flag_endpoints = [ep for ep in self.endpoints if "flag" in ep] + endpoints = flag_endpoints + endpoints + + for endpoint in endpoints: + if isinstance(endpoint, dict): + endpoint = endpoint.get("path") + schema = self.openapi_spec_parser.get_schema_for_endpoint(endpoint, "POST") + for account in self.accounts: + prompts.append( + # Test Logging for Incorrect Requests + { + "objective": "Test Logging for Incorrect Requests", + "steps": [ + f"Send an invalid POST request to {endpoint} based on schema {schema}.", + ], + "path": [endpoint], + "token": [account.get("token")], + "expected_response_code": [ + "400 Bad Request or 403 Forbidden, with the API logging details of the failed attempt." + ], + "security": [ + "Verify if the API logs the details of the failed request, including timestamp, user information, and the nature of the error." + "Ensure that the API logs invalid requests comprehensively to assist in debugging and identifying potential misuse.", + "Verify that logs do not include sensitive user data or details beyond what is necessary for monitoring." + ] + } + + # This prompt tests if the API logs incorrect requests, including any errors encountered. + ) + endpoints = self.get_correct_endpoints_for_method("secure_data_endpoint", + "GET") + params = [{"username": "' OR 1=1 #", "password": "password"}, {"username": "'", "password": "password"}, + {"username": "admin", "password": "'"}] + + for endpoint in endpoints: + if isinstance(endpoint, dict): + endpoint = endpoint.get("path") + + for account in self.accounts: + if "id}" in endpoint: + endpoint = self.replace_placeholders_with_1(endpoint, account.get("id")) + for paramerts in params: + prompts.append( + # Test Logging for Potentially Malicious Requests + { + "objective": "Test Logging for Potentially Malicious Requests", + "steps": [ + f"Send a GET request to {endpoint} with unusual or malicious parameters, such as SQL injection attempts. Request body should look like this: {paramerts}", + ], + "token": [account.get("token")], + "path": [endpoint], + "expected_response_code": [ + "200 OK", + "400 Bad Request with the API logging the details of the suspicious request, including the malicious parameters and IP address." + ], + "security": [ + "Check if the API logs the suspicious request and flags it for further analysis or monitoring." + "Ensure that the API logs potentially malicious requests and flags them for review to mitigate security threats.", + "Verify that logs are monitored for patterns of abuse or repeated attacks." + ] + } + + # This prompt tests if the API logs and monitors potentially malicious requests to help detect and respond to security threats. + ) + return prompts + + def get_correct_endpoints_for_method(self, type_of_endpoint, method): + endpoints = [] + for type_ep in self.categorized_endpoints.keys(): + if type_of_endpoint == type_ep: + x = self.categorized_endpoints[type_of_endpoint] + if x is not None: + for entry in x: # Assuming x is a list of dictionaries + if entry.get('method') == method: + endpoints.append(entry) + return endpoints + + def generate_random_numbers(self, length=10): + + number = ''.join(str(random.randint(0, 9)) for _ in range(length)) + while number in self.available_numbers: + number = ''.join(str(random.randint(0, 9)) for _ in range(length)) + + self.available_numbers.append(number) + return number + + def get_credentials(self, schema, endpoint, new_user=False): + """ + Fill username and password fields in the provided schema. + + Args: + schema (dict): A schema dictionary containing an example. + username (str): The username to populate in the example. + password (str): The password to populate in the example. + + Returns: + dict: Updated schema with username and password fields filled. + """ + # Deep copy the schema to avoid modifying the original + updated_schema = copy.deepcopy(schema) + example = None + + if schema is not None: + if "example" in updated_schema.keys(): + updated_schema["example"] = self.fill_schema(updated_schema["example"]) + else: + updated_schema = self.adjust_schema_with_examples(updated_schema) + + if "example" in updated_schema: + example = updated_schema["example"] + if endpoint not in self.credentials.keys() or new_user: + + # Check if 'example' exists and is a dictionary + if updated_schema is not None and "example" in updated_schema.keys(): + example = updated_schema.get("example") + + if example is None: + example = {} + if "email" not in example or example["email"].startswith("{{"): + example['email'] = self.faker.email() + if "name" not in example or example["name"].startswith("{{"): + example["name"] = self.faker.name() + if "number" not in example: + if schema is not None and "properties" in schema.keys(): + example["number"] = int(self.generate_random_numbers()) + else: + example["number"] = 1 + if "username" in example and example["username"].startswith("{{"): + example["username"] = self.faker.user_name() + else: + if "email" in example and "{{" in example["email"]: + example["email"] = self.faker.email() + if "password" in example and "{{" in example["password"]: + password = self.faker.password(special_chars=False) + if "passwordRepeat" in example and "{{" in example["passwordRepeat"]: + example["passwordRepeat"] = password + example["password"] = password + if "number" in example: + if "{{" in example["number"] or "phone" in example["number"]: + example["number"] = int(self.generate_random_numbers()) + if "username" in example: + example["username"] = self.faker.user_name() + + if updated_schema is None: + updated_schema = {} + updated_schema["example"] = example + self.credentials[endpoint] = updated_schema + + else: + updated_schema = self.credentials[endpoint] + + return updated_schema + + def fill_schema(self, schema, params=None): + if params: + field_to_faker = params + else: + + field_to_faker = { + 'name': self.faker.name, + 'email': self.faker.email, + 'phone': self.faker.phone_number, + 'password': self.faker.password, + 'address': self.faker.address, + 'city': self.faker.city, + 'username': self.faker.user_name, + "old_email": "adam007@example.com", + "new_email": self.faker.email, + "price": -2000, + "number_of_repeats": 10000, + } + filled_schema = {} + if schema: + for key, value in schema.items(): + # Attempt to find a Faker provider for the key + + provider = field_to_faker.get(key) + if provider: + # If a provider is found, use it to generate fake data + if key == "password": + filled_schema[key] = self.faker.password(special_chars=False) + else: + if not callable(provider): + filled_schema[key] = self.faker.random_letters() + else: + filled_schema[key] = provider() + else: + # If no provider is found, revert to a default or keep the original value + filled_schema[key] = value + return filled_schema + + def set_login_schema(self, account, login_schema): + if "username" in login_schema.keys(): + if "username" in account.keys(): + login_schema["username"] = account["username"] + elif "email" in account.keys(): + login_schema["username"] = account["email"] + + if "password" in login_schema.keys(): + login_schema["password"] = account["password"] + + return login_schema + + def adjust_schema_with_examples(self, schema: dict) -> dict: + """ + Move 'example' values from each property into a separate 'example' dict. + """ + new_schema = schema.copy() + example_dict = {} + + if 'properties' in schema: + for field, field_props in schema['properties'].items(): + if 'example' in field_props: + example_dict[field] = field_props.pop('example') + if "properties" in field_props: + if field not in example_dict or not isinstance(example_dict[field], dict): + example_dict[field] = {} + + for field1, field_props1 in field_props['properties'].items(): + if 'example' in field_props1: + example_dict[field][field1] = field_props1.pop('example') + + + # Add collected examples + if example_dict: + new_schema['example'] = example_dict + + return new_schema + + def create_random_bearer_token(self, length=16): + """ + Generates a random token using hex encoding and prefixes it with "Bearer ". + :param length: Number of bytes for the random token (each byte becomes two hex characters). + :return: A string in the format "Bearer ". + """ + token_value = secrets.token_hex(length) + return f"{token_value}" + + def get_invalid_credentials(self, account): + invalid_account = {} + for values, keys in account.items(): + if isinstance(values, str): + invalid_account[keys] = values + "1" + elif values.isnumeric(): + invalid_account[keys] = values + 1 + else: + invalid_account[keys] = "_" + values + return invalid_account + + def create_account(self, login_schema, login_path): + account = self.get_credentials(login_schema, login_path, new_user=True) + if account is None: + return None + if "example" in account: + account = account.get("example") + account["x"] = self.counter + parts = [api for api in login_path.split("/") if api.strip()] + api = [part for part in parts if part.startswith("api")] + if len(api) > 0: + api = api[0] + account["api"] = api + return account + + def assign_brute_force_endpoints(self, admin): + file = self.config.get("password_file") + if file == "Not found": + return + else: + password_list_path = self.get_file(file)[0] + + # Open the password list file + with open(password_list_path, "r") as file: + passwords = file.readlines() + + # Strip any extra whitespace characters (newlines, spaces) + passwords = [password.strip() for password in passwords] + + # Start brute-force attack + for password in passwords: + # Create the data for the POST request + data = { + 'username': admin, + 'password': password + } + self.brute_force_accounts.append(data) + + def get_comment_ep(self): + for ep in self.endpoints: + if "comment" in ep: + return ep + + def fill_in_schema_with_account(self, login_schema, account): + if login_schema is not None: + if "token" in login_schema and "token" in account: + login_schema["token"] = account["token"] + if "email" in login_schema and "email" in account: + login_schema["email"] = account["email"] + if "password" in login_schema and "password" in account: + login_schema["password"] = account["password"] + if "username" in login_schema and "username" in account: + login_schema["username"] = account["username"] + return login_schema + + def mechanic_report(self, endpoint, account, prompts): + if "mechanic_report" in endpoint: + users = list(range(1, 100)) + random_numbers = random.sample(users, 10) + random_numbers = [5] + random_numbers + + for i in random_numbers: + new_endpoint = endpoint + f"?report_id={i}" + prompts.append( + + # Access Control to Specific Resources + { + "objective": "Access Control to Specific Resource of other used", + "steps": [ + f"Attempting Access with Non-Owner's Credentials - Send a GET request to {new_endpoint} using an authentication token for User B (a different user who does not own the resource). user:{account}.\n", + ], + "path": [new_endpoint], + "expected_response_code": [ + "403 Forbidden or 404 Not Found when accessed by a non-owner, confirming that access is appropriately restricted.", + + ], + "token": [account.get("token")], + "security": [ + "Ensure robust validation of user credentials and proper enforcement of access controls. Verify that the system correctly identifies ownership and prevents unauthorized access based on user authentication and authorization levels." + ] + } + + # This prompt tests access control for resources that are user-owned by verifying: + # - Only the resource owner can access the data (User A) + # - Non-owners receive a restricted response (User B) + # - Unauthenticated requests are denied access (if applicable) + ) + return prompts + + def random_common_users(self, endpoint, login_path, login_schema, prompts): + + random_entries = self.df.sample(n=10, + random_state=42) # Adjust random_state for different samples + + for index, random_entry in random_entries.iterrows(): + username = random_entry['username'] + password = random_entry['password'] + # Now you can print or use username and password as needed + common_account = {"email": username, "password": password} + prompts = prompts + [{ + "objective": "Attempt login with common username and password", + "steps": [ + f"Attempt to send a POST request to the login endpoint {login_path} using common login credentials {login_schema} with user:{common_account}.\n", ], + "path": [login_path], + "expected_response_code": [ + "401 Unauthorized when attempting to authenticate with invalid credentials.", + ], + "token": [""], + "security": [ + "Ensure that the system rejects invalid credentials with a 401 Unauthorized status.", + "Ensure no sensitive information is leaked in the error responses, such as specific reasons for the failure." + ] + }] + username = "savanna48@ortiz.com" + password = "zTyBwV/9" + common_account = {"email": username, "password": password} + prompts = prompts + [{ + "objective": "Attempt login with common username and password", + "steps": [ + f"Attempt to send a POST request to the login endpoint {login_path} using common login credentials {login_schema} with user:{common_account}.\n", + ], + "path": [login_path], + "expected_response_code": [ + "401 Unauthorized when attempting to authenticate with invalid credentials.", + ], + "token": [""], + "security": [ + "Ensure that the system rejects invalid credentials with a 401 Unauthorized status.", + "Ensure no sensitive information is leaked in the error responses, such as specific reasons for the failure." + ] + }] + + return prompts + + def resource_prompts(self, endpoint, account, prompts): + key_found, key = self.key_in_path(endpoint, self.resources) + if key_found: + for resource in self.resources[key]: + new_endpoint = endpoint.replace("{", "").replace("}", "").replace(key, resource) + + prompts.append( + + # Access Control to Specific Resources + { + "objective": "Access Control to Specific Resources", + "steps": [ + f"Accessing Resource with Owner's Credentials - Send a GET request to {endpoint} using an authentication token for User A (a different user who does not own the resource). user:{account}.\n", + f"Attempting Access with Non-Owner's Credentials - Send a GET request to {endpoint} using an authentication token for User B (a different user who does not own the resource). user:{account}.\n", + f"Attempting Access without Authentication (if applicable) - Send a GET request to {endpoint} without any authentication token. Expect a 401 Unauthorized response, confirming that unauthenticated access is restricted." + ], + "path": [new_endpoint, new_endpoint, new_endpoint], + "expected_response_code": [ + "403 Forbidden or 404 Not Found when accessed by a non-owner, confirming that access is appropriately restricted.", + "403 Forbidden or 404 Not Found when accessed by a non-owner, confirming that access is appropriately restricted.", + "401 Unauthorized for unauthenticated requests, ensuring that authentication is required for access." + ], + "token": [account.get("token"), self.create_random_bearer_token(), ""], + "security": [ + "Ensure robust validation of user credentials and proper enforcement of access controls. Verify that the system correctly identifies ownership and prevents unauthorized access based on user authentication and authorization levels." + ]} + + # This prompt tests access control for resources that are user-owned by verifying: + # - Only the resource owner can access the data (User A) + # - Non-owners receive a restricted response (User B) + # - Unauthenticated requests are denied access (if applicable) + ) + return prompts + + def resource_endpoints(self, prompts): + resource_endpoints = [] + for ep in self.endpoints: + if "_id" not in ep and ("posts" and "recent" in ep or "mechanic" and "workshop" in ep or "dashboard" in ep): + resource_endpoints.append(ep) + + resource_endpoints = list(set(resource_endpoints)) + + for resource_endpoint in resource_endpoints: + for account in self.accounts: + token = account.get("token") + prompts = prompts + [ + { + "objective": "Get other data", + "steps": [ + f"Send a GET request to the {resource_endpoint} with the token {token}.\n" + ], + "path": [resource_endpoint], + "token": [account.get("token")], + "expected_response_code": ["200 OK", "201 Created"], + "security": [ + "Ensure that the correct expected code is received confirming that access is restricted and no sensitive data is included in the response body."] + } + + ] + return prompts + + def rbac(self, old_endpoint, account, prompts, id=None): + endpoint = copy.deepcopy(old_endpoint) + + ids = [35] + list(range(1, 10)) + if id: + ids = [id] + ids + + endpoint_and_new_endpoint = [] + for i in ids: + token = "" if "login" not in endpoint else account.get("token") + endpoint, endpoint_of_other_user = self.replace_ids(account, endpoint, given_id=i) + if (endpoint, endpoint_of_other_user) in endpoint_and_new_endpoint: + continue + endpoint_and_new_endpoint.append((endpoint, endpoint_of_other_user)) + + prompts = prompts + [ + + # Verify Role-Based Access Control (RBAC) + + { + "objective": "Verify Role-Based Access Control (RBAC)", + "steps": [ + # f"send a GET request to {endpoint} with an 'admin' role authentication token in the Authorization header. Expect a 200 OK response, indicating access to all resources. Verify that the response data includes complete access to the endpoint, reflecting the admin’s permissions.", + f"send a GET request to {endpoint} with a 'user' role authentication token in the Authorization header. Expect a 200 OK response, but confirm that only authorized resources or data for a regular user are included in the response. Verify that any restricted admin-only resources are not accessible to the user role. user:{account}.\n", + # f"send a GET request to {endpoint} with a 'guest' role authentication token (or without an authentication token if guests are unauthenticated). Expect either a 403 Forbidden or a 404 Not Found response if guests are not permitted to access the endpoint, or a 200 OK with limited data if some resources are publicly accessible to guests." + ], + "path": [endpoint, ], + "expected_response_code": [ + # "200 OK for admin, confirming full access.", + "200 OK for users, confirming access is limited to non-admin resources.", + # "403 Forbidden or 404 Not Found for guests if no access is granted, or 200 OK if limited access is applicable." + ], + "token": [ # self.admin.get("token"), + token, + # self.guest.get("token") + ], + "security": [ + "Verify that the response data includes complete access to the endpoint, reflecting the admin’s permissions.", + "Verify that any restricted admin-only resources are not accessible to the user role.", + "Verify that guest role has no or limited access."], + + }, + # Access Control to Specific Resources + { + "objective": "Access Control to Specific Resources", + "steps": [ + f"Accessing Resource with Owner's Credentials - Send a GET request to {endpoint} using an authentication token for User A (the owner of the resource). Expect a 200 OK response, confirming that User A can access their own data. Verify that the response includes the resource data as expected.user:{account}.\n", + f"Attempting Access with Non-Owner's Credentials - Send a GET request to {endpoint_of_other_user} using an authentication token for User B (a different user who does not own the resource). Expect a 403 Forbidden or 404 Not Found response, indicating that access is denied to resources not owned by the authenticated user.user:{account}.\n", + f"Attempting Access without Authentication (if applicable) - Send a GET request to {endpoint} without any authentication token. Expect a 401 Unauthorized response, confirming that unauthenticated access is restricted.user:{account}.\n" + ], + "path": [endpoint, endpoint_of_other_user, endpoint], + "expected_response_code": [ + "200 OK when accessed by the owner, confirming correct owner access.", + "403 Forbidden or 404 Not Found when accessed by a non-owner, confirming that access is appropriately restricted.", + "401 Unauthorized for unauthenticated requests, ensuring that authentication is required for access." + ], + "token": [account.get("token"), self.create_random_bearer_token(), ""], + "security": [ + "Ensure robust validation of user credentials and proper enforcement of access controls. Verify that the system correctly identifies ownership and prevents unauthorized access based on user authentication and authorization levels." + ] + } + + # This prompt tests the RBAC implementation by confirming that each user role has the appropriate access level: + # - Admin has full access + # - User has restricted access + # - Guest has no or limited access. + ] + + endpoint = old_endpoint + + return prompts + + def verify_data_masking(self, endpoint, account, prompts): + prompts = prompts + [ + + { + "objective": "Verify Data Masking", + "steps": [ + f"Restricted Access Verification - Send a GET request to {endpoint} with a user who has restricted access, verifying that only permissible fields are included in the response.user:{account}.\n", + ], + "expected_response_code": [ + "200 OK for responses that correctly implement data masking, ensuring that data visibility is restricted according to user permissions." + ], + "path": [endpoint], + "token": [account.get("token")], + "security": [ + " Confirm that sensitive data is either masked or excluded, ensuring that only authorized fields are visible to the user. Ensure that the system effectively masks or excludes sensitive data for users with restricted access. Verify that data masking is applied correctly and consistently across all applicable endpoints to maintain confidentiality and compliance with data protection regulations." + ] + } + + ] + return prompts + + def test_create(self, endpoint, account, prompts, post_endpoint): + post_schema = post_endpoint.get("schema") + post_schema = self.fill_schema(post_schema) + + prompts = prompts + [ # Check Permissions for CRUD Operations + + # Create Operation: + + { + "objective": "Check Permissions for CRUD Operations: Create", + "steps": [ + f"Authorized Creation - Send a POST request to {endpoint} with valid data to create a new resource like {post_schema}, using a user with creation permissions. Expect a 201 Created response if successful.", + f"Unauthorized Creation - Then, attempt to send a similar request to {endpoint} like {post_schema} with a user who lacks creation permissions, expecting a 403 Forbidden or similar error response." + ], + "path": [endpoint, endpoint], + "expected_response_code": [ + "201 Created for successful creation by authorized users, confirming that the system allows creation as expected.", + "403 Forbidden for unauthorized creation attempts, confirming that the system properly restricts creation permissions." + ], + "token": [account.get("token"), account.get("token")], + "security": [ + "Ensure that the system robustly validates user permissions before processing CRUD operations. " + "Confirm that only users with the appropriate permissions can create resources, and that unauthorized attempts are correctly identified and denied."] + } + ] + + return prompts + + def test_get(self, endpoint, account, prompts): + + get_endpoint, endpoint_of_other_user = self.replace_ids(account, endpoint) + + prompts = prompts + [ + + # Read Operation: + + { + "objective": "Check Permissions for CRUD Operations: Read", + "steps": [ + f"Authorized Read - Send a GET request to {get_endpoint} with a user who has read permissions, verifying that the correct data is returned.", + f"Unauthorized Read - Attempt the same request to {endpoint_of_other_user} with a user who lacks read permissions, and verify that the response returns a 403 Forbidden or 404 Not Found status." + ], + "path": [get_endpoint, get_endpoint], + "expected_response_code": [ + "200 OK for successful data retrieval by authorized users, confirming that the system allows reading of data as expected.", + "403 Forbidden or 404 Not Found for unauthorized read attempts, confirming that the system properly restricts reading permissions." + ], + "token": [account.get("token"), account.get("token")], + + "security": [ + "Ensure that the system robustly validates user permissions before allowing access to read operations. Confirm that only users with the appropriate permissions can access data, and that unauthorized attempts are correctly identified and denied, preventing data leaks."] + }] + + + return prompts + + def test_put(self, put_endoint, account, prompts): + endpoint = put_endoint.get("path") + put_endoint_schema = put_endoint.get("schema") + put_endoint, endpoint_of_other_user = self.replace_ids(account, endpoint) + prompts = prompts + [ + + # Update Operation: + + { + "objective": "Check Permissions for CRUD Operations: Update", + "steps": [ + f"Authorized Update - Send a PUT or PATCH request to {put_endoint} to update a resource using a user with update permissions. Expect the resource to be modified, indicated by a 200 OK or 204 No Content response schema:{put_endoint_schema}.", + f"Unauthorized Update - Then, repeat the request with a user to {endpoint_of_other_user}who lacks update permissions, expecting a 403 Forbidden response." + ], + "path": [put_endoint, put_endoint], + "token": [account.get("token"), account.get("token")], + + "expected_response_code": [ + "200 OK or 204 No Content for successful modification by authorized users, confirming that the system allows updates as expected.", + "403 Forbidden for unauthorized update attempts, confirming that the system properly restricts updating permissions." + ], + "security": [ + "Ensure that the system robustly validates user permissions before allowing access to update operations. Confirm that only users with the appropriate permissions can modify resources, and that unauthorized attempts are correctly identified and denied, preventing unauthorized data modification."]} + + ] + return prompts + + def test_delete(self, endpoint, account, prompts): + delete_endpoint, endpoint_of_other_user = self.replace_ids(account, endpoint) + + prompts = prompts + [ + + # Delete Operation: + + { + "objective": "Check Permissions for CRUD Operations: Delete", + "steps": [ + f"Authorized Deletion - Send a DELETE request to {delete_endpoint} with a user who has delete permissions, verifying a successful deletion with a 200 OK or 204 No Content response.", + f"Unauthorized Deletion - Then, attempt to delete the resource with a user {endpoint_of_other_user} who lacks delete permissions, expecting a 403 Forbidden response." + ], + "token": [account.get("token"), account.get("token")], + "path": [delete_endpoint, delete_endpoint], + "expected_response_code": [ + "200 OK or 204 No Content for successful deletion by authorized users, confirming that the system allows deletions as expected.", + "403 Forbidden for unauthorized deletion attempts, confirming that the system properly restricts deleting permissions." + ], + "security": [ + "Ensure that the system robustly validates user permissions before allowing access to delete operations. Confirm that only users with the appropriate permissions can delete resources, and that unauthorized attempts are correctly identified and denied, preventing unauthorized resource removal."] + }] + + return prompts + + def test_access_scope(self, get_endpoint, post_endpoint, put_endpoint, account, prompts): + if isinstance(get_endpoint, dict): + get_endpoint = get_endpoint.get("path") + if isinstance(post_endpoint, dict): + post_endpoint = post_endpoint.get("path") + if isinstance(put_endpoint, dict): + put_endoint = put_endpoint.get("path") + if "api" in get_endpoint and post_endpoint and put_endpoint: + if account["api"] in get_endpoint and account["api"] in post_endpoint and account[ + "api"] in put_endpoint: + id = account.get("id") + get_endpoint = get_endpoint.replace("{id}", str(account.get("id"))) + post_endoint = post_endpoint.replace("{id}", str(account.get("id"))) + put_endoint = put_endoint.replace("{id}", str(account.get("id"))) + prompts = prompts + [ + + # Access Token Scope Testing + { + "objective": "Access Token Scope Testing", + "steps": [ + f"Testing Read-Only Scope for Data Retrieval - Send a GET request to {get_endpoint} using a read-only access token. Verify that the response status is 200 OK, allowing data retrieval.", + f"Attempting Write Operation with Read-Only Token - Attempt to modify data on {put_endoint} by sending a POST, PUT, or DELETE request using the same read-only access token.", + f"Testing Full-Access Token (if applicable) - Send a POST request to {post_endoint} using a full-access token to verify that write actions are permitted." + ], + "token": [account.get("token"), self.create_random_bearer_token(), + account.get("token")], + "path": [get_endpoint, put_endoint, post_endoint], + + "expected_response_code": [ + "200 OK for successful data retrieval using a read-only token, confirming the enforcement of read-only access.", + "403 Forbidden for attempted write operations with a read-only token, confirming that the token scope correctly restricts write actions.", + "200 OK or 201 Created for successful write actions using a full-access token, confirming that full-access privileges are appropriately granted." + ], + "security": [ + "Ensure that the a A read-only access token permits data retrieval (GET request).", + "The same read-only token denies access to write operations (POST, PUT, DELETE requests).", + "A full-access token (if applicable) allows write actions, validating proper enforcement of token scopes."] + } + ] + + else: + if "id}" in get_endpoint: + get_endpoint = self.replace_placeholders_with_1(get_endpoint, account.get("id")) + prompts = prompts + [ + + # Access Token Scope Testing + { + "objective": "Access Token Scope Testing", + "steps": [ + f"Testing Read-Only Scope for Data Retrieval - Send a GET request to {get_endpoint} using a read-only access token. Verify that the response status is 200 OK, allowing data retrieval.", + f"Attempting Write Operation with Read-Only Token - Attempt to modify data on {put_endpoint} by sending a POST, PUT, or DELETE request using the same read-only access token.", + f"Testing Full-Access Token (if applicable) - Send a POST request to {post_endpoint} using a full-access token to verify that write actions are permitted." + ], + "token": [account.get("token"), self.create_random_bearer_token(), + account.get("token")], + "path": [get_endpoint, put_endpoint, post_endpoint], + + "expected_response_code": [ + "200 OK for successful data retrieval using a read-only token, confirming the enforcement of read-only access.", + "403 Forbidden for attempted write operations with a read-only token, confirming that the token scope correctly restricts write actions.", + "200 OK or 201 Created for successful write actions using a full-access token, confirming that full-access privileges are appropriately granted." + ], + "security": [ + "Ensure that the a A read-only access token permits data retrieval (GET request).", + "The same read-only token denies access to write operations (POST, PUT, DELETE requests).", + "A full-access token (if applicable) allows write actions, validating proper enforcement of token scopes."] + } + ] + return prompts + + # This prompt tests that: + # - A read-only access token permits data retrieval (GET request). + # - The same read-only token denies access to write operations (POST, PUT, DELETE requests). + # - A full-access token (if applicable) allows write actions, validating proper enforcement of token scopes. + def replace_id_placeholder(self, path: str, value: str) -> str: + return re.sub(r'\{[a-zA-Z0-9_]*_id\}', value, path) + + def test_authentication(self, endpoint, account, prompts): + if "api" in endpoint: + + new_endpoint, endpoint_of_other_user = self.replace_ids(account, endpoint) + + else: # Non-API endpoints + new_endpoint, endpoint_of_other_user = self.replace_ids(account, endpoint) + prompts = prompts + [ # Check Access Without Authentication + { + "objective": "Check Access Without Authentication", + "steps": [ + f"make a GET request to the protected endpoint {new_endpoint} without including any authentication headers with user:{account}.\n"], + "expected_response_code": ["401 Unauthorized", "403 Forbidden."], + "token": [""], + "path": [new_endpoint], + "security": [ + "Ensure that the correct expected code is received confirming that access is restricted and no sensitive data is included in the response body."]}, + # Test Valid Authentication + { + "objective": "Test Valid Authentication", + "steps": [ + f"use this accounts token to send a GET request to {new_endpoint}."], + "expected_response_code": [ + "200 OK response, indicating successful authentication." + ], + "path": [new_endpoint], + "token": [account.get("token")], + "security": ["Verify that the response includes the expected data.", + "Confirm that the token provides the correct access level."] + + }, + # Test only can access my user + { + "objective": "Test cannot access other user data", + "steps": [ + f"use this accounts token to send a GET request to {endpoint_of_other_user}."], + "expected_response_code": ["401 Unauthorized", "403 Forbidden."], + "token": [account.get("token")], + "path": [new_endpoint], + "security": [ + "Ensure that the correct expected code is received confirming that access is restricted and no sensitive data is included in the response body."]}, + ] + + return prompts + + def test_token(self, login_path, endpoint, account, login_schema, prompts, revocation=False): + if revocation and endpoint is not None: + + prompts.append( + + # Token Revocation + { + "objective": "Token Revocation", + "steps": [ + f"Send a GET request to {endpoint} with a valid token = {account.get('token')}.", + f"Send a POST request to the token revocation endpoint {login_path} to invalidate an active token with user credentials:{account}", + f"Test the revoked token by sending a GET request to {endpoint} with the revoked token." + ], + "expected_response_code": [ + "200 OK for the initial use of the valid token, confirming it is active.", + "200 OK or 204 No Content, Successful revocation " + "401 Unauthorized for attempts made with the revoked token, confirming the system's capability to recognize and deny access with revoked credentials." + ], + "token": [account.get("token"), "", account.get("token")], + "path": [endpoint, login_path, endpoint], + "security": + [ + "Ensure that valid tokens are correctly recognized and granted access, confirming that authentication systems are functioning as expected before revocation.", + "Verify that once tokens are revoked, the system immediately denies access, preventing any security breaches that could occur from the continued use of revoked tokens. This includes ensuring timely updates across all components of the system that might interact with token validation processes."] + + } + + ) # protected end point needed + else: + prompts = prompts + [ + { # Test Expired or Invalid Tokens + "objective": "Test Expired or Invalid Tokens", + "steps": [ + f"Send a POST request to the login endpoint {login_path} with valid login credentials ({login_schema}) to obtain a valid token with user:{account}.\n", + f"Send a GET request to {endpoint} using an expired, revoked, or otherwise invalid token." + ], + "path": [login_path, endpoint], + "expected_response_code": [ + "200 OK for successfully obtaining a valid token with correct credentials.", + "401 Unauthorized for the expired or invalid token, verifying that the system rejects tokens that are no longer valid." + ], + "token": ["", self.create_random_bearer_token()], + "security": [ + "Ensure that the system appropriately validates token expiration or invalidation before granting access.", + "Ensure that no sensitive information is included in the error responses." + ] + }, + # Basic Authentication (if applicable) + { + "objective": "Valid Basic Authentication", + "steps": [ + f"Send a GET request to {endpoint} using Basic Authentication with the valid token {self.valid_token}." + ], + "path": [endpoint], + "expected_response_code": [ + "200 OK when sending a GET request with correct credentials, confirming access is granted." + ], + "token": [account.get("token")], + "security": [ + "Ensure the system returns a 200 OK response for valid credentials and does not expose sensitive data in the response body." + ] + }, + { + "objective": "Invalid Basic Authentication", + "steps": [ + f"Attempt to send a POST request to the login endpoint {login_path} using invalid login credentials {login_schema} with user:{account}.\n", + f"Send a GET request to {endpoint} using Basic Authentication with invalid credentials:{self.get_invalid_credentials(account)}" + ], + "path": [login_path, endpoint], + "expected_response_code": [ + "401 Unauthorized when attempting to authenticate with invalid credentials.", + "401 Unauthorized when attempting to access the protected endpoint with invalid credentials." + ], + "token": [account.get("token"), account.get("token")], + "security": [ + "Ensure that the system rejects invalid credentials with a 401 Unauthorized status.", + "Ensure no sensitive information is leaked in the error responses, such as specific reasons for the failure." + ] + } + , + + ] + + return prompts + + def test_refresh_token(self, refresh_post_endpoint, refresh_get_endpoint, account, prompts): + prompts = prompts + [ # Test Token Refresh (if applicable) + + { + "objective": "Test Token Refresh", + "steps": [ + f"send a GET request to {refresh_get_endpoint} with the expired token in the Authorization header. Verify that the API responds with a 401 Unauthorized status, indicating the token has expired.", + f"send a POST request to the token refresh endpoint {refresh_post_endpoint} with the valid refresh token in the request body or headers, depending on the API's token refresh requirements. Check if the API responds with a 200 OK status and includes a new access token in the response body.", + f"use the new access token to send a GET request to {refresh_get_endpoint} again. Confirm that the API responds with a 200 OK status, indicating successful access with the refreshed token, and that the old expired token is no longer valid." + ], + "path": [refresh_get_endpoint, refresh_get_endpoint, refresh_get_endpoint], + "token": [self.create_random_bearer_token(), + account.get("token"), + account.get("token")], + "expected_response_code": [ + "401 Unauthorized for the expired token use, verifying that the token has indeed expired and is recognized by the system as such.", + "200 OK upon refreshing the token, confirming that the refresh mechanism works as expected and a new token is issued correctly.", + "200 OK when using the new token, verifying that the new token grants access and the old token is invalidated." + ], + "security": [ + "Ensure that the API does not leak sensitive information in error responses and that expired tokens are promptly invalidated to prevent unauthorized use."] + } + + # This prompt tests if the API correctly handles token expiration and issues a new token upon refresh, + # while ensuring that the expired token no longer provides access to protected resources. + + ] + return prompts + + def test_crud(self, endpoints, prompts): + post_endpoints = self.get_correct_endpoints_for_method("protected_endpoint", "POST") + delete_endpoints = self.get_correct_endpoints_for_method("protected_endpoint", "DELETE") + put_endpoints = self.get_correct_endpoints_for_method("protected_endpoint", "PUT") + + for account in self.accounts: + + if "id" in account.keys(): + id = account.get("id") + else: + id = 1 + + for post_endpoint in post_endpoints: + + if "api" in post_endpoint and len(post_endpoint.split("/")) > 0: + if account["api"] in post_endpoint: + endpoint = post_endpoint.replace("{id}", str(account.get("id"))) + prompts = self.test_create(endpoint, account, prompts, post_endpoint=post_endpoint) + + else: + prompts = self.test_create(post_endpoint.get("path"), account, prompts, + post_endpoint=post_endpoint) + else: + prompts = self.test_create(post_endpoint.get("path"), account, prompts, post_endpoint) + + for get_endpoint in endpoints: + if isinstance(get_endpoint, dict): + get_endpoint = get_endpoint.get("path") + + if "api" in get_endpoint and "id" in account.keys(): + if account["api"] in get_endpoint and isinstance(account["id"], int): + + prompts = self.test_get(get_endpoint, account, prompts) + + else: + prompts = self.test_get(get_endpoint, account, prompts) + + else: + if "{id}" in get_endpoint: + + id = 1 + if "OWASP" in self.config.get("name") and "basket" not in get_endpoint: + continue + + prompts = self.test_get(get_endpoint, account, prompts) + + for put_endoint in put_endpoints: + if "api" in put_endoint: + if account["api"] in put_endoint: + prompts = self.test_put(put_endoint, account, prompts) + else: + + prompts = self.test_put(put_endoint, account, prompts) + + for delete_endpoint in delete_endpoints: + id = 2 + if isinstance(delete_endpoint, dict): + delete_endpoint_schema = delete_endpoint.get("schema") + delete_endpoint = delete_endpoint.get("path") + if "api" in delete_endpoint and "id" in delete_endpoint: + prompts = self.test_delete(delete_endpoint, account, prompts) + else: + prompts = self.test_delete(delete_endpoint, account, prompts) + + for get_endpoint in endpoints: + for post_endpoint in post_endpoints: + for put_endpoint in put_endpoints: + prompts = self.test_access_scope(get_endpoint, post_endpoint, put_endpoint, account, prompts) + + return prompts + + def test_injection_bender_account(self, login_path, prompts): + injection_account = None + bender_account = None + if "OWASP" in self.config.get("name"): + injection_account = {"email": "' or 1=1--", "password": "a", "x": self.counter} + self.counter = self.counter + 1 + bender_account = self.config.get("bender") + bender_account["x"] = self.counter + 1 + if injection_account is not None: + prompts.append({ + "objective": "SQL Injection Attack on Login (Administrator Account)", + "steps": [ + f"Send a POST request to {login_path} with user:{injection_account}.\n" + ], + "token": [""], + "path": [login_path], + "expected_response_code": [ + "400", "404", "403" + ], + "security": [ + "Ensure that input fields like the login form are properly sanitized to prevent SQL injection attacks.", + "Implement prepared statements or parameterized queries to eliminate the possibility of SQL injection exploitation.", + "Verify that authentication systems prevent unauthorized login by enforcing strict checks on input data." + ] + } + + ) + if bender_account is not None: + prompts.append({ + "objective": "SQL Injection Attack on Login (Bender Account)", + "steps": [ + f"Send a POST request to {login_path} with user:{bender_account}.\n" + ], + "token": [""], + "path": [login_path], + "expected_response_code": [ + "400", "404", "403" + ], + "security": [ + "Ensure that input fields like the login form are properly sanitized to prevent SQL injection attacks.", + "Implement prepared statements or parameterized queries to eliminate the possibility of SQL injection exploitation.", + "Verify that authentication systems prevent unauthorized login by enforcing strict checks on input data." + ] + } + ) + return prompts + + def test_brute_force(self, login_path, login_schema, prompts, number=100): + if number == 10: + + str = [ + f"Attempt Login with Invalid Credentials - Send repeated POST requests to the login endpoint {login_path} using invalid credentials (e.g., incorrect password) {self.get_credentials(login_schema, login_path)}. " + + f"Simulate a high rate of attempts (e.g., 10 requests within 1 minute) to trigger potential brute force protection mechanisms. "] * 10 + + prompts.append( # Brute Force Protection + { + "objective": "Brute Force Protection", + "steps": str, + "expected_response_code": + [[ + "429 Too Many Requests to indicate effective rate-limiting measures against brute force attacks.", + "403 Forbidden as an alternative response indicating account lockout or other security measures preventing further attempts." + ]] * 10, + "token": [""] * 10, + "path": [login_path] * 10, + "security": + ["Ensure that user cannot login with invalid credentials.\n" + + "Ensure that rate-limiting mechanisms are robust and properly configured to prevent brute force attacks by limiting the number of allowed failed attempts within a given time frame. This prevents attackers from trying a large number of combinations in a short period.\n" + + "Check that account lockout mechanisms or other access denial policies are effective in disabling further login attempts after a certain number of failures, protecting against continuous brute force attempts and securing user accounts from unauthorized access."] + + } + ) + else: + if len(self.brute_force_accounts) > 0: + for acc_b in self.brute_force_accounts: + acc_b["x"] = 100 + str = [ + f"Attempt Login with different Credentials - Send repeated POST requests to the login endpoint {login_path} using invalid credentials (e.g., incorrect password) user:{acc_b}.\n" + + f"Simulate a high rate of attempts (e.g., 10 requests within 1 minute) to trigger potential brute force protection mechanisms. "] * 10 + + prompts.append( # Brute Force Protection + { + "objective": "Brute Force Attack", + "steps": str, + "expected_response_code": + [ + "429 Too Many Requests to indicate effective rate-limiting measures against brute force attacks.", + "403 Forbidden as an alternative response indicating account lockout or other security measures preventing further attempts." + ] * 10, + "token": [""] * 10, + "path": [login_path] * 10, + "security": + ["Ensure that user cannot login with invalid credentials.\n" + + "Ensure that rate-limiting mechanisms are robust and properly configured to prevent brute force attacks by limiting the number of allowed failed attempts within a given time frame. This prevents attackers from trying a large number of combinations in a short period.\n" + + "Check that account lockout mechanisms or other access denial policies are effective in disabling further login attempts after a certain number of failures, protecting against continuous brute force attempts and securing user accounts from unauthorized access."] + + } + ) + + change_email = self.extract_change_email_endpoints_with_schema() + for endpoint, schema in change_email: + for account in self.accounts: + if isinstance(endpoint, dict): + schema = endpoint.get("schema") + schema = self.fill_schema(schema) + else: + schema = self.fill_schema(schema) + prompts.append( # Insecure Email Change Process + { + "objective": "Insecure Email Change Process", + "steps": [f"Send a POST request to endpoint {endpoint} with schema {schema}"], + "expected_response_code": + [ + "403 Forbidden if proper authorization checks prevent changing another user's email." + ], + "token": [account.get("token")], + "path": [endpoint], + "security": + [ + "Ensure that users cannot change the email address of other users without proper authorization.\n" + + "Verify that sensitive actions like email changes send verification tokens only to the current, verified email address of the account owner.\n" + + "Check for IDOR vulnerabilities and enforce proper ownership validation when modifying sensitive user information."] + } + ) + + otp_endpoints = [ep for ep in self.endpoints if "otp"] + reset_password = [ep for ep in self.endpoints if + "reset" or "password" or "forget-password" or "forgot" in ep] + otps = list(range(0000, 9999)) + random_otps = [8859] + random.sample(otps, 10) + otp_calls = "" + counter = 0 + otps = [] + for (otp, otp_endpoint) in zip(random_otps, otp_endpoints): + otp_request = {"email": "adam007@example.com", "otp": f'{otp}', "password": "Az123###"} + + otp_calls = otp_calls + f"Send a POST request to endpoint {otp_endpoint} with the following body {otp_request}.\n" + counter += 1 + otps.append(otp) + + for reset in reset_password: + user = {"email": "adam007@example.com"} + # triggers otp being sent + new_otp_calls = f"Send a POST request to endpoint{reset} for victim user:{user}.\n" + otp_calls + prompts.append( # OTP Bruteforce Protection + { + "objective": "OTP Bruteforce Protection", + "steps": [new_otp_calls], + "expected_response_code": + [ + "429 Too Many Requests indicating effective rate-limiting against OTP bruteforce attempts.", + "403 Forbidden indicating account lockout or blocking after multiple failed OTP attempts." + ] * counter, + "token": [""] * counter, + "path": [reset] + otps, + "security": + [ + "Ensure that OTP verification mechanisms enforce rate-limiting to prevent brute-force attacks.\n" + + "Verify that deprecated or older API versions do not bypass security controls.\n" + + "Ensure consistent security protections across all API versions to prevent attackers from exploiting outdated endpoints."] + } + ) + + return prompts + + def extract_change_email_endpoints_with_schema(self) -> list[tuple]: + """ + Extracts all endpoints containing 'change-email' and returns (path, schema) pairs. + + Returns: + list: List of tuples (path, schema) for change-email endpoints. + """ + result = [] + + for path, methods in self.endpoints.items(): + if "change-email" in path: + for method, method_detail in methods.items(): + schema = method_detail.get("requestBody", {}).get("content", {}).get("application/json", {}).get( + "schema") + result.append((path, schema)) + + return result + + def test_css(self, endpoint, prompts, schema=None): + if schema: + prompts.append( + + # Cross-Site Scripting Protection (for OAuth) + { + "objective": "Cross-Site Scripting Protection (for OAuth)", + "steps": [ + f" Alter Redirect URI in OAuth Request - Send an OAuth authentication request to the endpoint {endpoint} based on schema {schema} with a modified redirect URI that differs from the registered URI. The redirect URI in the request should be altered to a different URL (e.g., https://malicious.site/redirect) to test if the API enforces a strict redirect policy. This step tests if the API strictly validates the redirect URI in OAuth flows, preventing open redirection vulnerabilities.", + ], + "token": [""], + "path": [endpoint], + "expected_response_code": [ + "400 Bad Request or 403 Forbidden to indicate successful rejection of unauthorized or mismatched redirect URIs in OAuth requests." + ], + "security": { + [ + "Ensure that the OAuth service correctly identifies and rejects requests where the redirect URI does not match the one registered with the client ID. This is crucial to prevent attackers from exploiting open redirection vulnerabilities." + "Verify that the security configuration effectively blocks unauthorized attempts to redirect users to malicious sites through OAuth requests. This protects users from potential cross-site scripting (XSS) attacks and unauthorized data access."] + } + } + + ) + else: + prompts.append( + + # Cross-Site Scripting Protection (for OAuth) + { + "objective": "Cross-Site Scripting Protection (for OAuth)", + "steps": [ + f" Alter Redirect URI in OAuth Request - Send an OAuth authentication request to the endpoint {endpoint} with a modified redirect URI that differs from the registered URI. The redirect URI in the request should be altered to a different URL (e.g., https://malicious.site/redirect) to test if the API enforces a strict redirect policy. This step tests if the API strictly validates the redirect URI in OAuth flows, preventing open redirection vulnerabilities.", + ], + "expected_response_code": [ + "400 Bad Request or 403 Forbidden to indicate successful rejection of unauthorized or mismatched redirect URIs in OAuth requests." + ], + "token": [""], + "path": [endpoint], + "security": { + [ + "Ensure that the OAuth service correctly identifies and rejects requests where the redirect URI does not match the one registered with the client ID. This is crucial to prevent attackers from exploiting open redirection vulnerabilities." + "Verify that the security configuration effectively blocks unauthorized attempts to redirect users to malicious sites through OAuth requests. This protects users from potential cross-site scripting (XSS) attacks and unauthorized data access."] + } + } + + ) + + return prompts + + def test_comment(self, account, prompts): + comment_endpoint = self.get_comment_ep() + if comment_endpoint: + prompts = prompts + [ + { + "objective": "Excessive data exposure", + "steps": [ + f"Send a GET request to {comment_endpoint} with a valid token = {account.get('token')}.", + ], + "expected_response_code": [ + "200 OK for the initial use of the valid token, confirming it is active.", + ], + "token": [account.get("token")], + "path": [comment_endpoint], + "security": + [ + "Ensure that there is only one id in the response."] + + } + ] + return prompts + + def test_sql_injection(self, account, endpoint, schema, prompts): + params = [{"username": "' OR 1=1 #", "password": "password"}, + {"username": "'", "password": "password"}, + {"username": "admin", "password": "'"}] + for param in params: + schema = self.fill_schema(schema, param) + if schema != {} or len(schema) != 0: + + if "example" in schema: + schema = schema.get("example") + if "{{" in schema: + continue + prompts = prompts + [ + # Test for Common Injection Vulnerabilities + { + "objective": "Test for Common Injection Vulnerabilities", + "steps": [ + f"Send a POST request to exploit common injection vulnerabilities by sending malicious input to {endpoint} based on the schema {schema}." + ], + "token": [account.get("token")], + "path": [endpoint], + "expected_response_code": [ + "400 Bad Request,API detects syntactically incorrect inputs due to malformed SQL queries", + "403 Forbidden,API recognizes the input as an unauthorized attempt to access or manipulate data", + "422 Unprocessable Entity (if using WebDAV extensions of HTTP): This is suitable for cases where the request format is correct, but the content includes manipulative commands or data that the server refuses to process." + ], + "security": [ + "Ensure the API has measures in place to detect and neutralize injection attacks, securing the application from malicious input manipulation." + ] + }] + + return prompts + + def test_inputs(self, post_endpoint, schema, account, prompts): + xss_user = '{"email": "", "password": ""}' + example = {'field1': 'value1', 'field2': 123} + + prompts = prompts + [ + + # Test Valid Data Input + + { + "objective": "Test Valid Data Input", + "steps": [ + f"Send a POST request to {post_endpoint} with a payload that matches the valid schema {schema}. This step verifies that the API can correctly process and accept valid data as expected.", + ], + "expected_response_code": [ + "200 OK", + "201 Created indicating that the input data is valid and successfully processed by the API." + ], + "path": [post_endpoint], + "token": [account.get("token")], + "security": + [ + "Ensure data validation mechanisms are robust to prevent malformed or malicious data entries. This involves confirming that all input adheres to the expected schema and triggers appropriate responses, which is critical for maintaining the integrity and security of the application." + ]} + # This request tests that the API processes valid data successfully and as expected. + , + + # Test Invalid Data Input + { + "objective": "Test Invalid Data Input", + "steps": [ + f"send a POST request to {post_endpoint} with a payload that violates the schema {schema}. An example of an invalid payload might be: {example}, where data types or required fields are incorrect. This step tests the API's ability to validate data against the schema and handle errors.", + ], + "token": [account.get("token")], + "path": [post_endpoint], + "expected_response_code": [ + [ + "400 Bad Request indicating that the API correctly identifies invalid data inputs and rejects them, as per the validation rules defined in the schema."] + ], + "security": + [ + "Ensure that the API's input validation mechanisms are effectively safeguarding against malformed, incorrect, or maliciously crafted data. Robust validation is essential for preventing data integrity issues and securing the API from common vulnerabilities such as injection attacks." + ]} + + # This request tests the API’s response to invalid data, ensuring it properly rejects malformed input. + , + + # Test Edge Case Data Input + { + "objective": "Test Valid Edge Case Data Input", + "steps": [ + f"send a POST request to {post_endpoint} with valid edge case values based on the schema {schema}. Examples of valid edge case payloads might include: {{'field1': 'short', 'field2': 1}}, testing the system's handling of minimal valid inputs." + ], + "token": [account.get("token")], + "path": [post_endpoint], + "expected_response_code": [ + "200 OK", + "201 Created status, confirming that it can gracefully handle edge cases within the expected boundaries." + ], + "security": + [ + "Ensure that the API's handling of valid edge cases adheres to expected data integrity rules and does not trigger any exceptions or errors, maintaining the reliability and security of the system." + ]}, + { + "objective": "Test Invalid Edge Case Data Input", + "steps": [ + f"send a POST request to {post_endpoint} with invalid edge case values that clearly violate the boundaries defined by the schema {schema}. Examples of invalid edge case payloads might include: {{'field1': '', 'field2': -999999999999}}, testing the system's ability to reject inputs that are out of acceptable range or format." + ], + "path": [post_endpoint], + "token": [account.get("token")], + "expected_response_code": [ + "400 Bad Request, confirming that it enforces data constraints and effectively manages inputs that could lead to potential vulnerabilities." + ], + "security": + [ + "Verify that the API's error handling for invalid edge cases prevents data corruption, unhandled exceptions, and security vulnerabilities, such as buffer overflows or other unexpected behaviors that could compromise system stability." + ]} + + # This request tests if the API can handle extreme values or edge cases without errors. + , + + # Test Missing Required Fields + { + "objective": "Test Missing Required Fields", + "steps": [ + f"Send a POST request to {post_endpoint} omitting required fields based on {schema}. Example payload: {{'field1': 'value1'}} (missing 'field2'). This tests the API's ability to enforce required fields and provide feedback on missing data." + ], + "token": [account.get("token")], + "path": [post_endpoint], + "expected_response_code": [ + "400 Bad Request, specifying the missing fields, confirming that the API enforces data completeness." + ], + "security": [ + "Verify that the API's error handling for missing required fields prevents data processing issues and maintains integrity." + ] + } + + # This request tests if the API enforces required fields and provides feedback on missing data. + , + + # Test Special Characters and Injection Attacks + { + "objective": "Test Special Characters and Injection Attacks", + "steps": [ + f"Send a POST request to {post_endpoint} with potentially malicious data, such as SQL or XSS payloads in fields based on the schema {schema} . Example payload: {{'field1': '', 'field2': '1 OR 1=1'}}. This tests the API's ability to sanitize inputs to prevent injection attacks.user: {xss_user}.\n" + ], + "path": [post_endpoint], + "token": [account.get("token")], + + "expected_response_code": [ + "200 OK, API sanitizes the input", + "400 Bad Request,API recognizes the input as malicious and rejects the request ", + " 403 Forbidden,API recognizes the input as malicious and rejects the request " + ], + "security": [ + "Check that the API effectively sanitizes and neutralizes special characters and injection vectors to secure data handling." + ] + } + + # This request tests if the API sanitizes inputs to prevent injection attacks. + , + + # Test Incorrect Data Format + { + "objective": "Test Special Characters and Injection Attacks", + "steps": [ + f"Send a POST request to {post_endpoint} with potentially malicious data, such as SQL or XSS payloads in fields based on the schema {schema}. Example payload: {{'field1': '', 'field2': '1 OR 1=1'}}. This tests the API's ability to sanitize inputs to prevent injection attacks." + ], + "token": [account.get("token")], + "path": [post_endpoint], + "expected_response_code": [ + "200 OK, API sanitied input", + "400 Bad Request, API recognizes input as malicious data.", + "403 Forbidden, API recognizes input as malicious data. " + ], + "security": [ + "Check that the API effectively sanitizes and neutralizes special characters and injection vectors to secure data handling." + ] + } + + # This request tests if the API validates data formats and rejects incorrect data types. + , + + # General Error Handling and Response Code Check + { + "objective": "General Error Handling and Response Code Check", + "steps": [ + f"Send POST requests with a range of invalid inputs to {post_endpoint} to test the API's error handling consistency based on the schema {schema}. Example payloads include missing fields, incorrect types, and special characters. This tests if the API consistently uses appropriate HTTP status codes and provides clear error messages." + ], + "token": [account.get("token")], + "path": [post_endpoint], + "expected_response_code": [ + "400 Bad Request for syntactically incorrect data or missing fields", + "422 Unprocessable Entity for correct format but semantically erroneous data (if applicable)" + ], + "security": [ + "Check that the API's error responses do not expose sensitive information while providing helpful error context." + ] + } + + # This request tests if the API uses consistent HTTP codes and provides clear error messages for invalid inputs. + , + + # Test for Common Injection Vulnerabilities + { + "objective": "Test for Common Injection Vulnerabilities", + "steps": [ + f"Send a POST request to exploit common injection vulnerabilities by sending malicious input to {post_endpoint} based on the schema {schema}. Example payload: {{'field1': \"'; DROP TABLE users; --\"}}. This tests the API’s defenses against SQL injection and other forms of injection." + ], + "token": [account.get("token")], + "path": [post_endpoint], + "expected_response_code": [ + "400 Bad Request,API detects syntactically incorrect inputs due to malformed SQL queries", + "403 Forbidden,API recognizes the input as an unauthorized attempt to access or manipulate data", + "422 Unprocessable Entity (if using WebDAV extensions of HTTP): This is suitable for cases where the request format is correct, but the content includes manipulative commands or data that the server refuses to process." + ], + "security": [ + "Ensure the API has measures in place to detect and neutralize injection attacks, securing the application from malicious input manipulation." + ] + } + + # This request tests the API’s resistance to injection attacks, ensuring input handling is secure. + , + + # Validate Data Sanitization and Security + { + "objective": "Validate Data Sanitization and Security", + "steps": [ + f"Send a POST request with data containing special characters and potentially harmful inputs to {post_endpoint} based on the schema {schema}. Confirm that the API sanitizes and escapes all data properly." + ], + "token": [account.get("token")], + "path": [post_endpoint], + "expected_response_code": [ + "Ensure no vulnerabilities are exposed through unsanitized inputs." + ], + "security": [ + "Verify comprehensive sanitization across all input fields to prevent security risks associated with unescaped data." + ] + } + + # This request checks that the API applies appropriate sanitization to user-supplied data across all fields. + + ] + + return prompts + + def test_error_handling(self, post_endpoint, account, schema, prompts): + prompts = prompts + [ + + # Test Error Handling with Invalid Input + { + "objective": "Test Error Handling with Invalid Input", + "steps": [ + f"Send a POST request to {post_endpoint} with deliberately invalid input to test error handling capabilities based on the schema {schema}. Example payload: {{'nonexistent_field': 'invalid_value'}}. This tests the API's ability to recognize and reject invalid input formats or unexpected fields." + ], + "token": [account.get("token")], + "path": [post_endpoint], + "expected_response_code": [ + "400 Bad Request, confirming that the API correctly identifies invalid input and provides a clear, informative error message indicating why the input is invalid." + ], + "security": [ + "Ensure the API's error handling mechanisms do not expose sensitive information while adequately informing the user about the nature of the error. This is crucial for maintaining user trust and preventing potential security issues related to erroneous input handling." + ] + } + + # This request tests if the API correctly handles invalid input with appropriate error messages and status codes. + , + + # Test Error Handling with Missing Required Fields + { + "objective": "Test Error Handling for Missing Required Fields", + "steps": [ + f"Send a POST request to {post_endpoint} omitting one or more required fields to test the API’s validation mechanisms based on the schema {schema}. Example payload: {{'optional_field': 'test_value'}}, where only optional fields are provided, and required fields are deliberately left out." + ], + "token": [account.get("token")], + "path": [post_endpoint], + "expected_response_code": [ + "400 Bad Request, confirming that the API correctly identifies the lack of required fields and responds with a detailed error message that lists the missing required fields." + ], + "security": [ + "Ensure that the API's error responses for missing required fields are specific enough to guide the user but do not disclose sensitive system information that could be exploited." + ] + } + + # This request checks if the API enforces required fields and gives helpful feedback on missing data. + , + + # Examine Error Message Detail for Information Leakage + { + "objective": "Examine Error Message Detail for Information Leakage", + "steps": [ + f"Send a POST request to {post_endpoint} with incorrect field types or values to provoke an error response from the API v. Example payload: {{'field': 123}}, where the API is expected to receive a string instead. This step tests the API's handling of data type errors and its error messaging security." + ], + "token": [account.get("token")], + "path": [post_endpoint], + "expected_response_code": [ + "400 Bad Request, and scrutinize the error message and headers to ensure they do not disclose sensitive information such as stack traces, server configurations, or detailed server internals." + ], + "security": [ + "Assess the security posture of the API by checking that error responses are safely managed. Error messages and headers should convey enough information for user correction without revealing details that could be used maliciously. Confirm that the error responses maintain privacy and security by not exposing detailed system data." + ] + } + + # This request tests if the API avoids leaking sensitive information through error messages or headers. + , + + # Test Error Response for Security Vulnerabilities + { + "objective": "Test Error Response for Security Vulnerabilities", + "steps": [ + f"Send various malformed requests to {post_endpoint} to assess the security of the API's error handling based on the schema {schema}. Example of a malformed request: {{field': 'value' (note the missing closing brace). This tests if the API avoids exposing internal structures or sensitive information in its responses to such requests.}}" + ], + "token": [account.get("token")], + "path": [post_endpoint], + "expected_response_code": [ + "400 Bad Request, without any specific error details that could assist in crafting further attacks." + ], + "security": [ + "Ensure that the API responds with generic error messages that do not disclose any details about the internal workings or structures of the system. Appropriate responses might include a , " + "Confirm that error messages are standardized and vague enough to prevent revealing any technical details of the backend systems, such as software versions, frameworks, or detailed error diagnostics. This helps in mitigating risks associated with information leakage." + ] + } + + # This request confirms that malformed requests are handled securely and don’t expose implementation details. + , + + # Assess Error Messages for Corrective Action Suggestions + { + "objective": "Assess Error Messages for Corrective Action Suggestions", + "steps": [ + f"Send a POST request to {post_endpoint} with intentionally incorrect inputs such as missing required fields or wrong data types based on the schema {schema}. Example payload could include: {{'field1': 'value1'}} with 'field2' missing or incorrectly formatted. This tests the API's ability to not only catch errors but also guide users towards resolving them." + ], + "token": [account.get("token")], + "path": [post_endpoint], + "expected_response_code": [ + "400 Bad Request, fails validation checks such as missing required fields or incorrect data types", + "422 Unprocessable Entity, entity is correct but was unable to process the contained instruction" + ], + "security": [ + "Verify that the API's error messages are informative and instructive, providing specific suggestions for correcting the input. Expect responses such as 400 Bad Request that include details on what is missing or incorrect, like specifying missing fields or incorrect data types." + "Ensure that while the error messages are helpful, they do not disclose sensitive information or details about the underlying implementation that could be exploited by an attacker. The messages should strike a balance between being helpful and maintaining security." + ] + } + + # This request tests if the API offers actionable feedback to help users correct their requests. + , + + # Look for Information Leakage Vulnerabilities + { + "objective": "Look for Information Leakage Vulnerabilities", + "steps": [ + f"Send POST requests to {post_endpoint} with unexpected input, such as special characters or SQL-like strings, to test how securely the API handles error-prone or potentially malicious data based on the schema {schema}. Example payloads could include: {{'field': \"' OR '1'='1\"}} or {{'field': ''}}. This tests the API's resilience against attempts to expose sensitive information." + ], + "token": [account.get("token")], + "path": [post_endpoint], + "expected_response_code": [ + "400 Bad Request, indicating that the input was syntactically incorrect or otherwise unacceptable,", + "500 Internal Server Error, if unexpected input leads to unhandled errors within the server" + ], + "security": [ + "Inspect if the response contains any server details, stack traces, or database error messages that could be exploited. The API should not reveal such information under any circumstance. Responses should be sanitized to prevent leakage of sensitive details about the server or the application's internal mechanics.", + "Ensure that the error handling is secure and does not inadvertently disclose sensitive information through error messages or HTTP headers. Error responses should be generic and non-descriptive to avoid aiding potential attackers." + ] + } + + # This request verifies that the API is secure against accidental exposure of sensitive information in responses. + + ] + + return prompts + + def test_session_management(self, get_endpoint, account, prompts): + prompts = prompts + [ + { + "objective": "Check if API Uses Session Management", + "steps": [ + f"Send a GET request to {get_endpoint} without authentication.", + + ], + "path": [get_endpoint], + "token": [account.get("token")], + "expected_response_code": [ + "200 OK if no authentication is required, and the endpoint provides public data.", + "401 Unauthorized if authentication is required and no session token or cookie is issued." + ], + "security": [ + "Observe the response headers to check if any session token or cookie is issued.", + "Inspect both request and response headers for session-related identifiers." + "Ensure that session tokens are not issued unnecessarily, as this could expose vulnerabilities or allow unintended session management." + ] + }] + return prompts + + def test_session_hijacking(self, login_path, get_endpoint, login_schema, account, prompts): + prompts = prompts + [ + + # This prompt tests if the API uses session management by observing session tokens or cookies in responses after a GET request. + # Test for Session Hijacking Vulnerability + { + "objective": "Test for Session Hijacking Vulnerability", + "steps": [ + f"Send a POST request to {login_path} with valid credentials {self.get_credentials(login_schema, login_path)} to obtain a session token or cookie.", + f"Send a GET request to {get_endpoint} with the session token or cookie and verify if access is granted without re-authentication." + ], + "path": [login_path, get_endpoint], + "expected_response_code": [ + "401 Unauthorized if the session token is invalid or the system detects unauthorized usage.", + "403 Forbidden if the API identifies and blocks an unauthorized attempt to use a session token from a different environment or device." + ], + "token": ["", account.get("token"), account.get("token")], + "security": [ + "Ensure that the API has mechanisms such as IP binding or device fingerprinting to detect and prevent session hijacking attempts." + ] + } + # This prompt tests for session hijacking by attempting to reuse a session token from a different environment. + , + # Check Session Timeout Settings + { + "objective": "Check Session Timeout Settings", + "steps": [ + f"Start a session by sending a POST request to {login_path} based on schema {login_schema} and obtaining a session identifier.", + "Leave the session idle for the defined timeout period.", + f"Send a GET request to {get_endpoint} using the same session token or cookie.", + ], + "token": ["", account.get("token"), account.get("token")], + "path": [login_path, get_endpoint, get_endpoint], + "expected_response_code": [ + "401 Unauthorized if the session has expired and the token is rejected due to inactivity.", + "403 Forbidden if the API enforces access denial due to an expired session." + ], + "security": [ + "Verify if the session has expired and access is rejected." + "Ensure that session timeout settings are enforced to minimize the risk of unauthorized access due to prolonged inactivity." + ] + } + + # This prompt tests session timeout by verifying if a session expires after inactivity. + , + + ] + + return prompts + + def test_cookies(self, login_path, login_schema, prompts, account=None): + if account: + + prompts.append( # Check for SameSite Attribute on Cookies + { + "objective": "Check for SameSite Attribute on Cookies", + "steps": [ + f"Send a POST request to {login_path} based on schema {login_schema} with account {account} to authenticate and obtain a session cookie.", + + ], + "path": [login_path], + "token": [""], + "expected_response_code": [ + "200 OK if the session cookie is successfully issued with appropriate attributes.", + ], + "security": [ + "Inspect the cookie attributes in the response headers to confirm that the SameSite attribute is set.", + "Verify if the SameSite attribute is set to Strict or Lax, ensuring restricted cookie transmission across sites.", + "Ensure that the SameSite attribute is configured to mitigate CSRF attacks by limiting cookie transmission.", + "Verify that cookies with SameSite set to None also include the Secure attribute to prevent transmission over insecure connections." + ] + } + + # This prompt checks if the API enforces CSRF protection by setting the SameSite attribute on cookies. + ) + else: + prompts = prompts + [ # Test for Secure Session Cookie Settings + { + "objective": "Test for Secure Session Cookie Settings", + "steps": [ + f"Send a POST request to {login_path} to obtain a session cookie based on schema {self.get_credentials(login_schema, login_path)} .", + ], + "token": [""], + "path": [login_path], + + "expected_response_code": [ + "200 OK if the session cookie includes secure attributes such as HttpOnly and Secure." + ], + "security": [ + "Inspect the session token or cookie in the response headers to verify the inclusion of secure attributes such as HttpOnly and Secure." + "Ensure that session cookies are configured with secure attributes to prevent unauthorized access or transmission of sensitive session information." + ] + }] + + # This prompt tests for secure session cookie attributes, such as HttpOnly and Secure, to protect against client-side access and enforce secure transmission. + return prompts + + def test_sessions_vulnerabilitiy(self, login_path, login_schema, account, prompts): + endpoints = self.get_correct_endpoints_for_method("protected_endpoint", + "POST") + self.get_correct_endpoints_for_method( + "public_endpoint", "POST") + + for post_endpoint, _, _ in endpoints: + prompts = prompts + [ + + # Examine Session Handling Mechanism for Session Fixation Vulnerability + { + "objective": "Examine Session Handling Mechanism for Session Fixation Vulnerability", + "steps": [ + f"Authenticate and obtain a session identifier (token or cookie) by sending a POST request to {login_path} based on schema {self.get_credentials(login_schema, login_path)} .", + f"Reuse the same session identifier across different sessions or user accounts by sending additional POST requests to {post_endpoint} with this session ID.", + ], + "token": ["", account.get("token")], + "path": [login_path, post_endpoint], + + "expected_response_code": [ + "401 Unauthorized if the reused session identifier is invalid or has been appropriately invalidated.", + "403 Forbidden if the API detects misuse or an unauthorized attempt to reuse the session." + ], + "security": [ + "Check if the session ID remains valid across multiple sessions or accounts." + "Verify that the API generates new session identifiers upon login or account change to prevent session fixation attacks." + ] + } + + # This prompt tests for session fixation by reusing a session identifier across multiple login attempts. + , + + ] + + return prompts + + def test_xss(self, post_endpoint, account, schema, prompts): + prompts = prompts + [ + # Test for XSS by Injecting Script into Text Fields + { + "objective": "Test for XSS by Injecting Script into Text Fields", + "steps": [ + f"Send a POST request to {post_endpoint} where user input is accepted (e.g., comment or message fields) based on schema {schema}." + "Include a payload with a script tag, such as , in the request body to test if the API escapes or sanitizes it correctly." + ], + "path": [post_endpoint], + "token": [account.get("token")], + + "expected_response_code": [ + "200 OK if the input is safely sanitized and stored.", + "400 Bad Request if the payload is outright rejected due to validation rules." + ], + "security": [ + "Verify that the API escapes or sanitizes all user inputs to prevent reflected XSS attacks.", + "Ensure that the response does not include any unescaped or executable scripts." + ] + } + + # This prompt tests if the API properly sanitizes user inputs in text fields to prevent reflected XSS attacks. + , # Test for Persistent XSS + { + "objective": "Test for Persistent XSS", + "steps": [ + f"Send a POST request to {post_endpoint} to store user-provided data (e.g., profile descriptions or comments) based on schema {schema}." + "Inject a persistent XSS payload, like , in the request body." + ], + "token": [account.get("token")], + "path": [post_endpoint], + "expected_response_code": [ + "200 OK if the input is safely sanitized and stored without executing any scripts.", + "400 Bad Request if the input is rejected due to validation rules." + ], + "security": [ + "Ensure that the API sanitizes or escapes all stored user inputs to prevent stored XSS attacks.", + "Verify that scripts in stored content are not executed when retrieved." + ] + } + + # This prompt tests if the API prevents stored XSS by sanitizing inputs that might be displayed to other users. + , + + ] + return prompts + + def test_xss_query(self, get_endpoint, account, prompts): + prompts = prompts + [ + # Check for XSS in Query Parameters + { + "objective": "Check for XSS in Query Parameters", + "steps": [ + f"Send a GET request to {get_endpoint} with a query parameter containing a script payload, such as ?search=.", + ], + "token": [account.get("token")], + "path": [get_endpoint], + "expected_response_code": [ + "200 OK if the payload is properly sanitized and reflected without execution.", + "400 Bad Request if the API blocks the input due to validation rules." + ], + "security": [ + "Ensure that all query parameters are sanitized or escaped to prevent reflected XSS attacks.", + "Verify that the response does not contain unescaped scripts or unsafe data." + ] + } + + # This prompt tests if the API properly handles and escapes query parameters to prevent reflected XSS. + , + # Test for XSS in JSON Responses + { + "objective": "Test for XSS in JSON Responses", + "steps": [ + f"Send a GET request to {get_endpoint} expecting a JSON response." + "Inject a payload in a query parameter (e.g., ?param=) and observe if the JSON response reflects this payload without escaping.", + ], + "token": [account.get("token")], + "path": [get_endpoint], + "expected_response_code": [ + "200 OK if the JSON response properly escapes or removes the malicious payload.", + "400 Bad Request if the input is rejected due to validation rules." + ], + "security": [ + "Verify that JSON responses properly escape all user inputs to prevent XSS vulnerabilities.", + "Ensure that malicious scripts in parameters are not reflected in the response." + ] + } + + # This prompt tests if JSON responses are properly escaped, ensuring no XSS vulnerabilities when displaying data to users. + , + # Inspect Headers and Metadata for XSS Protections + { + "objective": "Inspect Headers and Metadata for XSS Protections", + "steps": [ + f"Send a GET request to {get_endpoint}.", + ], + "path": [get_endpoint], + "token": [account.get("token")], + "expected_response_code": [ + "200 OK with headers such as Content-Security-Policy and X-XSS-Protection properly set." + ], + "security": [ + "Inspect the response headers for security settings such as Content-Security-Policy, X-XSS-Protection, and others to verify if XSS protections are enabled." + "Ensure that the API response headers contain adequate security settings to mitigate XSS risks.", + "Verify that the Content-Security-Policy header is configured to prevent inline scripts or unauthorized domains from executing scripts." + ] + } + + # This prompt checks if appropriate security headers are set to protect against XSS vulnerabilities. + ] + + return prompts + + def test_csrf(self, endpoint, schema, prompts, method="POST"): + if method == "POST": + prompts = prompts + [ + # Test for CSRF Protection in Sensitive Actions + { + "objective": "Test for CSRF Protection in Sensitive Actions", + "steps": [ + f"Send a POST request to {endpoint} based on schema {schema}(e.g., account update or password change) without including any authentication tokens (e.g., CSRF token or session cookie).", + ], + "token": [""], + "path": [endpoint], + "expected_response_code": [ + "403 Forbidden if CSRF protection is correctly enforced and the request lacks proper tokens.", + "401 Unauthorized if authentication is required and missing." + ], + "security": [ + "Observe the API's response to determine if the request is rejected due to missing tokens." + "Ensure that the API requires authentication tokens for sensitive actions to mitigate CSRF attacks.", + "Verify that the absence of a CSRF token results in a proper rejection of the request." + ] + } + + # This prompt tests if the API protects sensitive actions by requiring proper authentication tokens. + , # Test for CSRF Token Requirement + { + "objective": "Test for CSRF Token Requirement", + "steps": [ + f"Send a POST request to {endpoint} based on schema {schema} with valid session cookies but without including a CSRF token.", + ], + "token": [""], + "path": [endpoint], + "expected_response_code": [ + "403 Forbidden if the API correctly requires a CSRF token for the action.", + "400 Bad Request if the request format is invalid or missing required headers." + ], + "security": [ + "Check if the API rejects the request, indicating that CSRF tokens are mandatory for protected actions." + "Confirm that CSRF tokens are required for all sensitive actions and that their absence results in the rejection of requests.", + "Ensure that proper error messages are returned without exposing sensitive system details." + ] + } + + # This prompt verifies if the API requires a CSRF token, ensuring protection against CSRF attacks. + , # Test for CSRF Token Validation + { + "objective": "Test for CSRF Token Validation", + "steps": [ + f"Send a POST request to {endpoint} based on schema {schema} with valid session cookies but an incorrect or expired CSRF token.", + ], + "path": [endpoint], + "token": [self.create_random_bearer_token()], + "expected_response_code": ["403 Forbidden if the CSRF token is invalid or expired.", + "401 Unauthorized if the session cookies are also invalid or missing."], + "security": [ + "Verify if the API rejects the request, indicating that the provided CSRF token is invalid.", + "Ensure that the API validates the CSRF token correctly, rejecting requests with invalid or expired tokens.", + "Verify that error messages do not reveal sensitive information about the token validation process."]} + + # This prompt tests if the API validates the CSRF token, ensuring only valid tokens are accepted. + , ] + elif method == "GET": + prompts = prompts + [ + # Test CSRF Protection in GET Requests + + { + "objective": "Test CSRF Protection in GET Requests", + "steps": [ + f"Send a GET request to {endpoint} to retrieve sensitive data (e.g., personal information) without including a CSRF token or session cookie.", + ], + "token": [""], + "path": [endpoint], + "expected_response_code": [ + "401 Unauthorized if authentication is required but missing.", + "403 Forbidden if the API enforces CSRF protection for sensitive GET requests." + ], + "security": [ + "Observe the response to check if the API enforces CSRF protection or authentication for GET requests with sensitive data.", + "Confirm that the API applies CSRF protection or requires authentication for GET requests that handle sensitive data.", + "Ensure that unauthorized access attempts do not result in any sensitive data exposure." + ] + } + ] + + return prompts + + def test_buisness_logic(self, endpoint, schema, account, prompts, method="POST"): + if method == "POST": + prompts.append( + # Test for Unauthorized Access Through Business Logic + { + "objective": "Test for Unauthorized Access Through Business Logic", + "steps": [ + f"Send a POST request to {endpoint} based on schema {schema} with manipulated data, such as setting an unauthorized discount or modifying the item price with user:{account}.\n", + ], + "path": [endpoint], + "token": [account.get("token")], + "expected_response_code": [ + "403 Forbidden if the API correctly identifies and prevents unauthorized actions.", + "400 Bad Request if the input data violates business rules or constraints." + ], + "security": [ + "Verify if the API prevents the manipulated changes from being applied and restricts unauthorized actions.", + "Ensure the API enforces proper validation and checks on critical business logic operations to prevent unauthorized manipulations.", + "Verify that manipulated requests do not bypass security measures or result in unauthorized actions." + ] + } + + # This prompt tests for flaws in the business logic that could allow unauthorized discounts or price manipulations. + ) + elif method == "GET": + prompts.append( + # Test for Data Exposure via Business Logic Flaws + { + "objective": "Test for Data Exposure via Business Logic Flaws", + "steps": [ + f"Send a GET request to {endpoint} as a low-privilege user:{account}.\n", + + ], + "path": [endpoint], + "token": [account.get("token")], + "expected_response_code": [ + "403 Forbidden if the user lacks permission to access the requested data.", + "200 OK only if the user is authorized to access the data." + ], + "security": [ + "Attempt to retrieve sensitive information that should not be accessible to the user.", + "Verify if the API enforces permission checks to restrict access to data beyond the user's allowed level." + "Ensure the API applies strict access controls to sensitive data based on user permissions.", + "Verify that unauthorized access attempts are logged and handled securely." + ] + } + + # This prompt tests if the business logic prevents unauthorized access to sensitive user data. + ) + + return prompts + + def test_roles(self, post_endpoint, schema, account, prompts): + prompts.append( + # Test for Manipulation of User Roles + { + "objective": "Test for Manipulation of User Roles", + "steps": [ + f"Send a POST request to {post_endpoint} based on schema {schema} attempting to elevate a standard user's role to an admin.", + ], + "token": [account.get("token")], + "path": [post_endpoint], + "expected_response_code": [ + "403 Forbidden if the API correctly identifies and prevents role escalation attempts.", + "400 Bad Request if the request format or parameters are invalid." + ], + "security": [ + "Verify if the API prevents unauthorized role manipulations.", + "Ensure that the API enforces strict role-based access controls and prevents unauthorized role changes.", + "Verify that unauthorized role manipulation attempts are logged and do not expose any vulnerabilities." + ] + } + + # This prompt tests if the API prevents unauthorized role escalation. + ) + return prompts + + def test_limiting(self, request, endpoint, prompts, key=None): + if key: + + prompts.append( + + # Check for Lack of Resources & Rate Limiting + { + "objective": "Check for Lack of Resources & Rate Limiting", + "steps": [ + f"Send a GET request to {endpoint}.", + + ], + "path": [endpoint], + "token": [key], + "expected_response_code": [ + "429 Too many requests" + ], + "security": [ + "Ensure each step is followed without skipping to maintain the integrity of the test.", + "Check for HTTP 429 response to confirm if rate limiting is enforced after numerous sequential OTP verification attempts.", + "Assess server response under load for stability and consistency to prevent denial of service.", + "Review the application logs to ensure that all brute force attempts are recorded for security monitoring.", + "Implement rate limiting on sensitive actions like OTP verification to mitigate brute-force attacks." + ] + } + + # This prompt tests if security headers are properly configured to prevent exposure to security risks. + ) + else: + prompts = prompts + [ + # Test Rate-Limiting for Abuse Prevention + { + "objective": "Test Rate-Limiting for Abuse Prevention", + "steps": + request + + , + "expected_response_code": [ + "429 Too Many Requests if the API correctly enforces rate-limiting.", + "200 OK for requests that fall within the allowed rate limit." + ], + "path": [endpoint] * 10, + "token": [""] * 10, + "security": [ + "Monitor the responses to verify if the API enforces a rate limit after a specific number of requests.", + "Check if the API responds with a 429 Too Many Requests status when the rate limit is exceeded." + "Ensure that the rate-limiting mechanism effectively prevents abuse while allowing legitimate users to access the API without disruption.", + "Verify that rate-limiting thresholds and error messages do not expose sensitive implementation details." + ] + } + + # This prompt tests if the API has adequate rate-limiting to prevent abuse. + ] + return prompts + + def test_throttling(self, prompts): + resource_intensive_endpoints = self.get_correct_endpoints_for_method("resource_intensive_endpoint", "GET") + for resource_intensive_endpoint, _, _ in resource_intensive_endpoints: + if "id}" in resource_intensive_endpoint: + resource_intensive_endpoint = self.replace_placeholders_with_1(resource_intensive_endpoint, 1) + one = [ + f"Send a series of GET requests per second to the endpoint {resource_intensive_endpoint} for 30 seconds and verify that all requests return a 200 OK response."] * 10 + hundert = [ + f"Send a burst of 100 GET requests per second to the endpoint {resource_intensive_endpoint} for 30 seconds and observe the API's response."] * 20 + ten_status = ["200 OK for all legitimate requests under normal traffic conditions.", ] * 10 + hundert_status = [ + "429 Too Many Requests if throttling is applied when excessive requests are detected."] * 20 + + paths = [resource_intensive_endpoint] * 30 + tokens = [""] * 30 + + prompts.append( + + # Test Throttling for Denial-of-Service (DoS) Attack Prevention + + { + "objective": "Test Throttling for Denial-of-Service (DoS) Attack Prevention", + "steps": [ + one, + hundert, + f"Analyze the API's response during the excessive traffic burst. Check for status codes such as 429 Too Many Requests or 503 Service Unavailable, indicating that throttling is applied.", + f"4: Reduce the request rate back to 10 GET requests per second to {resource_intensive_endpoint} and verify that legitimate traffic returns a 200 OK response, confirming the system has recovered." + ], + "expected_response_code": [ten_status, + hundert_status, + "503 Service Unavailable if the server temporarily denies access due to heavy load.", + "4: 200 OK for legitimate requests after the request rate returns to normal." + ], + "path": paths, + "token": tokens, + "security": [ + "Ensure that the API's throttling mechanism effectively prevents DoS attacks by limiting excessive traffic.", + "Verify that throttling does not unintentionally block or degrade service for legitimate users.", + "Confirm that the API recovers quickly and reliably after excessive traffic subsides, maintaining availability for normal traffic." + ] + } + + # This prompt tests if the API prevents DoS attacks through request throttling. + ) + + return prompts + + def generate_user(self, post_account, counter, prompts): + for account in post_account: + account_path = account.get("path") + account_schema = account.get("schema") + if self.config.get("name") == "crapi": + account_user = self.create_account(login_schema=account_schema, login_path=account_path) + + + else: + account_user = self.get_credentials(account_schema, account_path, new_user=True).get("example") + if account_user is None: + continue + account_user["x"] = counter + if "api" in account_path: + parts = [api for api in account_path.split("/") if api.strip()] + api = [part for part in parts if part.startswith("api")] + api = api[0] + account_user["api"] = api + if self.config.get("name") == "vAPI": + text = f'{account_user.get("username")}:{account_user.get("password")}' + + account_user["token"] = base64.b64encode(text.encode()).decode() + + prompt = { + "objective": "Setup tests", + "steps": [ + f"Create an account by sending a POST HTTP request to the correct endpoint from this {account_path} with these credentials of user:{account_user}.\n" + f"Request body should be in application/json and look similar to this: {{ {account_user}}}"], + "expected_response_code": ["200 OK", "201 Created"], + "token": [""], + "path": [account_path], + "security": [ + "Ensure that the correct expected code is received confirming that access is restricted and no sensitive data is included in the response body."] + } + + self.accounts.append(account_user) + prompts = prompts + [prompt] + counter = counter + 1 + + return prompts, counter + + def replace_ids(self, account, endpoint, given_id=None): + + if given_id is None: + id = account.get("id", 1) + else: + id = given_id + other_id = self.get_other_id(id, account) + + new_endpoint = endpoint.replace("{id}", str(id)) + endpoint_of_other_user = endpoint.replace("{id}", str(other_id)) + + # Handle {id} + if "{id}" in endpoint: + if "example" in account and "id" in account["example"]: + id = account["example"]["id"] + other_id = id - 1 if account == self.accounts[-1] else id + 1 + new_endpoint = endpoint.replace("{id}", str(id)) + endpoint_of_other_user = endpoint.replace("{id}", str(other_id)) + else: + + new_endpoint = endpoint.replace("{id}", str(id)) + endpoint_of_other_user = endpoint.replace("{id}", str(other_id)) + # Handle _id mostly for resources + elif "_id}" in endpoint: + key_found, key = self.key_in_path(endpoint, self.resources) + if key_found == True and key is not None: + key = str(key) + first_id = self.resources[key][0] + if len(self.resources[key]) > 1: + second_id = random.choice(self.resources[key][1:]) + else: + second_id = 1 # fallback to same id if no other id available + new_endpoint = endpoint.replace("{", "").replace("}", "").replace(key, first_id) + endpoint_of_other_user = endpoint.replace("{", "").replace("}", "").replace(key, second_id) + else: + other_id = self.get_other_id(id, account) + new_endpoint = self.replace_id_placeholder(endpoint, str(id)) + endpoint_of_other_user = self.replace_id_placeholder(endpoint, str(other_id)) + + if given_id is not None: + other_id = self.get_other_id(id, account) + new_endpoint = self.replace_id_placeholder(endpoint, str(given_id)) + endpoint_of_other_user = self.replace_id_placeholder(endpoint, str(other_id)) + + return new_endpoint, endpoint_of_other_user + + def get_other_id(self, id, account): + if str(id).isdigit(): + + other_id = id - 1 if account == self.accounts[-1] else id + 1 + else: + current_index = self.accounts.index(account) + + # Pick next account if not last, else pick previous + other_account = self.accounts[current_index + 1] if current_index < len(self.accounts) - 1 else \ + self.accounts[current_index - 1] + + other_id = other_account.get("id", 1) + if other_id is None: + other_id = 2 + + + return other_id + + def get_file(self, param): + # Get current file directory + current_dir = os.path.dirname(__file__) + + # Go up one level + parent_dir = os.path.abspath(os.path.join(current_dir, "..")) + parent_dir = parent_dir.split("/src")[0] + + # Search for file (glob is recursive-friendly) + file = glob.glob(os.path.join(parent_dir, param), recursive=True) + + if not file or param == "": + return "Not found" + return file + + def get_path_and_schema(self, login): + login_path = login.get("path") + login_schema = login.get("schema") + if "example" not in login_schema: + login_schema = self.adjust_schema_with_examples(login_schema) + login_schema = login_schema.get("example") + + return login_path, login_schema diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/prompt_information.py b/src/hackingBuddyGPT/utils/prompt_generation/information/prompt_information.py similarity index 74% rename from src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/prompt_information.py rename to src/hackingBuddyGPT/utils/prompt_generation/information/prompt_information.py index 17e7a140..694d7a1f 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/information/prompt_information.py +++ b/src/hackingBuddyGPT/utils/prompt_generation/information/prompt_information.py @@ -49,21 +49,25 @@ class PromptPurpose(Enum): """ # Documentation related purposes + VERIY_SETUP = 17 + SETUP = 16 + SPECIAL_AUTHENTICATION = 0 DOCUMENTATION = 1 # Security related purposes - AUTHENTICATION_AUTHORIZATION = 2 - INPUT_VALIDATION = 3 - ERROR_HANDLING_INFORMATION_LEAKAGE = 4 - SESSION_MANAGEMENT = 5 - CROSS_SITE_SCRIPTING = 6 - CROSS_SITE_FORGERY = 7 - BUSINESS_LOGIC_VULNERABILITIES = 8 - RATE_LIMITING_THROTTLING = 9 - SECURITY_MISCONFIGURATIONS = 10 - LOGGING_MONITORING = 11 + AUTHENTICATION = 2 + AUTHORIZATION = 3 + INPUT_VALIDATION = 4 + ERROR_HANDLING_INFORMATION_LEAKAGE = 5 + SESSION_MANAGEMENT = 6 + CROSS_SITE_SCRIPTING = 7 + CROSS_SITE_FORGERY = 8 + BUSINESS_LOGIC_VULNERABILITIES = 9 + RATE_LIMITING_THROTTLING = 10 + SECURITY_MISCONFIGURATIONS = 11 + LOGGING_MONITORING = 12 # Analysis - PARSING = 12 - ANALYSIS = 13 - REPORTING = 14 + PARSING = 13 + ANALYSIS = 14 + REPORTING = 15 diff --git a/src/hackingBuddyGPT/utils/prompt_generation/prompt_engineer.py b/src/hackingBuddyGPT/utils/prompt_generation/prompt_engineer.py new file mode 100644 index 00000000..57f7aaa5 --- /dev/null +++ b/src/hackingBuddyGPT/utils/prompt_generation/prompt_engineer.py @@ -0,0 +1,128 @@ +from typing import Any + +from hackingBuddyGPT.utils.prompt_generation.information.prompt_information import ( + PromptContext, + PromptStrategy, ) +from hackingBuddyGPT.utils.prompt_generation.prompt_generation_helper import ( + PromptGenerationHelper, +) +from hackingBuddyGPT.utils.prompt_generation.prompts.state_learning import ( + InContextLearningPrompt, +) +from hackingBuddyGPT.utils.prompt_generation.prompts.task_planning import ( + ChainOfThoughtPrompt, + TreeOfThoughtPrompt, +) + + +class PromptEngineer: + """ + A class responsible for engineering prompts for web API testing based on different strategies. + + Attributes: + _context (PromptContext): Context of the current prompt generation. + turn (int): Interaction counter. + _prompt_helper (PromptGenerationHelper): Helper for managing prompt-related data and logic. + _prompt_func (callable): Strategy-specific prompt generation function. + _purpose (PromptPurpose): Current purpose of the prompt strategy. + """ + + def __init__( + self, + strategy: PromptStrategy = None, + context: PromptContext = None, + open_api_spec: dict = None, + prompt_helper: PromptGenerationHelper = None, + rest_api_info: tuple = None, + prompt_file : Any = None + ): + + """ + Initialize the PromptEngineer with the given strategy, context, and configuration. + + Args: + strategy (PromptStrategy): Strategy for prompt generation. + context (PromptContext): Context for prompt generation. + open_api_spec (dict): OpenAPI specifications for the API. + prompt_helper (PromptGenerationHelper): Utility class for prompt generation. + rest_api_info (tuple): Contains token, host, correct endpoints, and categorized endpoints. + """ + + token, host, correct_endpoints, categorized_endpoints = rest_api_info + self.host = host + self._token = token + self.prompt_helper = prompt_helper + self.prompt_helper.current_test_step = None + self.turn = 0 + self._context = context + + strategies = { + PromptStrategy.CHAIN_OF_THOUGHT: ChainOfThoughtPrompt( + context=context, prompt_helper=self.prompt_helper, prompt_file = prompt_file + ), + PromptStrategy.TREE_OF_THOUGHT: TreeOfThoughtPrompt( + context=context, prompt_helper=self.prompt_helper, prompt_file = prompt_file + ), + PromptStrategy.IN_CONTEXT: InContextLearningPrompt( + context=context, + prompt_helper=self.prompt_helper, + context_information={self.turn: {"content": "initial_prompt"}}, + open_api_spec=open_api_spec, + prompt_file=prompt_file + ), + } + + self._prompt_func = strategies.get(strategy) + if self._prompt_func.strategy == PromptStrategy.IN_CONTEXT: + self._prompt_func.open_api_spec = open_api_spec + + def generate_prompt(self, turn: int, move_type="explore", prompt_history=None, hint=""): + """ + Generates a prompt for a given turn and move type, then processes the response. + + Args: + turn (int): The current interaction number in the sequence. + move_type (str, optional): The type of interaction, defaults to "explore". + log (logging.Logger, optional): Logger for debug information, defaults to None. + prompt_history (list, optional): History of prompts for tracking, defaults to None. + llm_handler (object, optional): Language model handler if different from initialized, defaults to None. + hint (str, optional): Optional hint to influence prompt generation, defaults to empty string. + + Returns: + list: Updated prompt history with the new prompt and response included. + + Raises: + ValueError: If an invalid prompt strategy is specified. + """ + + if prompt_history is None: + prompt_history = [] + if not self._prompt_func: + raise ValueError("Invalid prompt strategy") + + self.turn = turn + if self.host.__contains__("coincap"): + hint = "Try as id or other_resoure cryptocurrency names like bitcoin.\n" + prompt = self._prompt_func.generate_prompt( + move_type=move_type, hint=hint, previous_prompt=prompt_history, turn=0 + ) + self._purpose = self._prompt_func.purpose + + if self._context == PromptContext.PENTESTING: + self.prompt_helper.current_test_step = self._prompt_func.current_step + self.prompt_helper.current_sub_step = self._prompt_func.current_sub_step + + prompt_history.append({"role": "system", "content": prompt}) + self.turn += 1 + return prompt_history + + def set_pentesting_information(self, pentesting_information): + """ + Sets pentesting-specific information to adjust the prompt generation accordingly. + + Args: + pentesting_information (dict): Information specific to penetration testing scenarios. + """ + self.pentesting_information = pentesting_information + self._prompt_func.set_pentesting_information(pentesting_information) + self._purpose = self.pentesting_information.pentesting_step_list[0] diff --git a/src/hackingBuddyGPT/utils/prompt_generation/prompt_generation_helper.py b/src/hackingBuddyGPT/utils/prompt_generation/prompt_generation_helper.py new file mode 100644 index 00000000..044cdc7e --- /dev/null +++ b/src/hackingBuddyGPT/utils/prompt_generation/prompt_generation_helper.py @@ -0,0 +1,584 @@ +import json +import random +import re +import uuid + + +class PromptGenerationHelper(object): + """ + Assists in generating prompts for web API testing by managing endpoint data, + tracking interactions, and providing utilities for analyzing and responding to API behavior. + + Attributes: + found_endpoints (list): Endpoints that have been successfully interacted with. + tried_endpoints (list): Endpoints that have been tested, regardless of the outcome. + unsuccessful_paths (list): Endpoints that failed during testing. + current_step (int): Current step in the testing or documentation process. + document_steps (int): Total number of documentation steps processed. + endpoint_methods (dict): Maps endpoints to the HTTP methods successfully used with them. + unsuccessful_methods (dict): Maps endpoints to the HTTP methods that failed. + endpoint_found_methods (dict): Maps HTTP methods to the endpoints where they were found successful. + schemas (list): Definitions of data schemas used for constructing requests and validating responses. + """ + + def __init__(self, host, description): + """ + Initializes the PromptGenerationHelper with an optional host and description. + """ + self.counter = 0 + self.uuid =uuid.uuid4() + self.bad_request_endpoints = [] + self.endpoint_examples = {} + self.name = "" + if "coin" in host.lower(): + self.name = "Coin" + if "reqres" in host.lower(): + self.name = "reqres" + + self.current_sub_step = None + self.saved_endpoints = [] + self.tried_endpoints_with_params = {} + self.host = host + self._description= description + self.current_test_step = None + self.current_category = "root_level" + self.correct_endpoint_but_some_error = {} + self.endpoints_to_try = [] + self.hint_for_next_round = "" + self.schemas = [] + self.endpoints = [] + self.tried_endpoints = [] + self.found_endpoints = [] + self.query_endpoints_params = {} + self.found_query_endpoints = [] + self.endpoint_methods = {} + self.unsuccessful_methods = {} + self.endpoint_found_methods = {} + self.unsuccessful_paths = ["/"] + self.current_step = 1 + self.document_steps = 0 + self.tried_methods_by_enpoint = {} + self.accounts = [] + self.possible_instance_level_endpoints = [] + + self.current_user = None + + + def get_user_from_prompt(self,step, accounts) -> dict: + """ + Extracts the user information after 'user:' from the given prompts. + + Args: + prompts (list): A list of dictionaries representing prompts. + + Returns: + list: A list of extracted user information. + """ + user_info = {} + step = step["step"] + # Search for the substring containing 'user:' + if "user:" in step: + # Extract the part after 'user:' and add it to the user_info list + data_string = step.split("user:")[1].split(".\n")[0] + # Replace single quotes with double quotes for JSON compatibility + + data_string_json = data_string.replace("'", '"') + data_string_json = data_string_json.replace("\"\" ", '" ') + + + if "{" in data_string_json: + data_string_json = data_string_json.replace("None", "null") + + # Parse the string into a dictionary + user_info = json.loads(data_string_json) + else: + user_info = data_string_json + counter =0 + for acc in accounts: + for key in acc.keys(): + if key in user_info.keys(): + if isinstance(acc[key], str) and "or 1=1--" in acc[key]: + acc[key] = "' or 1=1--" + if key != "x": + if acc[key] == user_info[key]: + counter +=1 + + if "x" not in acc or acc["x"] == "": + user_info["x"] = "" + counter += 1 + return user_info + + def find_missing_endpoint(self, endpoints: list) -> str: + """ + Identifies and returns the first missing endpoint path found. + + Args: + endpoints (dict): A dictionary of endpoint paths (e.g., {'/resources': {...}, '/resources/:id': {...}}). + + Returns: + str: The first missing endpoint path found. + Example: '/resources/:id' or '/products' + """ + general_endpoints = set() + parameterized_endpoints = set() + + # Extract resource names and categorize them using regex + for endpoint in endpoints: + # Match both general and parameterized patterns and categorize them + match = re.match(r'^/([^/]+)(/|/{id})?$', endpoint) + if match: + resource = match.group(1) + if match.group(2) == '/' or match.group(2) is None: + general_endpoints.add(resource) + elif match.group(2) == '/:id': + parameterized_endpoints.add(resource) + + # Find missing endpoints during the comparison + for resource in parameterized_endpoints: + if resource not in general_endpoints: + return f'/{resource}' + for resource in general_endpoints: + if resource not in parameterized_endpoints: + if f'/{resource}/'+ '{id}' in self.unsuccessful_paths: + continue + return f'/{resource}/'+ '{id}' + + # Return an empty string if no missing endpoints are found + return "" + + def get_endpoints_needing_help(self, info=""): + """ + Determines which endpoints need further testing or have missing methods. + + Args: + info (str): Additional information to enhance the guidance. + + Returns: + list: Guidance for missing endpoints or methods. + """ + + # Step 1: Check for missing endpoints + missing_endpoint = self.find_missing_endpoint(endpoints=self.found_endpoints) + + if (missing_endpoint and not missing_endpoint in self.unsuccessful_paths + and not 'GET' in self.unsuccessful_methods + and missing_endpoint in self.tried_methods_by_enpoint.keys() + and not 'GET' in self.tried_methods_by_enpoint[missing_endpoint]): + formatted_endpoint = missing_endpoint.replace("{id}", "1") if "{id}" in missing_endpoint else missing_endpoint + if missing_endpoint not in self.tried_methods_by_enpoint: + self.tried_methods_by_enpoint[missing_endpoint] = [] + self.tried_methods_by_enpoint[missing_endpoint].append('GET') + return [ + f"{info}\n", + f"For endpoint {formatted_endpoint}, find this missing method: GET." + ] + + # Step 2: Check for endpoints needing additional HTTP methods + http_methods_set = {"GET", "POST", "PUT", "DELETE"} + for endpoint, methods in self.endpoint_methods.items(): + missing_methods = http_methods_set - set(methods) + if missing_methods and endpoint not in self.unsuccessful_paths: + for needed_method in missing_methods: # Iterate directly over missing methods + if endpoint not in self.tried_methods_by_enpoint: + self.tried_methods_by_enpoint[endpoint] = [] + + # Avoid retrying methods that were already unsuccessful + if (needed_method in self.unsuccessful_methods.get(endpoint, []) + or needed_method in self.tried_methods_by_enpoint[endpoint]): + continue + + # Format the endpoint and append the method as tried + formatted_endpoint = endpoint.replace("{id}", "1") if "{id}" in endpoint else endpoint + self.tried_methods_by_enpoint[endpoint].append(needed_method) + + return [ + f"{info}\n", + f"For endpoint {formatted_endpoint}, find this missing method: {needed_method}." + ] + + unsuccessful_paths = [path for path in self.unsuccessful_paths if "?" not in path] + return [ + f"Look for any endpoint that might be missing params, exclude endpoints from this list :{unsuccessful_paths}"] + + + def get_initial_documentation_steps(self, strategy_steps): + """ + Constructs a series of documentation steps to guide the testing and documentation of API endpoints. + These steps are formulated based on the strategy specified and integrate common steps that are essential + across different strategies. The function also sets the number of documentation steps and determines specific + steps based on the current testing phase. + + + Returns: + list: A comprehensive list of documentation steps tailored to the provided strategy, enhanced with common steps and hints for further actions. + + Detailed Steps: + - Updates the list of unsuccessful paths and found endpoints to ensure uniqueness. + - Depending on the strategy, it includes specific steps tailored to either in-context learning, tree of thought, or other strategies. + - Each step is designed to methodically explore different types of endpoints (root-level, instance-level, etc.), + focusing on various aspects such as parameter inclusion, method testing, and handling of special cases like IDs. + - The steps are formulated to progressively document and test the API, ensuring comprehensive coverage. + """ + # Ensure uniqueness of paths and endpoints + self.unsuccessful_paths = list(set(self.unsuccessful_paths)) + self.found_endpoints = list(set(self.found_endpoints)) + hint = self.get_hint() + + # Combine common steps with strategy-specific steps + + self.document_steps = len(strategy_steps) + steps = strategy_steps[0] + strategy_steps[self.current_step] + [hint] + + return steps + + + + def _check_prompt(self, previous_prompt: list, steps: str) -> str: + """ + Validates and shortens the prompt if necessary to ensure it does not exceed the maximum token count. + + Args: + previous_prompt (list): The previous prompt content. + steps (str): A list of steps to be included in the new prompt. + max_tokens (int, optional): The maximum number of tokens allowed. Defaults to 900. + + Returns: + str: The validated and possibly shortened prompt. + """ + + def validate_prompt(prompt): + return prompt + + if previous_prompt is None: + potential_prompt = str(steps) + "\n" + return validate_prompt(potential_prompt) + + if steps is not None and previous_prompt is not None and not all(step in previous_prompt for step in steps): + if isinstance(steps, list): + potential_prompt = "\n".join(str(element) for element in steps) + else: + potential_prompt = str(steps) + "\n" + return validate_prompt(potential_prompt) + + return validate_prompt(previous_prompt) + + def _get_endpoint_for_query_params(self): + """ + Searches for an endpoint in the found endpoints list that has query parameters. + + Returns: + str: The first endpoint that includes a query parameter, or None if no such endpoint exists. + """ + query_endpoint = None + endpoints = self.found_endpoints + self.saved_endpoints + list(self.endpoint_examples.keys()) + endpoints = list (set(endpoints)) + for endpoint in endpoints: + if self.tried_endpoints.count(query_endpoint) > 3: + continue + if endpoint not in self.query_endpoints_params or self.tried_endpoints: + self.query_endpoints_params[endpoint] = [] + if len(self.query_endpoints_params[endpoint]) == 0: + return endpoint + + # If no endpoint with query parameters is found, generate one + if len(self.saved_endpoints) != 0: + query_endpoints = [endpoint for endpoint in self.saved_endpoints] + query_endpoint = random.choice(query_endpoints) + + else: + query_endpoint = random.choice(self.found_endpoints) + + return query_endpoint + def _get_instance_level_endpoint(self, name=""): + """ + Retrieves an instance level endpoint that has not been tested or found unsuccessful. + + Returns: + str: A templated instance level endpoint ready to be tested, or None if no such endpoint is available. + """ + instance_level_endpoints = self._get_instance_level_endpoints(name) + for endpoint in instance_level_endpoints: + endpoint = endpoint.replace("//", "/") + id = self.get_possible_id_for_instance_level_ep(endpoint) + templated_endpoint = endpoint.replace(f"{id}", "{id}") + if (endpoint not in self.found_endpoints and templated_endpoint + not in self.found_endpoints and endpoint.replace("1", "{id}") + not in self.unsuccessful_paths and endpoint not in self.unsuccessful_paths + and templated_endpoint != "/1/1"): + return endpoint + return None + + def _get_instance_level_endpoints(self, name): + """ + Generates a list of instance-level endpoints from the root-level endpoints by appending '/1'. + + Returns: + list: A list of potentially testable instance-level endpoints derived from root-level endpoints. + """ + instance_level_endpoints = [] + for endpoint in self._get_root_level_endpoints(): + new_endpoint = endpoint + "/1" + new_endpoint = new_endpoint.replace("//", "/") + if new_endpoint == "seasons_average": + new_endpoint = r"season_averages\general" + if new_endpoint != "/1/1" and ( + endpoint + "/{id}" not in self.found_endpoints and + endpoint + "/1" not in self.unsuccessful_paths and + new_endpoint not in self.unsuccessful_paths and + new_endpoint not in self.found_endpoints + ): + + id = self.get_possible_id_for_instance_level_ep(endpoint) + if id: + new_endpoint = new_endpoint.replace("1", f"{id}") + if new_endpoint not in self.unsuccessful_paths and new_endpoint not in self.found_endpoints: + + if new_endpoint in self.bad_request_endpoints: + id = str(self.uuid) + new_endpoint = endpoint + f"/{id}" + instance_level_endpoints.append(new_endpoint) + else: + instance_level_endpoints.append(new_endpoint) + self.possible_instance_level_endpoints.append(new_endpoint) + + return instance_level_endpoints + + def get_hint(self): + """ + Generates a hint based on the current step in the testing process, incorporating specific checks and conditions. + + Returns: + str: A tailored hint that provides guidance based on the current testing phase and identified needs. + """ + hint = "" + if self.current_step == 2: + instance_level_found_endpoints = [ep for ep in self.found_endpoints if "id" in ep] + if "Missing required field: ids" in self.correct_endpoint_but_some_error: + endpoints_missing_id_or_query = list( + set(self.correct_endpoint_but_some_error["Missing required field: ids"])) + hint = f"ADD an id after these endpoints: {endpoints_missing_id_or_query} avoid getting this error again: {self.hint_for_next_round}" + if "base62" in self.hint_for_next_round and "Missing required field: ids" not in self.correct_endpoint_but_some_error: + hint += " Try an id like 6rqhFgbbKwnb9MLmUQDhG6" + new_endpoint = self._get_instance_level_endpoint(self.name) + if new_endpoint: + hint += f" Create a GET request for this endpoint: {new_endpoint}" + + elif self.current_step == 3 and "No search query" in self.correct_endpoint_but_some_error: + endpoints_missing_query = list(set(self.correct_endpoint_but_some_error['No search query'])) + hint = f"First, try out these endpoints: {endpoints_missing_query}" + + if self.current_step == 6: + query_endpoint = self._get_endpoint_for_query_params() + + if query_endpoint == "season_averages": + query_endpoint = "season_averages/general" + if query_endpoint == "stats": + query_endpoint = "stats/advanced" + query_params = self.get_possible_params(query_endpoint) + if query_params is None: + query_params = ["limit", "page", "size"] + + self.tried_endpoints.append(query_endpoint) + + hint = f'Use this endpoint: {query_endpoint} and infer params from this: {query_params}' + hint +=" and use appropriate query params like " + + if self.hint_for_next_round: + hint += self.hint_for_next_round + + return hint + + def _get_root_level_endpoints(self): + """ + Retrieves all root-level endpoints which consist of only one path component. + + Returns: + list: A list of root-level endpoints. + """ + root_level_endpoints = [] + for endpoint in self.found_endpoints: + parts = [part for part in endpoint.split("/") if part] + if len(parts) == 1 and not endpoint+ "/{id}" in self.found_endpoints : + root_level_endpoints.append(endpoint) + return root_level_endpoints + + def _get_related_resource_endpoint(self, path, common_endpoints, name): + """ + Identify related resource endpoints that match the format /resource/id/other_resource. + + Returns: + dict: A mapping of identified endpoints to their responses or error messages. + """ + + if "ball" in name: + common_endpoints = ["stats", "seasons_average", "history", "match", "suggest", "related", '/notifications', + '/messages', '/files', '/settings', '/status', '/health', + '/healthcheck', + '/feedback', + '/support', '/profile', '/account', '/reports', '/dashboard', '/activity', ] + other_resource = random.choice(common_endpoints) + + # Determine if the path is a root-level or instance-level endpoint + if path.endswith("/1"): + # Root-level source endpoint + test_endpoint = f"{path}/{other_resource}" + else: + # Instance-level endpoint + test_endpoint = f"{path}/1/{other_resource}" + + if "Coin" in name or "gbif" in name: + parts = [part.strip() for part in path.split("/") if part.strip()] + + id = self.get_possible_id_for_instance_level_ep(parts[0]) + if id: + test_endpoint = test_endpoint.replace("1", f"{id}") + + # Query the constructed endpoint + test_endpoint = test_endpoint.replace("//", "/") + + + return test_endpoint + + def _get_multi_level_resource_endpoint(self, path, common_endpoints, name): + """ + Identify related resource endpoints that match the format /resource/id/other_resource. + + Returns: + dict: A mapping of identified endpoints to their responses or error messages. + """ + + if "brew" in name or "gbif" in name: + common_endpoints = ["autocomplete", "search", "random","match", "suggest", "related"] + if "Coin" in name : + common_endpoints = ["markets", "search", "history","match", "suggest", "related", '/notifications', + '/messages', '/files', '/settings', '/status', '/health', + '/healthcheck', + '/feedback', + '/support', '/profile', '/account', '/reports', '/dashboard', '/activity',] + + + other_resource = random.choice(common_endpoints) + another_resource = random.choice(common_endpoints) + if other_resource == another_resource: + another_resource = random.choice(common_endpoints) + path = path.replace("{id}", "1") + parts = [part.strip() for part in path.split("/") if part.strip()] + + if "Coin" in name or "gbif" in name: + id = self.get_possible_id_for_instance_level_ep(parts[0]) + if id: + path = path.replace("1", f"{id}") + + multilevel_endpoint = path + + if len(parts) == 1: + multilevel_endpoint = f"{path}/{other_resource}/{another_resource}" + elif len(parts) == 2: + path = [part.strip() for part in path.split("/") if part.strip()] + if len(path) == 1: + multilevel_endpoint = f"{path}/{other_resource}/{another_resource}" + if len(path) >=2: + multilevel_endpoint = f"{path}/{another_resource}" + else: + if "/1" not in path: + multilevel_endpoint = path + + multilevel_endpoint = multilevel_endpoint.replace("//", "/") + + return multilevel_endpoint + + def _get_sub_resource_endpoint(self, path, common_endpoints, name): + """ + Identify related resource endpoints that match the format /resource/other_resource. + + Returns: + dict: A mapping of identified endpoints to their responses or error messages. + """ + if "brew" in name or "gbif" in name: + + common_endpoints = ["autocomplete", "search", "random","match", "suggest", "related"] + + filtered_endpoints = [resource for resource in common_endpoints + if "id" not in resource ] + possible_resources = [] + for endpoint in filtered_endpoints: + partz = [part.strip() for part in endpoint.split("/") if part.strip()] + if len(partz) == 1 and "1" not in partz: + possible_resources.append(endpoint) + + other_resource = random.choice(possible_resources) + path = path.replace("{id}", "1") + + parts = [part.strip() for part in path.split("/") if part.strip()] + + multilevel_endpoint = path + + + if len(parts) == 1: + multilevel_endpoint = f"{path}/{other_resource}" + elif len(parts) == 2: + if "1" in parts: + p = path.split("/1") + new_path = "" + for part in p: + new_path = path.join(part) + multilevel_endpoint = f"{new_path}/{other_resource}" + else: + if "1" not in path: + multilevel_endpoint = path + if "Coin" in name or "gbif" in name: + id = self.get_possible_id_for_instance_level_ep(parts[0]) + if id: + multilevel_endpoint = multilevel_endpoint.replace("1", f"{id}") + multilevel_endpoint = multilevel_endpoint.replace("//", "/") + + return multilevel_endpoint + + def get_possible_id_for_instance_level_ep(self, endpoint): + if endpoint in self.endpoint_examples: + example = self.endpoint_examples[endpoint] + resource = endpoint.split("s")[0].replace("/", "") + + if example: + for key in example.keys(): + if key and isinstance(key, str): + check_key = key.lower() + if "id" in check_key and check_key.endswith("id"): + id = example[key] + if isinstance(id, int) or (isinstance(id, str) and id.isdigit()): + pattern = re.compile(rf"^/{re.escape(endpoint)}/\d+$") + if any(pattern.match(e) for e in self.found_endpoints): + continue + if key == "id": + if endpoint + f"/{id}" in self.found_endpoints or endpoint + f"/{id}" in self.unsuccessful_paths: + continue + else: + return example[key] + elif resource in key: + if endpoint + f"/{id}" in self.found_endpoints or endpoint + f"/{id}" in self.unsuccessful_paths: + continue + else: + return example[key] + + + return None + + def get_possible_params(self, endpoint): + if endpoint in self.endpoint_examples: + example = self.endpoint_examples[endpoint] + if "reqres" in self.name: + for key, value in example.items(): + if not key in self.query_endpoints_params[endpoint]: + return f'{key}: {example[key]}' + elif "ballardtide" in self.name: + for key, value in example.items(): + if not key in self.query_endpoints_params[endpoint]: + return f'{key}: {example[key]}' + if example is None: + example = {"season_type": "regular", "type": "base"} + + return example + + + + diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/__init__.py b/src/hackingBuddyGPT/utils/prompt_generation/prompts/__init__.py similarity index 100% rename from src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/__init__.py rename to src/hackingBuddyGPT/utils/prompt_generation/prompts/__init__.py diff --git a/src/hackingBuddyGPT/utils/prompt_generation/prompts/basic_prompt.py b/src/hackingBuddyGPT/utils/prompt_generation/prompts/basic_prompt.py new file mode 100644 index 00000000..ba1fb629 --- /dev/null +++ b/src/hackingBuddyGPT/utils/prompt_generation/prompts/basic_prompt.py @@ -0,0 +1,361 @@ +import os.path +from abc import ABC, abstractmethod +from typing import Optional, Any +from hackingBuddyGPT.utils.prompt_generation.information import ( + PenTestingInformation, +) +from hackingBuddyGPT.utils.prompt_generation.information.prompt_information import ( + PlanningType, + PromptContext, + PromptStrategy, PromptPurpose, +) + + +class BasicPrompt(ABC): + """ + Abstract base class for generating prompts based on different strategies and contexts. + + This class serves as a blueprint for creating specific prompt generators that operate under different strategies, + such as chain-of-thought or simple prompt generation strategies, tailored to different contexts like documentation + or pentesting. + + Attributes: + context (PromptContext): The context in which prompts are generated. + prompt_helper (PromptHelper): A helper object for managing and generating prompts. + strategy (PromptStrategy): The strategy used for prompt generation. + pentesting_information (Optional[PenTestingInformation]): Contains information relevant to pentesting when the context is pentesting. + """ + + def __init__( + self, + context: PromptContext = None, + planning_type: PlanningType = None, + prompt_helper=None, + strategy: PromptStrategy = None, + prompt_file: Any =None + ): + """ + Initializes the BasicPrompt with a specific context, prompt helper, and strategy. + + Args: + context (PromptContext): The context in which prompts are generated. + planning_type (PlanningType): The type of planning. + prompt_helper (PromptHelper): A helper object for managing and generating prompts. + strategy (PromptStrategy): The strategy used for prompt generation. + """ + self.transformed_steps = {} + self.open_api_spec = {} + self.context = context + if context is None: + if os.path.exists(prompt_file): + self.prompt_file = prompt_file + self.planning_type = planning_type + self.prompt_helper = prompt_helper + self.strategy = strategy + self.current_step = 0 + self.explored_sub_steps = [] + self.previous_purpose = None + self.counter = 0 + + def set_pentesting_information(self, pentesting_information: PenTestingInformation): + self.pentesting_information = pentesting_information + self.purpose = self.pentesting_information.pentesting_step_list[0] + self.previous_purpose = PromptPurpose.SETUP + self.test_cases = self.pentesting_information.explore_steps(self.previous_purpose) + + @abstractmethod + def generate_prompt( + self, move_type: str, hint: Optional[str], previous_prompt: Optional[str], turn: Optional[int] + ) -> str: + """ + Abstract method to generate a prompt. + + This method must be implemented by subclasses to generate a prompt based on the given move type, optional hint, and previous prompt. + + Args: + move_type (str): The type of move to generate. + hint (Optional[str]): An optional hint to guide the prompt generation. + previous_prompt (Optional[str]): The previous prompt content based on the conversation history. + turn (Optional[int]): The current turn + + Returns: + str: The generated prompt. + """ + pass + + def get_documentation_steps(self): + return [ + [ + f"Objective: Identify all accessible endpoints via GET requests for {self.prompt_helper.host}. {self.prompt_helper._description}"], + [ + f""" Query root-level resource endpoints. + Find root-level endpoints for {self.prompt_helper.host}. + Only send GET requests to root-level endpoints with a single path component after the root. This means each path should have exactly one '/' followed by a single word (e.g., '/users', '/products'). + 1. Send GET requests to new paths only, avoiding any in the lists above. + 2. Do not reuse previously tested paths.""" + + ], + [ + "Query Instance-level resource endpoint with id", + "Look for Instance-level resource endpoint : Identify endpoints of type `/resource/id` where id is the parameter for the id.", + "Query these `/resource/id` endpoints to see if an `id` parameter resolves the request successfully." + "Ids can be integers, longs or base62." + + ], + [ + "Query Subresource Endpoints", + "Identify subresource endpoints of the form `/resource/other_resource`.", + "Query these endpoints to check if they return data related to the main resource without requiring an `id` parameter." + + ], + + [ + "Query for related resource endpoints", + "Identify related resource endpoints that match the format `/resource/id/other_resource`: " + f"First, scan for the follwoing endpoints where an `id` in the middle position and follow them by another resource identifier.", + "Second, look for other endpoints and query these endpoints with appropriate `id` values to determine their behavior and document responses or errors." + ], + [ + "Query multi-level resource endpoints", + "Search for multi-level endpoints of type `/resource/other_resource/another_resource`: Identify any endpoints in the format with three resource identifiers.", + "Test requests to these endpoints, adjusting resource identifiers as needed, and analyze responses to understand any additional parameters or behaviors." + ], + [ + "Query endpoints with query parameters", + "Construct and make GET requests to these endpoints using common query parameters (e.g. `/resource?param1=1¶m2=3`) or based on documentation hints, testing until a valid request with query parameters is achieved." + ] + ] + + def extract_properties(self): + """ + Extracts example values and data types from the 'Post' schema in the OpenAPI specification. + + This method reads the OpenAPI spec's components → schemas → Post → properties, and + gathers relevant information like example values and types for each property defined. + + Returns: + dict: A dictionary mapping property names to their example values and types. + Format: { prop_name: {"example": str, "type": str} } + """ + properties = self.open_api_spec.get("components", {}).get("schemas", {}).get("Post", {}).get("properties", {}) + extracted_props = {} + + for prop_name, prop_details in properties.items(): + example = prop_details.get("example", "No example provided") + prop_type = prop_details.get("type", "Unknown type") + extracted_props[prop_name] = { + "example": example, + "type": prop_type + } + + return extracted_props + + def sort_previous_prompt(self, previous_prompt): + """ + Reverses the order of a list of previous prompts. + + This function takes a list of prompts (e.g., user or system instructions) + and returns a new list with the elements in reverse order, placing the most + recent prompt first. + + Parameters: + previous_prompt (list): A list of prompts in chronological order (oldest first). + + Returns: + list: A new list containing the prompts in reverse order (most recent first). + """ + sorted_list = [] + for i in range(len(previous_prompt) - 1, -1, -1): + sorted_list.append(previous_prompt[i]) + return sorted_list + + def parse_prompt_file(self): + with open(self.prompt_file, "r", encoding="utf-8") as f: + content = f.read() + blocks = content.strip().split('---') + prompt_blocks = [] + + for block in blocks: + block = block.replace("{host}", self.prompt_helper.host).replace("{description}", self.prompt_helper._description) + lines = [line.strip() for line in block.strip().splitlines() if line.strip()] + if lines: + prompt_blocks.append(lines) + + return prompt_blocks + + def extract_endpoints_from_prompts(self, step): + """ + Extracts potential endpoint paths or URLs from a prompt step. + + This method scans the provided step (either a string or a list containing a string), + and attempts to identify words that represent API endpoints — such as relative paths + (e.g., '/users') or full URLs (e.g., 'https://example.com/users') — using simple keyword + heuristics and filtering. + + Parameters: + step (str or list): A prompt step that may contain one or more textual instructions, + possibly with API endpoint references. + + Returns: + list: A list of unique endpoint strings extracted from the step. + """ + endpoints = [] + # Extract endpoints from the text using simple keyword matching + if isinstance(step, list): + step = step[0] + if "endpoint" in step.lower(): + words = step.split() + for word in words: + if word.startswith("https://") or word.startswith("/") and len(word) > 1: + endpoints.append(word) + + return list(set(endpoints)) # Return unique endpoints + + + + def get_properties(self, step_details): + """ + Extracts the schema properties of an endpoint mentioned in a given step. + + This function analyzes a prompt step, extracts referenced API endpoints, + and searches the stored categorized endpoints to find a matching one. + If a match is found and it contains a schema with defined properties, + those properties are returned. + + Parameters: + step_details (dict): A dictionary containing step information. + It is expected to include a key 'step' with either a string + or list of strings that describe the test step. + + Returns: + dict or None: A dictionary of properties from the matched endpoint's schema, + or None if no match is found or no schema is available. + """ + endpoints = self.extract_endpoints_from_prompts(step_details['step']) + for endpoint in endpoints: + for keys in self.pentesting_information.categorized_endpoints: + for ep in self.pentesting_information.categorized_endpoints[keys]: + print(f'ep:{ep}') + + if ep["path"] == endpoint: + print(f'ep:{ep}') + print(f' endpoint: {endpoint}') + schema = ep.get('schema', {}) + if schema != None and schema != {}: + properties = schema.get('properties', {}) + else: + properties = None + return properties + + def next_purpose(self, step, icl_steps, purpose): + """ + Updates the current pentesting purpose based on the progress of ICL steps. + + If the current purpose has no test cases left (`icl_steps` is None), it is removed from + the list of remaining purposes. Otherwise, if the current `step` matches the last explored + step, it also considers the current purpose complete and advances to the next one. + + Parameters: + step (dict or None): The current step being evaluated. + icl_steps (list or None): A list of previously explored steps. + purpose (str): The current pentesting purpose associated with the step. + + Returns: + None + """ + # Process the step and return its result + if icl_steps is None: + self.pentesting_information.pentesting_step_list.remove(purpose) + self.purpose = self.pentesting_information.pentesting_step_list[0] + self.counter = 0 # Reset counter + return + last_item = icl_steps[-1] + if self.check_if_step_is_same(last_item, step) or step is None: + # If it's the last step, remove the purpose and update self.purpose + if purpose in self.pentesting_information.pentesting_step_list: + self.pentesting_information.pentesting_step_list.remove(purpose) + if self.pentesting_information.pentesting_step_list: + self.purpose = self.pentesting_information.pentesting_step_list[0] + + self.counter = 0 # Reset counter + + def check_if_step_is_same(self, step1, step2): + """ + Compares two step dictionaries to determine if they represent the same step. + + Specifically checks if the first item in the 'steps' list of `step1` is equal to + the 'step' value of the first item in the 'steps' list of `step2`. + + Parameters: + step1 (dict): The first step to compare. + step2 (dict): The second step to compare. + + Returns: + bool: True if both steps are considered the same, False otherwise. + """ + # Check if 'steps' and 'path' are identical + steps_same = (step1.get('steps', [])[0] == step2.get('steps', [])[0].get("step")) + + return steps_same + def all_substeps_explored(self, icl_steps): + + """ + Checks whether all substeps in the provided ICL step block have already been explored. + + Compares the list of substeps in `icl_steps` against the `explored_sub_steps` list + to determine if they were previously processed. + + Parameters: + icl_steps (dict): A dictionary containing a list of steps under the 'steps' key. + + Returns: + bool: True if all substeps were explored, False otherwise. + """ + all_steps = [] + for step in icl_steps.get("steps") : + all_steps.append(step) + + if all_steps in self.explored_sub_steps: + return True + else: + return False + + + def reset_accounts(self): + self.prompt_helper.accounts = [acc for acc in self.prompt_helper.accounts if "x" in acc and acc["x"] != ""] + + def get_test_cases(self, test_cases): + """ + Attempts to retrieve a valid list of test cases. + + This method first checks if the input `test_cases` is an empty list. + If so, it iterates through the pentesting step list and attempts to fetch + non-empty test cases using `get_steps_of_phase`, skipping any already transformed steps. + + If no valid test cases are found or if `test_cases` is None, it will repeatedly call + `next_purpose()` and use `explore_steps()` until it retrieves a non-None result. + + Parameters: + test_cases (list or None): An initial set of test cases to validate or replace. + + Returns: + list or None: A valid list of test cases or None if none could be retrieved. + """ + # If test_cases is an empty list, try to find a new non-empty list from other phases + while isinstance(test_cases, list) and len(test_cases) == 0: + for purpose in self.pentesting_information.pentesting_step_list: + if purpose in self.transformed_steps.keys(): + continue + else: + test_cases = self.pentesting_information.get_steps_of_phase(purpose) + if test_cases is not None: + if len(test_cases) != 0: + return test_cases + + # If test_cases is None, keep trying next_purpose and explore_steps until something is found + if test_cases is None: + while test_cases is None: + self.next_purpose(None, test_cases, self.purpose) + test_cases = self.pentesting_information.explore_steps(self.purpose) + + return test_cases diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/__init__.py b/src/hackingBuddyGPT/utils/prompt_generation/prompts/state_learning/__init__.py similarity index 100% rename from src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/state_learning/__init__.py rename to src/hackingBuddyGPT/utils/prompt_generation/prompts/state_learning/__init__.py diff --git a/src/hackingBuddyGPT/utils/prompt_generation/prompts/state_learning/in_context_learning_prompt.py b/src/hackingBuddyGPT/utils/prompt_generation/prompts/state_learning/in_context_learning_prompt.py new file mode 100644 index 00000000..9570359d --- /dev/null +++ b/src/hackingBuddyGPT/utils/prompt_generation/prompts/state_learning/in_context_learning_prompt.py @@ -0,0 +1,427 @@ +import json +from typing import Dict, Optional, Any, List +from hackingBuddyGPT.utils.prompt_generation.information.prompt_information import ( + PromptContext, + PromptPurpose, + PromptStrategy, +) +from hackingBuddyGPT.utils.prompt_generation.prompts.state_learning.state_planning_prompt import ( + StatePlanningPrompt, +) + + +class InContextLearningPrompt(StatePlanningPrompt): + """ + A class that generates prompts using the in-context learning strategy. + + This class extends the BasicPrompt abstract base class and implements + the generate_prompt method for creating prompts based on the + in-context learning strategy. + + Attributes: + context (PromptContext): The context in which prompts are generated. + prompt_helper (PromptHelper): A helper object for managing and generating prompts. + prompt (Dict[int, Dict[str, str]]): A dictionary containing the prompts for each round. + turn (int): The round number for which the prompt is being generated. + purpose (Optional[PromptPurpose]): The purpose of the prompt generation, which can be set during the process. + open_api_spec (Any) : Samples including the context. + """ + + def __init__(self, context: PromptContext, prompt_helper, context_information: Dict[int, Dict[str, str]], + open_api_spec: Any, prompt_file : Any=None) -> None: + """ + Initializes the InContextLearningPrompt with a specific context, prompt helper, and initial prompt. + + Args: + context (PromptContext): The context in which prompts are generated. + prompt_helper (PromptHelper): A helper object for managing and generating prompts. + context_information (Dict[int, Dict[str, str]]): A dictionary containing the prompts for each round. + """ + super().__init__(context=context, prompt_helper=prompt_helper, strategy=PromptStrategy.IN_CONTEXT, prompt_file=prompt_file) + self.prompt: Dict[int, Dict[str, str]] = context_information + self.purpose: Optional[PromptPurpose] = None + self.open_api_spec = open_api_spec + self.response_history = { + } + + + + def generate_prompt( + self, move_type: str, hint: Optional[str], previous_prompt: Optional[str], turn: Optional[int] + ) -> str: + """ + Generates a prompt using the in-context learning strategy. + + Args: + move_type (str): The type of move to generate. + hint (Optional[str]): An optional hint to guide the prompt generation. + previous_prompt (List[Dict[str, str]]): A list of previous prompt entries, each containing a "content" key. + turn (Optional[int]): Current turn. + + Returns: + str: The generated prompt. + """ + if self.context == PromptContext.DOCUMENTATION: + steps = self._get_documentation_steps(move_type=move_type, previous_prompt=previous_prompt, doc_steps=self.get_documentation_steps()) + elif self.context == PromptContext.PENTESTING: + steps = self._get_pentesting_steps(move_type=move_type) + else: + steps = self.parse_prompt_file() + steps = self._get_documentation_steps(move_type=move_type, previous_prompt=previous_prompt, + doc_steps=steps) + + + + if hint: + steps = steps + [hint] + + return self.prompt_helper._check_prompt(previous_prompt=previous_prompt, steps=steps) + + def _get_documentation_steps(self, move_type: str, previous_prompt, doc_steps: Any) -> List[str]: + """ + Generates documentation steps based on the current API specification, previous prompts, + and the intended move type. + + Args: + move_type (str): Determines the strategy to apply. Accepted values: + - "explore": Generates initial documentation steps for exploration. + - Any other value: Triggers identification of endpoints needing more help. + previous_prompt (Any): A history of previously generated prompts used to determine + which endpoints have already been addressed. + doc_steps (Any): Existing documentation steps that are modified or expanded based on + the selected move_type. + + Returns: + List[str]: A list of documentation prompts tailored to the move_type and current context. + """ + # Extract properties and example response + if "endpoints" in self.open_api_spec: + properties = self.extract_properties() + example_response = {} + endpoint = "" + endpoints = [endpoint for endpoint in self.open_api_spec["endpoints"]] + if len(endpoints) > 0: + previous_prompt = self.sort_previous_prompt(previous_prompt) + for prompt in previous_prompt: + if isinstance(prompt, dict) and prompt["role"] == "system": + if endpoints[0] not in prompt["content"]: + endpoint = endpoints[0] + else: + for ep in endpoints: + if ep not in prompt["content"]: + endpoint = ep + + break + + # if endpoint != "": break + method_example_response = self.extract_example_response(self.open_api_spec["endpoints"], + endpoint=endpoint) + + icl_prompt = self.generate_icl_prompt(properties, method_example_response, endpoint) + else: + icl_prompt = "" + else: + icl_prompt = "" + + if move_type == "explore": + icl = [[f"Based on this information :\n{icl_prompt}\n" + doc_steps[0][0]]] + # if self.current_step == 0: + # self.current_step == 1 + doc_steps = icl + doc_steps[1:] + # self.current_step += 1 + return self.prompt_helper.get_initial_documentation_steps( + strategy_steps=doc_steps) + else: + return self.prompt_helper.get_endpoints_needing_help( + info=f"Based on this information :\n{icl_prompt}\n Do the following: ") + + + def extract_example_response(self, api_paths, endpoint, method="get"): + """ + Extracts a representative example response for a specified API endpoint and method + from an OpenAPI specification. + Args: + api_paths (dict): A dictionary representing the paths section of the OpenAPI spec, + typically `self.open_api_spec["endpoints"]`. + endpoint (str): The specific API endpoint to extract the example from (e.g., "/users"). + method (str, optional): The HTTP method to consider (e.g., "get", "post"). + Defaults to "get". + + Returns: + dict: A dictionary with the HTTP method as the key and the extracted example + response as the value. If no suitable example is found, returns an empty dict. + Format: { "get": { "exampleName": exampleData } } + """ + example_method = {} + example_response = {} + # Ensure that the provided endpoint and method exist in the schema + if endpoint in api_paths and method in api_paths[endpoint]: + responses = api_paths[endpoint][method].get("responses", {}) + + # Check for response code 200 and application/json content type + if '200' in responses: + content = responses['200'].get("content", {}) + if "application/json" in content: + examples = content["application/json"].get("examples", {}) + + # Extract example responses + for example_name, example_details in examples.items(): + if len(example_response) == 1: + break + if isinstance(example_details, dict): + + example_value = example_details.get("value", {}) + data = example_value.get("data", []) + + else: + print(f'example_details: {example_details}') + example_value = example_details + data = example_details + + if isinstance(data, list) and data != []: + data = data[0] + example_response[example_name] = data + + example_method[method] = example_response + + return example_method + + # Function to generate the prompt for In-Context Learning + def generate_icl_prompt(self, properties, example_response, endpoint): + """ + Generates an in-context learning (ICL) prompt to guide a language model in understanding + and documenting a REST API endpoint. + + Args: + properties (dict): A dictionary of property names to their types and example values. + Format: { "property_name": {"type": "string", "example": "value"} } + example_response (dict): A dictionary containing example API responses, typically extracted + using `extract_example_response`. Format: { "get": { ...example... } } + endpoint (str): The API endpoint path (e.g., "/users"). + + Returns: + str: A formatted prompt string containing API metadata, property descriptions, + and a JSON-formatted example response. + """ + # Core information about API + if len(example_response.keys()) > 0: + prompt = f"# REST API: {list(example_response.keys())[0].upper()} {endpoint}\n\n" + else: + prompt = f"# REST API: {endpoint}\n\n" + + + # Add properties to the prompt + counter = 0 + if len(properties) == 0: + properties = self.extract_properties_with_examples(example_response) + for prop, details in properties.items(): + if counter == 0: + prompt += "This API retrieves objects with the following properties:\n" + prompt += f"- {prop}:{details['type']} (e.g., {details['example']})\n" + counter += 1 + + # Add an example response to the prompt + prompt += "\nExample Response:\n`" + if example_response != {}: + example_key = list(example_response.keys())[0] # Take the first example for simplicity + example_json = json.dumps(example_response[example_key], indent=2) + prompt += example_json + + return prompt + + def extract_properties_with_examples(self, data): + """ + Extracts and flattens properties from a nested dictionary or list of dictionaries, + producing a dictionary of property names along with their inferred types and example values. + + Args: + data (dict or list): The input data, usually an example API response. This can be: + - A single dictionary (representing a single API object). + - A list of dictionaries (representing a collection of API objects). + - A special-case dict with a single `None` key, which is unwrapped. + + Returns: + dict: A dictionary mapping property names to a dictionary with keys: + - "type": The inferred data type (e.g., "string", "integer"). + - "example": A sample value for the property. + Format: { "property_name": {"type": "string", "example": "value"} } + """ + + # Handle nested dictionaries, return flattened properties + + if isinstance(data, dict) and len(data) == 1 and list(data.keys())[0] is None: + data = list(data.values())[0] + + result = {} + if isinstance(data, list): + for item in data: + result = self.get_props(item, result) + + + else: + result = self.get_props(data, result) + + return result + + + def transform_into_prompt_structure_with_previous_examples(self, test_case, purpose): + """ + Transforms a single test case into a In context learning structure. + + The transformation emphasizes breaking tasks into hierarchical phases and embedding conditional logic + to adaptively handle outcomes, inspired by strategies in recent research on structured reasoning. + + Args: + test_case (dict): A dictionary representing a single test case with fields like 'objective', 'steps', and 'security'. + + Returns: + dict: A transformed test case structured hierarchically and conditionally. + """ + + # Initialize the transformed test case + + transformed_case = { + "phase_title": f"Phase: {test_case['objective']}", + "steps": [], + "assessments": [], + "path": test_case.get("path") + } + + + # Process steps in the test case + counter = 0 + for step in test_case["steps"]: + if counter < len(test_case["security"]): + security = test_case["security"][counter] + else: + security = test_case["security"][0] + + if len(test_case["steps"]) > 1: + if counter 1: + if self.counter == idx: + result.append(f" {step_details['step']}\n") + result.append(f"Example: {self.get_properties(step_details)}") + else: + result.append(f" {step_details['step']}\n") + result.append(f"Example: {self.get_properties(step_details)}") + + + # Add phase assessments + if character == "assessments": + result.append("\nAssessments:\n") + for assessment in test_case["assessments"]: + result.append(f" - {assessment}\n") + + # Add the final assessment if applicable + if character == "final_assessment": + if "final_assessment" in test_case: + result.append(f"\nFinal Assessment:\n {test_case['final_assessment']}\n") + + return ''.join(result) + + def get_props(self, data:dict, result:dict ): + """ + Recursively extracts properties from a dictionary, including nested dictionaries and lists, + and appends them to the result dictionary with their inferred data types and example values. + + Returns: + dict: The updated result dictionary containing all extracted properties, including those + found in nested dictionaries or lists. + """ + + for key, value in data.items(): + + if isinstance(value, dict): + + # Recursively extract properties from nested dictionaries + + nested_properties = self.extract_properties_with_examples(value) + + result.update(nested_properties) + + elif isinstance(value, list): + + if value: + + example_value = value[0] + + result[key] = {"type": "list", "example": example_value} + + else: + + result[key] = {"type": "list", "example": "[]"} + else: + + result[key] = {"type": type(value).__name__, "example": value} + + return result + + diff --git a/src/hackingBuddyGPT/utils/prompt_generation/prompts/state_learning/state_planning_prompt.py b/src/hackingBuddyGPT/utils/prompt_generation/prompts/state_learning/state_planning_prompt.py new file mode 100644 index 00000000..b89af36e --- /dev/null +++ b/src/hackingBuddyGPT/utils/prompt_generation/prompts/state_learning/state_planning_prompt.py @@ -0,0 +1,135 @@ +from abc import abstractmethod +from typing import List, Any + +from hackingBuddyGPT.utils.prompt_generation.information import PenTestingInformation +from hackingBuddyGPT.utils.prompt_generation.information.prompt_information import ( + PlanningType, + PromptContext, + PromptStrategy, PromptPurpose, +) +from hackingBuddyGPT.utils.prompt_generation.prompts import ( + BasicPrompt, +) + + +class StatePlanningPrompt(BasicPrompt): + """ + A class for generating state planning prompts, including strategies like In-Context Learning (ICL). + + This class extends BasicPrompt to provide specific implementations for state planning strategies, focusing on + adapting prompts based on the current context or state of information provided. + + Attributes: + context (PromptContext): The context in which prompts are generated. + prompt_helper (PromptHelper): A helper object for managing and generating prompts. + strategy (PromptStrategy): The strategy used for prompt generation, typically state-oriented like ICL. + pentesting_information (Optional[PenTestingInformation]): Contains information relevant to pentesting when the context is pentesting. + """ + + def __init__(self, context: PromptContext, prompt_helper, strategy: PromptStrategy, prompt_file: Any=None): + """ + Initializes the StatePlanningPrompt with a specific context, prompt helper, and strategy. + + Args: + context (PromptContext): The context in which prompts are generated. + prompt_helper (PromptHelper): A helper object for managing and generating prompts. + strategy (PromptStrategy): The state planning strategy used for prompt generation. + """ + super().__init__( + context=context, + planning_type=PlanningType.STATE_PLANNING, + prompt_helper=prompt_helper, + strategy=strategy, + prompt_file=prompt_file + ) + self.explored_steps: List[str] = [] + self.transformed_steps ={} + + def set_pentesting_information(self, pentesting_information: PenTestingInformation): + self.pentesting_information = pentesting_information + self.purpose = self.pentesting_information.pentesting_step_list[0] + self.pentesting_information.next_testing_endpoint() + + + def _get_pentesting_steps(self, move_type: str) -> List[str]: + """ + Provides the steps for the chain-of-thought strategy when the context is pentesting. + + Args: + move_type (str): The type of move to generate. + common_step (Optional[str]): A common step prefix to apply to each generated step. + + Returns: + List[str]: A list of steps for the chain-of-thought strategy in the pentesting context. + """ + if self.previous_purpose != self.purpose: + self.previous_purpose = self.purpose + self.reset_accounts() + self.test_cases = self.pentesting_information.explore_steps(self.purpose) + if self.purpose == PromptPurpose.SETUP: + if self.counter == 0: + self.prompt_helper.accounts = self.pentesting_information.accounts + else: + self.pentesting_information.accounts = self.prompt_helper.accounts + + else: + + self.prompt_helper.accounts = self.pentesting_information.accounts + purpose = self.purpose + + if move_type == "explore": + test_cases = self.get_test_cases(self.test_cases) + for test_case in test_cases: + if purpose not in self.transformed_steps.keys(): + self.transformed_steps[purpose] = [] + # Transform steps into icl based on purpose + self.transformed_steps[purpose].append( + self.transform_into_prompt_structure_with_previous_examples(test_case, purpose) + ) + + # Extract the CoT for the current purpose + icl_steps = self.transformed_steps[purpose] + + # Process steps one by one, with memory of explored steps and conditional handling + for icl_test_case in icl_steps: + if icl_test_case not in self.explored_steps and not self.all_substeps_explored(icl_test_case): + self.current_step = icl_test_case + # single step test case + if len(icl_test_case.get("steps")) == 1: + self.current_sub_step = icl_test_case.get("steps")[0] + self.current_sub_step["path"] = icl_test_case.get("path")[0] + else: + if self.counter < len(icl_test_case.get("steps")): + # multi-step test case + self.current_sub_step = icl_test_case.get("steps")[self.counter] + if len(icl_test_case.get("path")) > 1: + self.current_sub_step["path"] = icl_test_case.get("path")[self.counter] + self.explored_sub_steps.append(self.current_sub_step) + self.explored_steps.append(icl_test_case) + + self.prompt_helper.current_user = self.prompt_helper.get_user_from_prompt(self.current_sub_step, self.pentesting_information.accounts) + self.prompt_helper.counter = self.counter + + + + step = self.transform_test_case_to_string(self.current_step, "steps") + if self.prompt_helper.current_user is not None or isinstance(self.prompt_helper.current_user, + dict): + if "token" in self.prompt_helper.current_user and "'{{token}}'" in step: + step = step.replace("'{{token}}'", self.prompt_helper.current_user.get("token")) + self.counter += 1 + # if last step of exploration, change purpose to next + self.next_purpose(icl_test_case,test_cases, purpose) + + return [step] + + # Default steps if none match + return ["Look for exploits."] + + + @abstractmethod + def transform_into_prompt_structure_with_previous_examples(self, test_case, purpose): + pass + @abstractmethod + def transform_test_case_to_string(self, current_step, param): + pass \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/__init__.py b/src/hackingBuddyGPT/utils/prompt_generation/prompts/task_planning/__init__.py similarity index 100% rename from src/hackingBuddyGPT/usecases/web_api_testing/prompt_generation/prompts/task_planning/__init__.py rename to src/hackingBuddyGPT/utils/prompt_generation/prompts/task_planning/__init__.py diff --git a/src/hackingBuddyGPT/utils/prompt_generation/prompts/task_planning/chain_of_thought_prompt.py b/src/hackingBuddyGPT/utils/prompt_generation/prompts/task_planning/chain_of_thought_prompt.py new file mode 100644 index 00000000..0a3c4dcf --- /dev/null +++ b/src/hackingBuddyGPT/utils/prompt_generation/prompts/task_planning/chain_of_thought_prompt.py @@ -0,0 +1,249 @@ +from typing import List, Optional +from hackingBuddyGPT.utils.prompt_generation.information.prompt_information import ( + PromptContext, + PromptPurpose, + PromptStrategy, +) +from hackingBuddyGPT.utils.prompt_generation.prompts.task_planning.task_planning_prompt import ( + TaskPlanningPrompt, +) + + +class ChainOfThoughtPrompt(TaskPlanningPrompt): + """ + A class that generates prompts using the chain-of-thought strategy. + + This class extends the BasicPrompt abstract base class and implements + the generate_prompt method for creating prompts based on the + chain-of-thought strategy. + + Attributes: + context (PromptContext): The context in which prompts are generated. + prompt_helper (PromptHelper): A helper object for managing and generating prompts. + explored_steps (List[str]): A list of steps that have already been explored in the chain-of-thought strategy. + """ + + def __init__(self, context: PromptContext, prompt_helper, prompt_file): + """ + Initializes the ChainOfThoughtPrompt with a specific context and prompt helper. + + Args: + context (PromptContext): The context in which prompts are generated. + prompt_helper (PromptHelper): A helper object for managing and generating prompts. + """ + super().__init__(context=context, prompt_helper=prompt_helper, strategy=PromptStrategy.CHAIN_OF_THOUGHT, prompt_file= prompt_file) + self.counter = 0 + + def generate_prompt( + self, move_type: str, hint: Optional[str], previous_prompt: Optional[str], turn: Optional[int] + ) -> str: + """ + Generates a prompt using the chain-of-thought strategy. Provides the steps for the chain-of-thought strategy based on the current context. + Args: + move_type (str): The type of move to generate. + hint (Optional[str]): An optional hint to guide the prompt generation. + previous_prompt (Optional[str]): The previous prompt content based on the conversation history. + turn (Optional[int]): The current turn of the chain-of-thought strategy. + + Returns: + str: The generated prompt. + """ + if self.context == PromptContext.DOCUMENTATION: + self.purpose = PromptPurpose.DOCUMENTATION + chain_of_thought_steps = self._get_documentation_steps([],move_type, self.get_documentation_steps()) + chain_of_thought_steps = [chain_of_thought_steps[0]] + [ + "Let's think step by step"] + chain_of_thought_steps[1:] + + elif self.context == PromptContext.PENTESTING: + chain_of_thought_steps = self._get_pentesting_steps(move_type,"") + else: + steps = self.parse_prompt_file() + chain_of_thought_steps = self._get_documentation_steps([],move_type, steps) + + chain_of_thought_steps = [chain_of_thought_steps[0]] + [ + "Let's think step by step"] + chain_of_thought_steps[1:] + if hint: + chain_of_thought_steps.append(hint) + + return self.prompt_helper._check_prompt(previous_prompt=previous_prompt, steps=chain_of_thought_steps) + + def get_documentation_steps(self): + """ + Returns a predefined list of endpoint exploration steps based on the target API host. + + These steps are used to guide automated documentation of a web API by progressively + discovering and querying endpoints using GET requests. The process follows a structured + hierarchy from root-level endpoints to more complex nested endpoints and those with query parameters. + + Returns: + List[List[str]]: A list of steps, each step being a list of instruction strings. + """ + + # Define specific documentation steps based on the given strategy + + return [ + [f"Objective: Identify all accessible endpoints via GET requests for {self.prompt_helper.host}. {self.prompt_helper._description}"], + [ + f""" Query root-level resource endpoints. + Find root-level endpoints for {self.prompt_helper.host}. + Only send GET requests to root-level endpoints with a single path component after the root. This means each path should have exactly one '/' followed by a single word (e.g., '/users', '/products'). + 1. Send GET requests to new paths only, avoiding any in the lists above. + 2. Do not reuse previously tested paths.""" + + ], + [ + "Query Instance-level resource endpoint with id", + "Look for Instance-level resource endpoint : Identify endpoints of type `/resource/id` where id is the parameter for the id.", + "Query these `/resource/id` endpoints to see if an `id` parameter resolves the request successfully." + "Ids can be integers, longs or base62." + + ], + [ + "Query Subresource Endpoints", + "Identify subresource endpoints of the form `/resource/other_resource`.", + "Query these endpoints to check if they return data related to the main resource without requiring an `id` parameter." + + + ], + + [ + "Query for related resource endpoints", + "Identify related resource endpoints that match the format `/resource/id/other_resource`: " + f"First, scan for the follwoing endpoints where an `id` in the middle position and follow them by another resource identifier.", + "Second, look for other endpoints and query these endpoints with appropriate `id` values to determine their behavior and document responses or errors." + ], + [ + "Query multi-level resource endpoints", + "Search for multi-level endpoints of type `/resource/other_resource/another_resource`: Identify any endpoints in the format with three resource identifiers.", + "Test requests to these endpoints, adjusting resource identifiers as needed, and analyze responses to understand any additional parameters or behaviors." + ], + [ + "Query endpoints with query parameters", + "Construct and make GET requests to these endpoints using common query parameters (e.g. `/resource?param1=1¶m2=3`) or based on documentation hints, testing until a valid request with query parameters is achieved." + ] + ] + + + def transform_into_prompt_structure(self, test_case, purpose): + """ + Transforms a single test case into a Hierarchical-Conditional Hybrid Chain-of-Prompt structure. + + The transformation emphasizes breaking tasks into hierarchical phases and embedding conditional logic + to adaptively handle outcomes, inspired by strategies in recent research on structured reasoning. + + Args: + test_case (dict): A dictionary representing a single test case with fields like 'objective', 'steps', and 'security'. + + Returns: + dict: A transformed test case structured hierarchically and conditionally. + """ + + # Initialize the transformed test case + + transformed_case = { + "phase_title": f"Phase: {test_case['objective']}", + "steps": [], + "assessments": [], + "path": test_case.get("path") + } + + # Process steps in the test case + counter = 0 + #print(f' test case:{test_case}') + for step in test_case["steps"]: + if counter < len(test_case["security"]): + security = test_case["security"][counter] + else: + security = test_case["security"][0] + + if len(test_case["steps"]) > 1: + if counter < len(test_case["expected_response_code"]): + expected_response_code = test_case["expected_response_code"][counter] + + else: + expected_response_code = test_case["expected_response_code"] + + token = test_case["token"][counter] + path = test_case["path"][counter] + else: + expected_response_code = test_case["expected_response_code"] + token = test_case["token"][0] + path = test_case["path"][0] + + step_details = { + "purpose": purpose, + "step": step, + "expected_response_code": expected_response_code, + "security": security, + "conditions": { + "if_successful": "No Vulnerability found.", + "if_unsuccessful": "Vulnerability found." + }, + "token": token, + "path": path + } + counter += 1 + transformed_case["steps"].append(step_details) + + + return transformed_case + + def transform_test_case_to_string(self, test_case, character): + """ + Transforms a single test case into a formatted string representation. + + Args: + test_case (dict): A dictionary representing a single test case transformed into a hierarchical structure. + + Returns: + str: A formatted string representation of the test case. + """ + # Initialize the result string + result = [] + + # Add the phase title + result.append(f"{test_case['phase_title']}\n") + + # Add each step with conditions + if character == "steps": + result.append("Let's think step by step.") + result.append("Steps:\n") + for idx, step_details in enumerate(test_case["steps"], start=1): + result.append(f" Step {idx}:\n") + result.append(f" {step_details['step']}\n") + + # Add phase assessments + if character == "assessments": + result.append("\nAssessments:\n") + for assessment in test_case["assessments"]: + result.append(f" - {assessment}\n") + + # Add the final assessment if applicable + if character == "final_assessment": + if "final_assessment" in test_case: + result.append(f"\nFinal Assessment:\n {test_case['final_assessment']}\n") + + return ''.join(result) + + def generate_documentation_steps(self, steps) -> list: + """ + Creates a chain of thought prompt to guide the model through the API documentation process. + + Args: + steps (list): A list of steps, where each step is a list. The first element + of each inner list is the step title, followed by its sub-steps or details. + + Returns: + list: A transformed list where each step (except the first) is prefixed with + "Step X:" headers and includes its associated sub-steps. + """ + + transformed_steps = [steps[0]] + + for index, steps in enumerate(steps[1:], start=1): + step_header = f"Step {index}: {steps[0]}" + detailed_steps = steps[1:] + transformed_step = [step_header] + detailed_steps + transformed_steps.append(transformed_step) + + return transformed_steps diff --git a/src/hackingBuddyGPT/utils/prompt_generation/prompts/task_planning/task_planning_prompt.py b/src/hackingBuddyGPT/utils/prompt_generation/prompts/task_planning/task_planning_prompt.py new file mode 100644 index 00000000..a6a6f84e --- /dev/null +++ b/src/hackingBuddyGPT/utils/prompt_generation/prompts/task_planning/task_planning_prompt.py @@ -0,0 +1,215 @@ +from abc import abstractmethod + +from hackingBuddyGPT.utils.prompt_generation.information.prompt_information import ( + PlanningType, + PromptContext, + PromptStrategy, + PromptPurpose, +) +from hackingBuddyGPT.utils.prompt_generation.prompts import ( + BasicPrompt, +) + +from typing import List, Optional, Any + + +class TaskPlanningPrompt(BasicPrompt): + """ + A class for generating task planning prompts, including strategies like Chain-of-Thought (CoT) and Tree-of-Thought (ToT). + + This class extends BasicPrompt to provide specific implementations for task planning strategies, allowing for + detailed step-by-step reasoning or exploration of multiple potential reasoning paths. + + Attributes: + context (PromptContext): The context in which prompts are generated. + prompt_helper (PromptHelper): A helper object for managing and generating prompts. + strategy (PromptStrategy): The strategy used for prompt generation, which could be CoT, ToT, etc. + pentesting_information (Optional[PenTestingInformation]): Contains information relevant to pentesting when the context is pentesting. + """ + + def __init__(self, context: PromptContext, prompt_helper, strategy: PromptStrategy, prompt_file : Any=None): + """ + Initializes the TaskPlanningPrompt with a specific context, prompt helper, and strategy. + + Args: + context (PromptContext): The context in which prompts are generated. + prompt_helper (PromptHelper): A helper object for managing and generating prompts. + strategy (PromptStrategy): The task planning strategy used for prompt generation. + """ + super().__init__( + context=context, + planning_type=PlanningType.TASK_PLANNING, + prompt_helper=prompt_helper, + strategy=strategy, + prompt_file= prompt_file + ) + self.explored_steps: List[str] = [] + self.purpose: Optional[PromptPurpose] = None + self.phase = None + self.transformed_steps = {} + self.pentest_steps = None + + def _get_documentation_steps(self, common_steps: List[str], move_type: str, steps: Any) -> List[str]: + """ + Provides the steps for the task learning prompt when the context is documentation. + + Args: + common_steps (List[str]): A list of common steps for generating prompts. + move_type (str): The type of move to generate. + steps (Any): steps that are transformed into task planning prompt + + Returns: + List[str]: A list of steps for the chain-of-thought strategy in the documentation context. + """ + if move_type == "explore": + doc_steps = self.generate_documentation_steps(steps) + return self.prompt_helper.get_initial_documentation_steps( + strategy_steps= doc_steps) + else: + return self.prompt_helper.get_endpoints_needing_help() + + def _get_pentesting_steps(self, move_type: str, common_step: Optional[str] = "") -> Any: + """ + Provides the steps for the chain-of-thought strategy when the context is pentesting. + + Args: + move_type (str): The type of move to generate. + common_step (Optional[str]): A list of common steps for generating prompts. + + Returns: + List[str]: A list of steps for the chain-of-thought strategy in the pentesting context. + """ + if self.previous_purpose != self.purpose: + self.previous_purpose = self.purpose + self.reset_accounts() + self.test_cases = self.pentesting_information.explore_steps(self.purpose) + if self.purpose == PromptPurpose.SETUP: + if self.counter == 0: + self.prompt_helper.accounts = self.pentesting_information.accounts + + else: + self.pentesting_information.accounts = self.prompt_helper.accounts + + else: + + self.prompt_helper.accounts = self.pentesting_information.accounts + + purpose = self.purpose + + if move_type == "explore": + test_cases = self.get_test_cases(self.test_cases) + + for test_case in test_cases: + + if purpose not in self.transformed_steps.keys(): + self.transformed_steps[purpose] = [] + # Transform steps into icl based on purpose + self.transformed_steps[purpose].append( + self.transform_into_prompt_structure(test_case, purpose) + ) + + # Extract the Task planning test cases for the current purpose + task_planning_test_cases = self.transformed_steps[purpose] + + # Process steps one by one, with memory of explored steps and conditional handling + for task_planning_test_case in task_planning_test_cases: + if task_planning_test_case not in self.explored_steps and not self.all_substeps_explored(task_planning_test_case): + self.current_step = task_planning_test_case + # single step test case + if len(task_planning_test_case.get("steps")) == 1: + self.current_sub_step = task_planning_test_case.get("steps")[0] + self.current_sub_step["path"] = task_planning_test_case.get("path")[0] + else: + if self.counter < len(task_planning_test_case.get("steps")): + # multi-step test case + self.current_sub_step = task_planning_test_case.get("steps")[self.counter] + if len(task_planning_test_case.get("path")) > 1: + self.current_sub_step["path"] = task_planning_test_case.get("path")[self.counter] + self.explored_sub_steps.append(self.current_sub_step) + self.explored_steps.append(task_planning_test_case) + + + + self.prompt_helper.current_user = self.prompt_helper.get_user_from_prompt(self.current_sub_step, + self.pentesting_information.accounts) + + self.prompt_helper.counter = self.counter + + step = self.transform_test_case_to_string(self.current_step, "steps") + + if self.prompt_helper.current_user is not None or isinstance(self.prompt_helper.current_user,dict): + if "token" in self.prompt_helper.current_user and "'{{token}}'" in step: + step = step.replace("'{{token}}'", self.prompt_helper.current_user.get("token")) + + self.counter += 1 + # if last step of exploration, change purpose to next + self.next_purpose(task_planning_test_case, test_cases, purpose) + + return [step] + + # Default steps if none match + return ["Look for exploits."] + + + def _get_common_steps(self) -> List[str]: + """ + Provides a list of common steps for generating prompts. + + Returns: + List[str]: A list of common steps for generating prompts. + + """ + if self.strategy == PromptStrategy.CHAIN_OF_THOUGHT: + if self.context == PromptContext.DOCUMENTATION: + return [ + "Identify common data structures returned by various endpoints and define them as reusable schemas. " + "Determine the type of each field (e.g., integer, string, array) and define common response structures as components that can be referenced in multiple endpoint definitions.", + "Create an OpenAPI document including metadata such as API title, version, and description, define the base URL of the API, list all endpoints, methods, parameters, and responses, and define reusable schemas, response types, and parameters.", + "Ensure the correctness and completeness of the OpenAPI specification by validating the syntax and completeness of the document using tools like Swagger Editor, and ensure the specification matches the actual behavior of the API.", + "Refine the document based on feedback and additional testing, share the draft with others, gather feedback, and make necessary adjustments. Regularly update the specification as the API evolves.", + "Make the OpenAPI specification available to developers by incorporating it into your API documentation site and keep the documentation up to date with API changes.", + ] + else: + return [ + "Identify common data structures returned by various endpoints and define them as reusable schemas, specifying field types like integer, string, and array.", + "Create an OpenAPI document that includes API metadata (title, version, description), the base URL, endpoints, methods, parameters, and responses.", + "Ensure the document's correctness and completeness using tools like Swagger Editor, and verify it matches the API's behavior. Refine the document based on feedback, share drafts for review, and update it regularly as the API evolves.", + "Make the specification available to developers through the API documentation site, keeping it current with any API changes.", + ] + elif self.strategy == PromptStrategy.TREE_OF_THOUGHT: + if self.context == PromptContext.DOCUMENTATION: + return [ + "Imagine three different OpenAPI specification specialists.\n" + "All experts will write down one step of their thinking,\n" + "then share it with the group.\n" + "After that, all remaining specialists will proceed to the next step, and so on.\n" + "If any specialist realizes they're wrong at any point, they will leave.\n" + f"The question is: " + + ] + else: + return [ + "Imagine three different Pentest experts are answering this question.\n" + "All experts will write down one step of their thinking,\n" + "then share it with the group.\n" + "After that, all experts will proceed to the next step, and so on.\n" + "If any expert realizes they're wrong at any point, they will leave.\n" + f"The question is: " + ] + + else: + raise TypeError(f"There exists no PromptStrategy of the type {self.strategy}") + + @abstractmethod + def generate_documentation_steps(self, steps: List[str]) -> List[str] : + pass + + + @abstractmethod + def transform_test_case_to_string(self, current_step, param): + pass + + @abstractmethod + def transform_into_prompt_structure(self, test_case, purpose): + pass + diff --git a/src/hackingBuddyGPT/utils/prompt_generation/prompts/task_planning/tree_of_thought_prompt.py b/src/hackingBuddyGPT/utils/prompt_generation/prompts/task_planning/tree_of_thought_prompt.py new file mode 100644 index 00000000..0944b614 --- /dev/null +++ b/src/hackingBuddyGPT/utils/prompt_generation/prompts/task_planning/tree_of_thought_prompt.py @@ -0,0 +1,312 @@ +from typing import Optional, List, Dict + +from hackingBuddyGPT.utils.prompt_generation.information.prompt_information import ( + PromptContext, + PromptPurpose, + PromptStrategy, +) +from hackingBuddyGPT.utils.prompt_generation.prompts.task_planning import ( + TaskPlanningPrompt, +) +from hackingBuddyGPT.usecases.web_api_testing.utils.custom_datatypes import Prompt + + +class TreeOfThoughtPrompt(TaskPlanningPrompt): + """ + A class that generates prompts using the tree-of-thought strategy. + + This class extends the BasicPrompt abstract base class and implements + the generate_prompt method for creating prompts based on the + tree-of-thought strategy. + + Attributes: + context (PromptContext): The context in which prompts are generated. + prompt_helper (PromptHelper): A helper object for managing and generating prompts. + rest_api (str): The REST API endpoint for which prompts are generated. + round (int): The round number for the prompt generation process. + purpose (Optional[PromptPurpose]): The purpose of the prompt generation, which can be set during the process. + """ + + def __init__(self, context: PromptContext, prompt_helper, prompt_file) -> None: + """ + Initializes the TreeOfThoughtPrompt with a specific context and prompt helper. + + Args: + context (PromptContext): The context in which prompts are generated. + prompt_helper (PromptHelper): A helper object for managing and generating prompts. + round (int): The round number for the prompt generation process. + """ + super().__init__(context=context, prompt_helper=prompt_helper, strategy=PromptStrategy.TREE_OF_THOUGHT, prompt_file=prompt_file) + + def generate_prompt(self, move_type: str, hint: Optional[str], previous_prompt: Prompt, turn: Optional[int]) -> str: + """ + Generates a prompt using the tree-of-thought strategy. + + Args: + move_type (str): The type of move to generate. + hint (Optional[str]): An optional hint to guide the prompt generation. + previous_prompt (List[Union[ChatCompletionMessage, ChatCompletionMessageParam]]): A list of previous prompt entries, each containing a "content" key. + turn (Optional[int]): The current turn or step in the conversation. + + Returns: + str: The generated prompt. + """ + common_steps = self._get_common_steps() + if self.context == PromptContext.DOCUMENTATION: + self.purpose = PromptPurpose.DOCUMENTATION + tree_of_thought_steps = self._get_documentation_steps(common_steps, move_type, self.get_documentation_steps()) + tree_of_thought_steps = [ + "Imagine three experts each proposing one step at a time. If an expert realizes their step was incorrect, they leave. The question is:"] + tree_of_thought_steps + + elif self.context == PromptContext.PENTESTING: + tree_of_thought_steps = self._get_pentesting_steps(move_type) + else: + steps = self.parse_prompt_file() + + tree_of_thought_steps = self._get_documentation_steps(common_steps, move_type, steps) + + + tree_of_thought_steps = ([ + "Imagine three experts each proposing one step at a time. If an expert realizes their step was incorrect, they leave. The question is:"] + + tree_of_thought_steps) + if hint: + tree_of_thought_steps.append(hint) + + + return self.prompt_helper._check_prompt(previous_prompt=previous_prompt, steps=tree_of_thought_steps) + + + def transform_into_prompt_structure(self, test_case, purpose): + """ + Transforms a single test case into a Tree-of-Thought structure. + + The transformation incorporates branching reasoning paths, self-evaluation at decision points, + and backtracking to enable deliberate problem-solving. + + Args: + test_case (dict): A dictionary representing a single test case with fields like 'objective', 'steps', + 'security', and 'expected_response_code'. + purpose (str): The overarching purpose of the test case. + + Returns: + dict: A transformed test case structured as a Tree-of-Thought process. + """ + + # Initialize the root of the tree + transformed_case = { + "purpose": purpose, + "root": f"Objective: {test_case['objective']}", + "steps": [], + "assessments": [], + "path": test_case.get("path") + } + counter = 0 + # Process steps in the test case as potential steps + for i, step in enumerate(test_case["steps"]): + if counter < len(test_case["security"]): + security = test_case["security"][counter] + else: + security = test_case["security"][0] + + if len(test_case["steps"]) > 1: + if counter < len(test_case["expected_response_code"]): + expected_response_code = test_case["expected_response_code"][counter] + + else: + expected_response_code = test_case["expected_response_code"] + + print(f'COunter: {counter}') + token = test_case["token"][counter] + path = test_case["path"][counter] + else: + expected_response_code = test_case["expected_response_code"] + token = test_case["token"][0] + path = test_case["path"][0] + + + step = """Imagine three different experts are answering this question. + All experts will write down 1 step of their thinking, + then share it with the group. + Then all experts will go on to the next step, etc. + If any expert realises they're wrong at any point then they leave. + The question is : """ + step + + + # Define a branch representing a single reasoning path + branch = { + "purpose": purpose, + "step": step, + "security": security, + "expected_response_code": expected_response_code, + "conditions": { + "if_successful": "No Vulnerability found.", + "if_unsuccessful": "Vulnerability found." + }, + "token": token, + "path": path + } + # Add branch to the tree + transformed_case["steps"].append(branch) + + + + return transformed_case + + + def transform_test_case_to_string(self, tree_of_thought, character): + """ + Transforms a Tree-of-Thought structured test case into a formatted string representation. + + Args: + tree_of_thought (dict): The output from the `transform_to_tree_of_thought` function, representing + a tree-structured test case. + character (str): The focus of the transformation, which could be 'steps', 'assessments', or 'final_assessment'. + + Returns: + str: A formatted string representation of the Tree-of-Thought structure. + """ + # Initialize the result string + result = [] + + # Add the root objective + result.append(f"Root Objective: {tree_of_thought['root']}\n\n") + + # Handle steps + if character == "steps": + result.append("Tree of Thought:\n") + for idx, branch in enumerate(tree_of_thought["steps"], start=1): + result.append(f" Branch {idx}:\n") + result.append(f" Step: {branch['step']}\n") + result.append(f" Security: {branch['security']}\n") + result.append(f" Expected Response Code: {branch['expected_response_code']}\n") + result.append("\n") + + # Handle assessments + if character == "assessments": + result.append("\nAssessments:\n") + for assessment in tree_of_thought["assessments"]: + result.append(f" - {assessment['phase_review']}\n") + + # Handle final assessment + if character == "final_assessment": + if "final_assessment" in tree_of_thought: + final_assessment = tree_of_thought["final_assessment"] + result.append(f"\nFinal Assessment:\n") + result.append(f" Criteria: {final_assessment['criteria']}\n") + result.append(f" Next Action: {final_assessment['next_action']}\n") + + return ''.join(result) + + def transform_to_tree_of_thoughtx(self, prompts: Dict[str, List[List[str]]]) -> Dict[str, List[str]]: + """ + Transforms prompts into a "Tree of Thought" (ToT) format with branching paths, checkpoints, + and conditional steps for flexible, iterative problem-solving as per Tree of Thoughts methodology. + Explanation and Justification + + This implementation aligns closely with the Tree of Thought (ToT) principles outlined by Xie et al. (2023): + + Iterative Evaluation: Each step incorporates assessment points to check if the outcome meets expectations, partially succeeds, or fails, facilitating iterative refinement. + + Dynamic Branching: Conditional steps allow for the creation of alternative paths ("sub-steps") based on intermediate outcomes. This enables the prompt to pivot when initial strategies don’t fully succeed. + + Decision Nodes: Decision nodes evaluate whether to proceed, retry, or backtrack, supporting a flexible problem-solving strategy. This approach mirrors the tree-based structure proposed in ToT, where decisions at each node guide the overall trajectory. + + Progress Checkpoints: Regular checkpoints ensure that each level’s insights are documented and assessed for readiness to proceed. This helps manage complex tasks by breaking down the process into comprehensible phases, similar to how ToT manages complexity in problem-solving. + + Hierarchical Structure: Each level in the hierarchy deepens the model's understanding, allowing for more detailed exploration at higher levels, a core concept in ToT’s approach to handling multi-step tasks. + + Args: + prompts (Dict[str, List[List[str]]]): Dictionary of initial steps for various purposes. + + Returns: + Dict[str, List[str]]: A dictionary where each purpose maps to a structured list of transformed steps in the ToT format. + """ + tot_prompts = {} + + for purpose, steps_list in prompts.items(): + tree_steps = [] + current_level = 1 + + for steps in steps_list: + # Iterate through each step in the current level of the tree + for step in steps: + # Main step execution path + tree_steps.append(f"Level {current_level} - Main Step: {step}") + tree_steps.append(" - Document initial observations.") + tree_steps.append(" - Assess: Is the goal partially or fully achieved?") + + # Conditional branching for flexible responses + tree_steps.append(" - If fully achieved, proceed to the next main step.") + tree_steps.append( + " - If partially achieved, identify areas that need refinement and retry with adjusted parameters.") + tree_steps.append(" - If unsuccessful, branch out to explore alternative strategies.") + + # Add sub-branch for alternative exploration + tree_steps.append( + f"Sub-Branch at Level {current_level}: Retry with alternative strategy for Step: {step}") + tree_steps.append(" - Note adjustments and compare outcomes with previous attempts.") + tree_steps.append(" - If successful, integrate findings back into the main path.") + + # Decision node for evaluating continuation or backtracking + tree_steps.append("Decision Node:") + tree_steps.append(" - Assess: Should we continue on this path, backtrack, or end this branch?") + tree_steps.append(" - If major issues persist, consider redefining prerequisites or conditions.") + + # Checkpoint for overall progress assessment at each level + tree_steps.append( + f"Progress Checkpoint at Level {current_level}: Review progress, document insights, and confirm readiness to advance.") + + # Increment to deeper level in the hierarchy for next step + current_level += 1 + + # Conclude steps for this level, reset for new purpose-specific path + tree_steps.append( + f"End of Level {current_level - 1}: Consolidate all insights before moving to the next logical phase.") + current_level = 1 # Reset level for subsequent purposes + + # Add the structured Tree of Thought with steps and checkpoints to the final prompts dictionary + tot_prompts[purpose] = tree_steps + + return tot_prompts + + def get_documentation_steps(self): + return [ + [ + f"Objective: Identify all accessible endpoints via GET requests for {self.prompt_helper.host}. {self.prompt_helper._description}"], + [ + "Start by querying root-level resource endpoints.", + "Focus on sending GET requests only to those endpoints that consist of a single path component directly following the root.", + "For instance, paths should look like '/users' or '/products', with each representing a distinct resource type.", + "Ensure to explore new paths that haven't been previously tested to maximize coverage.", + ], + [ + "Next, move to instance-level resource endpoints.", + "Identify and list endpoints formatted as `/resource/id`, where 'id' represents a dynamic parameter.", + "Attempt to query these endpoints to validate whether the 'id' parameter correctly retrieves individual resource instances.", + "Consider testing with various ID formats, such as integers, longs, or base62 encodings like '6rqhFgbbKwnb9MLmUQDhG6'." + ], + ["Now, move to query Subresource Endpoints.", + "Identify subresource endpoints of the form `/resource/other_resource`.", + "Query these endpoints to check if they return data related to the main resource without requiring an `id` parameter." + ], + [ + "Proceed to analyze related resource endpoints.", + "Identify patterns where a resource is associated with another through an 'id', formatted as `/resource/id/other_resource`.", + "Start by cataloging endpoints that fit this pattern, particularly noting the position of 'id' between two resource identifiers.", + "Then, methodically test these endpoints, using appropriate 'id' values, to explore their responses and document any anomalies or significant behaviors." + ], + [ + "Explore multi-level resource endpoints next.", + "Look for endpoints that connect multiple resources in a sequence, such as `/resource/other_resource/another_resource`.", + "Catalog each discovered endpoint that follows this structure, focusing on their hierarchical relationship.", + "Systematically test these endpoints by adjusting identifiers as necessary, analyzing the response details to decode complex relationships or additional parameters." + ], + [ + "Finally, assess endpoints that utilize query parameters.", + "Construct GET requests for endpoints by incorporating commonly used query parameters or those suggested in documentation.", + "Persistently test these configurations to confirm that each query parameter effectively modifies the response, aiming to finalize the functionality of query parameters." + ] + ] + + def generate_documentation_steps(self, steps): + return self.get_documentation_steps() diff --git a/tests/test_files/fakeapi_config.json b/tests/test_files/fakeapi_config.json new file mode 100644 index 00000000..79119311 --- /dev/null +++ b/tests/test_files/fakeapi_config.json @@ -0,0 +1,33 @@ +{ + "token": "your_api_token_here", + "name": "fake_api", + "host": "https://dummyjson.com", + "description": "API for managing users, including auth, filtering, sorting, and relations like carts/posts/todos.", + "correct_endpoints": [ + "/users", + "/users/{id}", + "/users/search", + "/users/filter", + "/user/login", + "/user/me", + "/users/add" + ], + "query_params": { + "/users": [ + "limit", + "skip", + "select", + "sortBy", + "order" + ], + "/users/search": [ + "q" + ], + "/users/filter": [ + "key", + "value" + ] + }, + "password_file": "", + "csv_file": "" +} \ No newline at end of file diff --git a/tests/test_files/oas/fakeapi_oas.json b/tests/test_files/oas/fakeapi_oas.json new file mode 100644 index 00000000..d02ebe2b --- /dev/null +++ b/tests/test_files/oas/fakeapi_oas.json @@ -0,0 +1,390 @@ +{ + "openapi": "3.0.0", + "info": { + "title": "DummyJSON Users API", + "version": "1.0.0", + "description": "API for managing users, including auth, filtering, sorting, and relations like carts/posts/todos." + }, + "servers": [ + { + "url": "https://dummyjson.com" + } + ], + "paths": { + "/users": { + "get": { + "summary": "Get all users", + "parameters": [ + { + "name": "limit", + "in": "query", + "schema": { + "type": "integer" + } + }, + { + "name": "skip", + "in": "query", + "schema": { + "type": "integer" + } + }, + { + "name": "select", + "in": "query", + "schema": { + "type": "string" + } + }, + { + "name": "sortBy", + "in": "query", + "schema": { + "type": "string" + } + }, + { + "name": "order", + "in": "query", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "List of users", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "users": { + "type": "array", + "items": { + "$ref": "#/components/schemas/User" + } + }, + "total": { + "type": "integer" + }, + "skip": { + "type": "integer" + }, + "limit": { + "type": "integer" + } + } + } + } + } + } + } + }, + "post": { + "summary": "Add a user", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/User" + } + } + } + }, + "responses": { + "200": { + "description": "Created user", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/User" + } + } + } + } + } + } + }, + "/users/{id}": { + "get": { + "summary": "Get a single user", + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "schema": { + "type": "integer" + } + } + ], + "responses": { + "200": { + "description": "User data", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/User" + } + } + } + } + } + }, + "put": { + "summary": "Update a user", + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "schema": { + "type": "integer" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/User" + } + } + } + }, + "responses": { + "200": { + "description": "Updated user", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/User" + } + } + } + } + } + }, + "delete": { + "summary": "Delete a user", + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "schema": { + "type": "integer" + } + } + ], + "responses": { + "200": { + "description": "Deleted user", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/User" + } + } + } + } + } + } + }, + "/users/search": { + "get": { + "summary": "Search users", + "parameters": [ + { + "name": "q", + "in": "query", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Search results", + "content": { + "application/json": { + "schema": { + "type": "object" + } + } + } + } + } + } + }, + "/users/filter": { + "get": { + "summary": "Filter users", + "parameters": [ + { + "name": "key", + "in": "query", + "schema": { + "type": "string" + } + }, + { + "name": "value", + "in": "query", + "schema": { + "type": "string" + } + } + ], + "responses": { + "200": { + "description": "Filtered results", + "content": { + "application/json": { + "schema": { + "type": "object" + } + } + } + } + } + } + }, + "/user/login": { + "post": { + "summary": "Login user and get tokens", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "username": { + "type": "string" + }, + "password": { + "type": "string" + }, + "expiresInMins": { + "type": "integer" + } + } + } + } + } + }, + "responses": { + "200": { + "description": "JWT tokens and user data", + "content": { + "application/json": { + "schema": { + "type": "object" + } + } + } + } + } + } + }, + "/user/me": { + "get": { + "summary": "Get current authenticated user", + "security": [ + { + "bearerAuth": [] + } + ], + "responses": { + "200": { + "description": "Authenticated user", + "content": { + "application/json": { + "schema": { + "type": "object" + } + } + } + } + } + } + }, + "/users/add": { + "post": { + "summary": "Add a new user (simulation)", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/User" + } + } + } + }, + "responses": { + "200": { + "description": "Simulated created user", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/User" + } + } + } + } + } + } + } + }, + "components": { + "schemas": { + "User": { + "type": "object", + "properties": { + "id": { + "type": "integer" + }, + "firstName": { + "type": "string" + }, + "lastName": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "gender": { + "type": "string" + }, + "email": { + "type": "string" + }, + "username": { + "type": "string" + }, + "password": { + "type": "string" + }, + "birthDate": { + "type": "string" + }, + "role": { + "type": "string" + } + } + } + }, + "securitySchemes": { + "bearerAuth": { + "type": "http", + "scheme": "bearer", + "bearerFormat": "JWT" + } + } + } +} \ No newline at end of file diff --git a/tests/test_files/oas/test_oas.json b/tests/test_files/oas/test_oas.json new file mode 100644 index 00000000..0d369300 --- /dev/null +++ b/tests/test_files/oas/test_oas.json @@ -0,0 +1,96 @@ +{ + "openapi": "3.0.0", + "info": { + "version": "1.0.0", + "title": "JSON Placeholder API", + "description": "See https://jsonplaceholder.typicode.com/" + }, + "servers": [ + { + "url": "https://jsonplaceholder.typicode.com/" + } + ], + "paths": { + "/posts": { + "get": { + "description": "Returns all posts", + "tags": ["Posts"], + "operationId": "getPosts", + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/PostsList" + } + } + } + } + } + } + }, + "/posts/{id}": { + "get": { + "description": "Returns a post by id", + "tags": ["Posts"], + "operationId": "getPost", + "parameters": [ + { + "name": "id", + "in": "path", + "required": true, + "description": "The user id.", + "schema": { + "type": "integer", + "format": "int64" + } + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Post" + } + } + } + }, + "404": { + "description": "Post not found" + } + } + } + } + }, + "components": { + "schemas": { + "PostsList": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Post" + } + }, + "Post": { + "type": "object", + "required": ["id", "userId", "title", "completed"], + "properties": { + "id": { + "type": "integer" + }, + "userId": { + "type": "integer" + }, + "title": { + "type": "string" + }, + "completed": { + "type": "string" + } + } + } + } + } +} diff --git a/tests/test_files/test_config.json b/tests/test_files/test_config.json new file mode 100644 index 00000000..0bf2ad6b --- /dev/null +++ b/tests/test_files/test_config.json @@ -0,0 +1,13 @@ +{ + "name": "test", + "token": "your_api_token_here", + "host": "https://jsonplaceholder.typicode.com/", + "description": "See https://jsonplaceholder.typicode.com/", + "correct_endpoints": [ + "/posts", + "/posts/{id}" + ], + "query_params": {}, + "password_file": "", + "csv_file": "" +} \ No newline at end of file diff --git a/tests/test_llm_handler.py b/tests/test_llm_handler.py index 9e1447ad..6a3ab57d 100644 --- a/tests/test_llm_handler.py +++ b/tests/test_llm_handler.py @@ -33,7 +33,7 @@ def test_add_created_object(self): created_object = MagicMock() object_type = "test_type" - self.llm_handler.add_created_object(created_object, object_type) + self.llm_handler._add_created_object(created_object, object_type) self.assertIn(object_type, self.llm_handler.created_objects) self.assertIn(created_object, self.llm_handler.created_objects[object_type]) @@ -43,16 +43,16 @@ def test_add_created_object_limit(self): object_type = "test_type" for _ in range(8): # Exceed the limit of 7 objects - self.llm_handler.add_created_object(created_object, object_type) + self.llm_handler._add_created_object(created_object, object_type) self.assertEqual(len(self.llm_handler.created_objects[object_type]), 7) def test_get_created_objects(self): created_object = MagicMock() object_type = "test_type" - self.llm_handler.add_created_object(created_object, object_type) + self.llm_handler._add_created_object(created_object, object_type) - created_objects = self.llm_handler.get_created_objects() + created_objects = self.llm_handler._get_created_objects() self.assertIn(object_type, created_objects) self.assertIn(created_object, created_objects[object_type]) diff --git a/tests/test_openAPI_specification_manager.py b/tests/test_openAPI_specification_manager.py index e6088c00..c5b1c96d 100644 --- a/tests/test_openAPI_specification_manager.py +++ b/tests/test_openAPI_specification_manager.py @@ -1,52 +1,174 @@ +import os import unittest -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock -from hackingBuddyGPT.capabilities.http_request import HTTPRequest -from hackingBuddyGPT.usecases.web_api_testing.documentation.openapi_specification_handler import ( - OpenAPISpecificationHandler, -) +from hackingBuddyGPT.usecases.web_api_testing.documentation import OpenAPISpecificationHandler +from hackingBuddyGPT.utils.prompt_generation import PromptGenerationHelper +from hackingBuddyGPT.utils.prompt_generation.information import PromptStrategy, PromptContext +from hackingBuddyGPT.usecases.web_api_testing.response_processing import ResponseHandler +from hackingBuddyGPT.usecases.web_api_testing.utils import LLMHandler +from hackingBuddyGPT.usecases.web_api_testing.utils.configuration_handler import ConfigurationHandler -class TestSpecificationHandler(unittest.TestCase): +class TestOpenAPISpecificationHandler(unittest.TestCase): def setUp(self): - self.llm_handler = MagicMock() - self.response_handler = MagicMock() - self.doc_handler = OpenAPISpecificationHandler(self.llm_handler, self.response_handler) - - @patch("os.makedirs") - @patch("builtins.open") - def test_write_openapi_to_yaml(self, mock_open, mock_makedirs): - self.doc_handler.write_openapi_to_yaml() - mock_makedirs.assert_called_once_with(self.doc_handler.file_path, exist_ok=True) - mock_open.assert_called_once_with(self.doc_handler.file, "w") - - # Create a mock HTTPRequest object - response_mock = MagicMock() - response_mock.action = HTTPRequest( - host="https://jsonplaceholder.typicode.com", follow_redirects=False, use_cookie_jar=True + self.llm_handler = MagicMock(spec=LLMHandler) + self.llm_handler_mock = MagicMock() + self.response_handler = MagicMock(spec=ResponseHandler) + self.strategy = PromptStrategy.IN_CONTEXT + self.url = "https://jsonplaceholder.typicode.com/" + self.description = "JSON Placeholder API" + self.name = "JSON Placeholder API" + self.llm_handler_mock = MagicMock(spec=LLMHandler) + self.config_path = os.path.join(os.path.dirname(__file__), "test_files", "test_config.json") + self.configuration_handler = ConfigurationHandler(self.config_path) + self.config = self.configuration_handler._load_config(self.config_path) + self.host = "https://jsonplaceholder.typicode.com/" + self.description = "JSON Placeholder API" + self.prompt_helper = PromptGenerationHelper(self.host, self.description) + self.response_handler = ResponseHandler(self.llm_handler_mock, PromptContext.DOCUMENTATION, self.config, + self.prompt_helper, None) + self.openapi_handler = OpenAPISpecificationHandler( + llm_handler=self.llm_handler, + response_handler=self.response_handler, + strategy=self.strategy, + url=self.url, + description=self.description, + name=self.name, ) - response_mock.action.method = "GET" - response_mock.action.path = "/test" - result = '{"key": "value"}' + def test_update_openapi_spec_success(self): + # Mock HTTP Request object + mock_request = MagicMock() + mock_request.__class__.__name__ = "HTTPRequest" + mock_request.path = "/users" + mock_request.method = "GET" - self.response_handler.parse_http_response_to_openapi_example = MagicMock( - return_value=({}, "#/components/schemas/TestSchema", self.doc_handler.openapi_spec) - ) + # Mock Response object + mock_resp = MagicMock() + mock_resp.action = mock_request + + result = ( + "HTTP/1.1 200 OK\n" + "Date: Wed, 17 Apr 2025 12:00:00 GMT\n" + "Content-Type: application/json; charset=utf-8\n" + "Content-Length: 85\n" + "Connection: keep-alive\n" + "X-Powered-By: Express\n" + "Strict-Transport-Security: max-age=31536000; includeSubDomains\n" + "Cache-Control: no-store\n" + "Set-Cookie: sessionId=abc123; HttpOnly; Secure; Path=/\r\n\r\n" + "\n" + "{\n" + ' "id": 1,\n' + ' "username": "alice@example.com",\n' + ' "role": "user",\n' + ' "token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."\n' + "}" +) + + prompt_engineer = MagicMock() + prompt_engineer.prompt_helper.current_step = 1 # Needed for replace_id_with_placeholder - endpoints = self.doc_handler.update_openapi_spec(response_mock, result) + updated_endpoints = self.openapi_handler.update_openapi_spec(mock_resp, result, prompt_engineer) - self.assertIn("/test", self.doc_handler.openapi_spec["endpoints"]) - self.assertIn("get", self.doc_handler.openapi_spec["endpoints"]["/test"]) - self.assertEqual( - self.doc_handler.openapi_spec["endpoints"]["/test"]["get"]["summary"], "GET operation on /test" + self.assertIn("/users", updated_endpoints) + self.assertIn("GET", self.openapi_handler.endpoint_methods["/users"]) + self.assertEqual(self.openapi_handler.openapi_spec["endpoints"]["/users"]["get"]["summary"], "GET operation on /users") + + def test_update_openapi_spec_unsuccessful(self): + mock_request = MagicMock() + mock_request.__class__.__name__ = "HTTPRequest" + mock_request.path = "/invalid" + mock_request.method = "POST" + + mock_resp = MagicMock() + mock_resp.action = mock_request + + result = ( + "HTTP/1.1 404 Not Found\n" + "Date: Wed, 17 Apr 2025 12:00:00 GMT\n" + "Content-Type: application/json; charset=utf-8\n" + "Content-Length: 85\n" + "Connection: keep-alive\n" + "X-Powered-By: Express\n" + "Strict-Transport-Security: max-age=31536000; includeSubDomains\n" + "Cache-Control: no-store\n" + "Set-Cookie: sessionId=abc123; HttpOnly; Secure; Path=/\r\n\r\n" + "\n" + "{\n" + ' "msg": "error not found"' + "}" ) - self.assertEqual(endpoints, ["/test"]) + prompt_engineer = MagicMock() + prompt_engineer.prompt_helper.current_step = 1 + self.openapi_handler.openapi_spec = { + "endpoints": { + "/invalid": { + "get": { + "id": "id" + } + } + } + } + updated_endpoints = self.openapi_handler.update_openapi_spec(mock_resp, result, prompt_engineer) + + self.assertIn("/invalid", self.openapi_handler.unsuccessful_paths) + self.assertIn("/invalid", updated_endpoints) + + def test_extract_status_code_and_message_valid(self): + result = "HTTP/1.1 200 OK\nContent-Type: application/json" + code, message = self.openapi_handler.extract_status_code_and_message(result) + self.assertEqual(code, "200") + self.assertEqual(message, "OK") + + def test_extract_status_code_and_message_invalid(self): + result = "Not an HTTP header" + code, message = self.openapi_handler.extract_status_code_and_message(result) + self.assertIsNone(code) + self.assertIsNone(message) + + def test_get_type_integer(self): + self.assertEqual(self.openapi_handler.get_type("123"), "integer") + + def test_get_type_double(self): + self.assertEqual(self.openapi_handler.get_type("3.14"), "double") + + def test_get_type_string(self): + self.assertEqual(self.openapi_handler.get_type("hello"), "string") + + def test_replace_crypto_with_id_found(self): + path = "/currency/bitcoin/prices" + replaced = self.openapi_handler.replace_crypto_with_id(path) + self.assertIn("{id}", replaced) + + def test_replace_crypto_with_id_not_found(self): + path = "/currency/euro/prices" + replaced = self.openapi_handler.replace_crypto_with_id(path) + self.assertEqual(replaced, path) + + def test_replace_id_with_placeholder_basic(self): + path = "/user/1/orders" + mock_prompt_engineer = MagicMock() + mock_prompt_engineer.prompt_helper.current_step = 1 + result = self.openapi_handler.replace_id_with_placeholder(path, mock_prompt_engineer) + self.assertIn("{id}", result) + + def test_replace_id_with_placeholder_current_step_2(self): + path = "/user/1234/orders" + mock_prompt_engineer = MagicMock() + mock_prompt_engineer.prompt_helper.current_step = 2 + result = self.openapi_handler.replace_id_with_placeholder(path, mock_prompt_engineer) + self.assertTrue(result.startswith("user")) + + def test_is_partial_match_true(self): + self.assertTrue(self.openapi_handler.is_partial_match("/users/1", ["/users/{id}"])) + + def test_is_partial_match_false(self): + self.assertFalse(self.openapi_handler.is_partial_match("/admin", ["/users/{id}", "/posts"])) - def test_partial_match(self): - string_list = ["test_endpoint", "another_endpoint"] - self.assertTrue(self.doc_handler.is_partial_match("test", string_list)) - self.assertFalse(self.doc_handler.is_partial_match("not_in_list", string_list)) + if __name__ == "__main__": + unittest.main() if __name__ == "__main__": diff --git a/tests/test_openapi_parser.py b/tests/test_openapi_parser.py index a4f73443..fca34a47 100644 --- a/tests/test_openapi_parser.py +++ b/tests/test_openapi_parser.py @@ -1,3 +1,4 @@ +import os import unittest from unittest.mock import mock_open, patch @@ -10,251 +11,50 @@ class TestOpenAPISpecificationParser(unittest.TestCase): def setUp(self): - self.filepath = "dummy_path.yaml" - self.yaml_content = """ - openapi: 3.0.0 - info: - title: Sample API - version: 1.0.0 - servers: - - url: https://api.example.com - - url: https://staging.api.example.com - paths: - /pets: - get: - summary: List all pets - responses: - '200': - description: A paged array of pets - post: - summary: Create a pet - responses: - '200': - description: Pet created - /pets/{petId}: - get: - summary: Info for a specific pet - responses: - '200': - description: Expected response to a valid request - """ - - @patch("builtins.open", new_callable=mock_open, read_data="") - @patch( - "yaml.safe_load", - return_value=yaml.safe_load( - """ - openapi: 3.0.0 - info: - title: Sample API - version: 1.0.0 - servers: - - url: https://api.example.com - - url: https://staging.api.example.com - paths: - /pets: - get: - summary: List all pets - responses: - '200': - description: A paged array of pets - post: - summary: Create a pet - responses: - '200': - description: Pet created - /pets/{petId}: - get: - summary: Info for a specific pet - responses: - '200': - description: Expected response to a valid request - """ - ), - ) - def test_load_yaml(self, mock_yaml_load, mock_open_file): - parser = OpenAPISpecificationParser(self.filepath) - self.assertEqual(parser.api_data["info"]["title"], "Sample API") - self.assertEqual(parser.api_data["info"]["version"], "1.0.0") - self.assertEqual(len(parser.api_data["servers"]), 2) - - @patch("builtins.open", new_callable=mock_open, read_data="") - @patch( - "yaml.safe_load", - return_value=yaml.safe_load( - """ - openapi: 3.0.0 - info: - title: Sample API - version: 1.0.0 - servers: - - url: https://api.example.com - - url: https://staging.api.example.com - paths: - /pets: - get: - summary: List all pets - responses: - '200': - description: A paged array of pets - post: - summary: Create a pet - responses: - '200': - description: Pet created - /pets/{petId}: - get: - summary: Info for a specific pet - responses: - '200': - description: Expected response to a valid request - """ - ), - ) - def test_get_servers(self, mock_yaml_load, mock_open_file): - parser = OpenAPISpecificationParser(self.filepath) - servers = parser._get_servers() - self.assertEqual(servers, ["https://api.example.com", "https://staging.api.example.com"]) - - @patch("builtins.open", new_callable=mock_open, read_data="") - @patch( - "yaml.safe_load", - return_value=yaml.safe_load( - """ - openapi: 3.0.0 - info: - title: Sample API - version: 1.0.0 - servers: - - url: https://api.example.com - - url: https://staging.api.example.com - paths: - /pets: - get: - summary: List all pets - responses: - '200': - description: A paged array of pets - post: - summary: Create a pet - responses: - '200': - description: Pet created - /pets/{petId}: - get: - summary: Info for a specific pet - responses: - '200': - description: Expected response to a valid request - """ - ), - ) - def test_get_paths(self, mock_yaml_load, mock_open_file): - parser = OpenAPISpecificationParser(self.filepath) - paths = parser.get_paths() - expected_paths = { - "/pets": { - "get": { - "summary": "List all pets", - "responses": {"200": {"description": "A paged array of pets"}}, - }, - "post": {"summary": "Create a pet", "responses": {"200": {"description": "Pet created"}}}, - }, - "/pets/{petId}": { - "get": { - "summary": "Info for a specific pet", - "responses": {"200": {"description": "Expected response to a valid request"}}, - } - }, - } - self.assertEqual(paths, expected_paths) - - @patch("builtins.open", new_callable=mock_open, read_data="") - @patch( - "yaml.safe_load", - return_value=yaml.safe_load( - """ - openapi: 3.0.0 - info: - title: Sample API - version: 1.0.0 - servers: - - url: https://api.example.com - - url: https://staging.api.example.com - paths: - /pets: - get: - summary: List all pets - responses: - '200': - description: A paged array of pets - post: - summary: Create a pet - responses: - '200': - description: Pet created - /pets/{petId}: - get: - summary: Info for a specific pet - responses: - '200': - description: Expected response to a valid request - """ - ), - ) - def test_get_operations(self, mock_yaml_load, mock_open_file): - parser = OpenAPISpecificationParser(self.filepath) - operations = parser._get_operations("/pets") - expected_operations = { - "get": { - "summary": "List all pets", - "responses": {"200": {"description": "A paged array of pets"}}, - }, - "post": {"summary": "Create a pet", "responses": {"200": {"description": "Pet created"}}}, - } + self.filepath = os.path.join(os.path.dirname(__file__), "test_files", "test_config.json") + self.parser = OpenAPISpecificationParser(self.filepath) + + + def test_get_servers(self): + servers = self.parser._get_servers() + self.assertEqual(["https://jsonplaceholder.typicode.com/"], servers) + + + def test_get_paths(self): + paths = self.parser.get_endpoints() + expected_paths = {'/posts': {'get': {'description': 'Returns all posts', + 'operationId': 'getPosts', + 'responses': {'200': {'content': {'application/json': {'schema': {'$ref': '#/components/schemas/PostsList'}}}, + 'description': 'Successful ' + 'response'}}, + 'tags': ['Posts']}}, + '/posts/{id}': {'get': {'description': 'Returns a post by id', + 'operationId': 'getPost', + 'parameters': [{'description': 'The user id.', + 'in': 'path', + 'name': 'id', + 'required': True, + 'schema': {'format': 'int64', + 'type': 'integer'}}], + 'responses': {'200': {'content': {'application/json': {'schema': {'$ref': '#/components/schemas/Post'}}}, + 'description': 'Successful ' + 'response'}, + '404': {'description': 'Post not ' + 'found'}}, + 'tags': ['Posts']}}} + self.assertEqual(expected_paths, paths) + + + def test_get_operations(self): + operations = self.parser._get_operations("/posts") + expected_operations = {'get': {'description': 'Returns all posts', + 'operationId': 'getPosts', + 'responses': {'200': {'content': {'application/json': {'schema': {'$ref': '#/components/schemas/PostsList'}}}, + 'description': 'Successful response'}}, + 'tags': ['Posts']}} self.assertEqual(operations, expected_operations) - @patch("builtins.open", new_callable=mock_open, read_data="") - @patch( - "yaml.safe_load", - return_value=yaml.safe_load( - """ - openapi: 3.0.0 - info: - title: Sample API - version: 1.0.0 - servers: - - url: https://api.example.com - - url: https://staging.api.example.com - paths: - /pets: - get: - summary: List all pets - responses: - '200': - description: A paged array of pets - post: - summary: Create a pet - responses: - '200': - description: Pet created - /pets/{petId}: - get: - summary: Info for a specific pet - responses: - '200': - description: Expected response to a valid request - """ - ), - ) - def test_print_api_details(self, mock_yaml_load, mock_open_file): - parser = OpenAPISpecificationParser(self.filepath) - with patch("builtins.print") as mocked_print: - parser._print_api_details() - mocked_print.assert_any_call("API Title:", "Sample API") - mocked_print.assert_any_call("API Version:", "1.0.0") - mocked_print.assert_any_call("Servers:", ["https://api.example.com", "https://staging.api.example.com"]) - mocked_print.assert_any_call("\nAvailable Paths and Operations:") + if __name__ == "__main__": diff --git a/tests/test_pentesting_information.py b/tests/test_pentesting_information.py new file mode 100644 index 00000000..a1818388 --- /dev/null +++ b/tests/test_pentesting_information.py @@ -0,0 +1,86 @@ +import os +import unittest +from unittest.mock import MagicMock + +from hackingBuddyGPT.usecases.web_api_testing.documentation.parsing import OpenAPISpecificationParser +from hackingBuddyGPT.utils.prompt_generation.information import PenTestingInformation +from hackingBuddyGPT.usecases.web_api_testing.utils.configuration_handler import ConfigurationHandler + + +class TestPenTestingInformation(unittest.TestCase): + + def setUp(self): + self.response_handler = MagicMock() + self.config_path = os.path.join(os.path.dirname(__file__), "test_files","fakeapi_config.json") + self.configuration_handler = ConfigurationHandler(self.config_path) + self.config = self.configuration_handler._load_config(self.config_path) + self._openapi_specification_parser = OpenAPISpecificationParser(self.config_path) + self._openapi_specification = self._openapi_specification_parser.api_data + + + + def test_assign_endpoint_categories(self): + self.pentesting_information = self.generate_pentesting_information("icl") + + creation_paths = [ep.get("path") for ep in self.pentesting_information.categorized_endpoints.get("account_creation")] + protected_endpoints = [ep.get("path") for ep in self.pentesting_information.categorized_endpoints.get("protected_endpoint")] + + self.assertIn('/users', creation_paths) + self.assertIn('/users/{id}',protected_endpoints) + + + def test_key_in_path(self): + self.pentesting_information = self.generate_pentesting_information("icl") + self.pentesting_information.resources = {"user": ["1", "2"]} + found, key = self.pentesting_information.key_in_path("/api/v1/user/1", self.pentesting_information.resources) + self.assertTrue(found) + self.assertEqual(key, "user") + + def test_generate_authentication_prompts(self): + self.pentesting_information = self.generate_pentesting_information("icl") + result = self.pentesting_information.generate_authentication_prompts() + self.assertIsInstance(result, list) + self.assertEqual(len(result), 0) + + def test_generate_input_validation_prompts(self): + self.pentesting_information = self.generate_pentesting_information("icl") + result = self.pentesting_information.generate_input_validation_prompts() + self.assertIsInstance(result, list) + self.assertEqual(len(result), 0) + + def test_generate_authorization_prompts(self): + self.pentesting_information = self.generate_pentesting_information("icl") + + result = self.pentesting_information.generate_authorization_prompts() + self.assertIsInstance(result, list) + self.assertEqual(len(result), 0) + + def test_generate_error_handling_prompts(self): + self.pentesting_information = self.generate_pentesting_information("icl") + + result = self.pentesting_information.generate_error_handling_prompts() + self.assertIsInstance(result, list) + self.assertEqual(len(result), 0) + + def test_generate_session_management_prompts(self): + self.pentesting_information = self.generate_pentesting_information("icl") + + result = self.pentesting_information.generate_session_management_prompts() + self.assertIsInstance(result, list) + self.assertEqual(len(result), 0) + + def test_generate_xss_prompts(self): + self.pentesting_information = self.generate_pentesting_information("icl") + + result = self.pentesting_information.generate_xss_prompts() + self.assertIsInstance(result, list) + self.assertEqual(len(result), 0) + + def generate_pentesting_information(self, param): + config, strategy = self.configuration_handler.load(param) + self.pentesting_information = PenTestingInformation(self._openapi_specification_parser, config) + return self.pentesting_information + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_prompt_engineer_documentation.py b/tests/test_prompt_engineer_documentation.py index daeedbbd..26f0e578 100644 --- a/tests/test_prompt_engineer_documentation.py +++ b/tests/test_prompt_engineer_documentation.py @@ -1,65 +1,106 @@ +import os import unittest from unittest.mock import MagicMock from openai.types.chat import ChatCompletionMessage -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( +from hackingBuddyGPT.usecases.web_api_testing.documentation.parsing import OpenAPISpecificationParser +from hackingBuddyGPT.utils.prompt_generation import PromptGenerationHelper +from hackingBuddyGPT.utils.prompt_generation.information import PenTestingInformation +from hackingBuddyGPT.utils.prompt_generation.information import ( PromptContext, ) -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompt_engineer import ( - PromptEngineer, - PromptStrategy, +from hackingBuddyGPT.utils.prompt_generation.prompt_engineer import ( + PromptEngineer ) +from hackingBuddyGPT.usecases.web_api_testing.utils.configuration_handler import ConfigurationHandler class TestPromptEngineer(unittest.TestCase): def setUp(self): - self.strategy = PromptStrategy.IN_CONTEXT self.llm_handler = MagicMock() self.history = [{"content": "initial_prompt", "role": "system"}] self.schemas = MagicMock() self.response_handler = MagicMock() - self.prompt_engineer = PromptEngineer( - strategy=self.strategy, - handlers=(self.llm_handler, self.response_handler), - history=self.history, - context=PromptContext.DOCUMENTATION, - ) + self.config_path = os.path.join(os.path.dirname(__file__), "test_files/test_config.json") + self.configuration_handler = ConfigurationHandler(self.config_path) + self.config = self.configuration_handler._load_config(self.config_path) + self._openapi_specification_parser = OpenAPISpecificationParser(self.config_path) + self._openapi_specification = self._openapi_specification_parser.api_data + + self.token, self.host, self.description, self.correct_endpoints, self.query_params = self.configuration_handler._extract_config_values( + self.config) + self.categorized_endpoints = self._openapi_specification_parser.categorize_endpoints(self.correct_endpoints, + self.query_params) + self.prompt_helper = PromptGenerationHelper(self.host, self.description) def test_in_context_learning_no_hint(self): - self.prompt_engineer.strategy = PromptStrategy.IN_CONTEXT - expected_prompt = "initial_prompt\ninitial_prompt" - actual_prompt = self.prompt_engineer.generate_prompt(hint="", turn=1) - self.assertEqual(expected_prompt, actual_prompt[1]["content"]) + prompt_engineer = self.generate_prompt_engineer("icl") + + expected_prompt = ('Based on this information :\n' + '\n' + 'Objective: Identify all accessible endpoints via GET requests for ' + 'https://jsonplaceholder.typicode.com/. See ' + 'https://jsonplaceholder.typicode.com/\n' + ' Query root-level resource endpoints.\n' + ' Find root-level endpoints for ' + 'https://jsonplaceholder.typicode.com/.\n' + ' Only send GET requests to root-level ' + 'endpoints with a single path component after the root. This means each path ' + "should have exactly one '/' followed by a single word (e.g., '/users', " + "'/products'). \n" + ' 1. Send GET requests to new paths ' + 'only, avoiding any in the lists above.\n' + ' 2. Do not reuse previously tested ' + 'paths.\n') + actual_prompt = prompt_engineer.generate_prompt(hint="", turn=1) + + + print(f'actuaL.{actual_prompt[0].get("content"),}') + self.assertEqual(actual_prompt[0].get("content"), expected_prompt) def test_in_context_learning_with_hint(self): - self.prompt_engineer.strategy = PromptStrategy.IN_CONTEXT + prompt_engineer = self.generate_prompt_engineer("icl") + expected_prompt = """Based on this information : + +Objective: Identify all accessible endpoints via GET requests for https://jsonplaceholder.typicode.com/. See https://jsonplaceholder.typicode.com/ + Query root-level resource endpoints. + Find root-level endpoints for https://jsonplaceholder.typicode.com/. + Only send GET requests to root-level endpoints with a single path component after the root. This means each path should have exactly one '/' followed by a single word (e.g., '/users', '/products'). + 1. Send GET requests to new paths only, avoiding any in the lists above. + 2. Do not reuse previously tested paths. +""" hint = "This is a hint." - expected_prompt = "initial_prompt\ninitial_prompt\nThis is a hint." - actual_prompt = self.prompt_engineer.generate_prompt(hint=hint, turn=1) - self.assertEqual(expected_prompt, actual_prompt[1]["content"]) + actual_prompt = prompt_engineer.generate_prompt(hint=hint, turn=1) + self.assertIn(hint, actual_prompt[0].get("content")) def test_in_context_learning_with_doc_and_hint(self): - self.prompt_engineer.strategy = PromptStrategy.IN_CONTEXT + prompt_engineer = self.generate_prompt_engineer("icl") hint = "This is another hint." - expected_prompt = "initial_prompt\ninitial_prompt\nThis is another hint." - actual_prompt = self.prompt_engineer.generate_prompt(hint=hint, turn=1) - self.assertEqual(expected_prompt, actual_prompt[1]["content"]) + expected_prompt = """Objective: Identify all accessible endpoints via GET requests for 'https://jsonplaceholder.typicode.com/ provided.. See https://jsonplaceholder.typicode.com/ + Query root-level resource endpoints. + Find root-level endpoints for 'https://jsonplaceholder.typicode.com/ provided.. + Only send GET requests to root-level endpoints with a single path component after the root. This means each path should have exactly one '/' followed by a single word (e.g., '/users', '/products'). + 1. Send GET requests to new paths only, avoiding any in the lists above. + 2. Do not reuse previously tested paths. + +This is another hint.""" + actual_prompt = prompt_engineer.generate_prompt(hint=hint, turn=1) + self.assertIn(hint,actual_prompt[0].get("content")) def test_generate_prompt_chain_of_thought(self): - self.prompt_engineer.strategy = PromptStrategy.CHAIN_OF_THOUGHT + prompt_engineer = self.generate_prompt_engineer("cot") self.response_handler.get_response_for_prompt = MagicMock(return_value="response_text") - self.prompt_engineer.evaluate_response = MagicMock(return_value=True) + prompt_engineer.evaluate_response = MagicMock(return_value=True) - prompt_history = self.prompt_engineer.generate_prompt(turn=1) + prompt_history = prompt_engineer.generate_prompt(turn=1) - self.assertEqual(2, len(prompt_history)) + self.assertEqual(1, len(prompt_history)) def test_generate_prompt_tree_of_thought(self): - # Set the strategy to TREE_OF_THOUGHT - self.prompt_engineer.strategy = PromptStrategy.TREE_OF_THOUGHT + prompt_engineer = self.generate_prompt_engineer("tot") self.response_handler.get_response_for_prompt = MagicMock(return_value="response_text") - self.prompt_engineer.evaluate_response = MagicMock(return_value=True) + prompt_engineer.evaluate_response = MagicMock(return_value=True) # Create mock previous prompts with valid roles previous_prompts = [ @@ -68,13 +109,26 @@ def test_generate_prompt_tree_of_thought(self): ] # Assign the previous prompts to prompt_engineer._prompt_history - self.prompt_engineer._prompt_history = previous_prompts + prompt_engineer._prompt_history = previous_prompts # Generate the prompt - prompt_history = self.prompt_engineer.generate_prompt(turn=1) + prompt_history = prompt_engineer.generate_prompt(turn=1) # Check if the prompt history length is as expected - self.assertEqual(len(prompt_history), 3) # Adjust to 3 if previous prompt exists + new prompt + self.assertEqual(1, len(prompt_history)) # Adjust to 3 if previous prompt exists + new prompt + + def generate_prompt_engineer(self, param): + config, strategy = self.configuration_handler.load(param) + self.pentesting_information = PenTestingInformation(self._openapi_specification_parser, config) + prompt_engineer = PromptEngineer( + strategy=strategy, + prompt_helper=self.prompt_helper, + context=PromptContext.DOCUMENTATION, + open_api_spec=self._openapi_specification, + rest_api_info=(self.token, self.host, self.correct_endpoints, self.categorized_endpoints), + ) + prompt_engineer.set_pentesting_information(pentesting_information=self.pentesting_information) + return prompt_engineer if __name__ == "__main__": diff --git a/tests/test_prompt_engineer_testing.py b/tests/test_prompt_engineer_testing.py index 198bbbc6..8ada4801 100644 --- a/tests/test_prompt_engineer_testing.py +++ b/tests/test_prompt_engineer_testing.py @@ -1,65 +1,95 @@ +import os import unittest from unittest.mock import MagicMock from openai.types.chat import ChatCompletionMessage -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( - PromptContext, +from hackingBuddyGPT.usecases.web_api_testing.documentation.parsing import OpenAPISpecificationParser +from hackingBuddyGPT.utils.prompt_generation import PromptGenerationHelper +from hackingBuddyGPT.utils.prompt_generation.information import PenTestingInformation +from hackingBuddyGPT.utils.prompt_generation.information import ( + PromptContext, PromptPurpose, ) -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompt_engineer import ( - PromptEngineer, - PromptStrategy, +from hackingBuddyGPT.utils.prompt_generation.prompt_engineer import ( + PromptEngineer ) +from hackingBuddyGPT.usecases.web_api_testing.utils.configuration_handler import ConfigurationHandler class TestPromptEngineer(unittest.TestCase): def setUp(self): - self.strategy = PromptStrategy.IN_CONTEXT self.llm_handler = MagicMock() self.history = [{"content": "initial_prompt", "role": "system"}] self.schemas = MagicMock() self.response_handler = MagicMock() - self.prompt_engineer = PromptEngineer( - strategy=self.strategy, - handlers=(self.llm_handler, self.response_handler), - history=self.history, - context=PromptContext.PENTESTING, - ) + self.config_path = os.path.join(os.path.dirname(__file__), "test_files","fakeapi_config.json") + self.configuration_handler = ConfigurationHandler(self.config_path) + self.config = self.configuration_handler._load_config(self.config_path) + self._openapi_specification_parser = OpenAPISpecificationParser(self.config_path) + self._openapi_specification = self._openapi_specification_parser.api_data - def test_in_context_learning_no_hint(self): - self.prompt_engineer.strategy = PromptStrategy.IN_CONTEXT - expected_prompt = "initial_prompt\ninitial_prompt" - actual_prompt = self.prompt_engineer.generate_prompt(hint="", turn=1) - self.assertEqual(expected_prompt, actual_prompt[1]["content"]) + self.token, self.host, self.description, self.correct_endpoints, self.query_params = self.configuration_handler._extract_config_values( + self.config) + self.categorized_endpoints = self._openapi_specification_parser.classify_endpoints() + self.prompt_helper = PromptGenerationHelper(self.host, self.description) + + def test_in_context_learning_no_hint(self): + prompt_engineer = self.generate_prompt_engineer("icl") + + expected_prompt = """Based on this information : + +Objective: Identify all accessible endpoints via GET requests for No host URL provided.. See https://jsonplaceholder.typicode.com/ + Query root-level resource endpoints. + Find root-level endpoints for No host URL provided.. + Only send GET requests to root-level endpoints with a single path component after the root. This means each path should have exactly one '/' followed by a single word (e.g., '/users', '/products'). + 1. Send GET requests to new paths only, avoiding any in the lists above. + 2. Do not reuse previously tested paths. +""" + actual_prompt = prompt_engineer.generate_prompt(hint="", turn=1) + self.assertIn(" Create an account by sending a POST HTTP request to the correct endpoint from this /users with these credentials of user:", actual_prompt[0].get("content")) def test_in_context_learning_with_hint(self): - self.prompt_engineer.strategy = PromptStrategy.IN_CONTEXT + prompt_engineer = self.generate_prompt_engineer("icl") + expected_prompt = """Based on this information : + + Objective: Identify all accessible endpoints via GET requests for No host URL provided.. See https://jsonplaceholder.typicode.com/ + Query root-level resource endpoints. + Find root-level endpoints for No host URL provided.. + Only send GET requests to root-level endpoints with a single path component after the root. This means each path should have exactly one '/' followed by a single word (e.g., '/users', '/products'). + 1. Send GET requests to new paths only, avoiding any in the lists above. + 2. Do not reuse previously tested paths. + """ hint = "This is a hint." - expected_prompt = "initial_prompt\ninitial_prompt\nThis is a hint." - actual_prompt = self.prompt_engineer.generate_prompt(hint=hint, turn=1) - self.assertEqual(expected_prompt, actual_prompt[1]["content"]) + actual_prompt = prompt_engineer.generate_prompt(hint=hint, turn=1) + self.assertIn(hint, actual_prompt[0].get("content"), ) def test_in_context_learning_with_doc_and_hint(self): - self.prompt_engineer.strategy = PromptStrategy.IN_CONTEXT + prompt_engineer = self.generate_prompt_engineer("icl") hint = "This is another hint." - expected_prompt = "initial_prompt\ninitial_prompt\nThis is another hint." - actual_prompt = self.prompt_engineer.generate_prompt(hint=hint, turn=1) - self.assertEqual(expected_prompt, actual_prompt[1]["content"]) + expected_prompt = """Objective: Identify all accessible endpoints via GET requests for No host URL provided.. See https://jsonplaceholder.typicode.com/ + Query root-level resource endpoints. + Find root-level endpoints for No host URL provided.. + Only send GET requests to root-level endpoints with a single path component after the root. This means each path should have exactly one '/' followed by a single word (e.g., '/users', '/products'). + 1. Send GET requests to new paths only, avoiding any in the lists above. + 2. Do not reuse previously tested paths. + +This is another hint.""" + actual_prompt = prompt_engineer.generate_prompt(hint=hint, turn=1) + self.assertIn(hint,actual_prompt[0].get("content")) def test_generate_prompt_chain_of_thought(self): - self.prompt_engineer.strategy = PromptStrategy.CHAIN_OF_THOUGHT + prompt_engineer = self.generate_prompt_engineer("cot") self.response_handler.get_response_for_prompt = MagicMock(return_value="response_text") - self.prompt_engineer.evaluate_response = MagicMock(return_value=True) + prompt_engineer.evaluate_response = MagicMock(return_value=True) - prompt_history = self.prompt_engineer.generate_prompt(turn=1) + prompt_history = prompt_engineer.generate_prompt(turn=1) - self.assertEqual(2, len(prompt_history)) + self.assertEqual(1, len(prompt_history)) def test_generate_prompt_tree_of_thought(self): - # Set the strategy to TREE_OF_THOUGHT - self.prompt_engineer.strategy = PromptStrategy.TREE_OF_THOUGHT + prompt_engineer = self.generate_prompt_engineer("tot") self.response_handler.get_response_for_prompt = MagicMock(return_value="response_text") - self.prompt_engineer.evaluate_response = MagicMock(return_value=True) + prompt_engineer.evaluate_response = MagicMock(return_value=True) # Create mock previous prompts with valid roles previous_prompts = [ @@ -68,13 +98,32 @@ def test_generate_prompt_tree_of_thought(self): ] # Assign the previous prompts to prompt_engineer._prompt_history - self.prompt_engineer._prompt_history = previous_prompts + prompt_engineer._prompt_history = previous_prompts # Generate the prompt - prompt_history = self.prompt_engineer.generate_prompt(turn=1) + prompt_history = prompt_engineer.generate_prompt(turn=1) # Check if the prompt history length is as expected - self.assertEqual(len(prompt_history), 3) # Adjust to 3 if previous prompt exists + new prompt + self.assertEqual(1, len(prompt_history)) # Adjust to 3 if previous prompt exists + new prompt + + def generate_prompt_engineer(self, param): + config, strategy = self.configuration_handler.load(param) + self.pentesting_information = PenTestingInformation(self._openapi_specification_parser, config) + + prompt_engineer = PromptEngineer( + strategy=strategy, + prompt_helper=self.prompt_helper, + context=PromptContext.PENTESTING, + open_api_spec=self._openapi_specification, + rest_api_info=(self.token, self.description, self.correct_endpoints, self.categorized_endpoints), + ) + self.pentesting_information.pentesting_step_list = [ + PromptPurpose.SETUP, + PromptPurpose.VERIY_SETUP + ] + prompt_engineer.set_pentesting_information(pentesting_information=self.pentesting_information) + + return prompt_engineer if __name__ == "__main__": diff --git a/tests/test_prompt_generation_helper.py b/tests/test_prompt_generation_helper.py index 06aca3b4..c51ba8f5 100644 --- a/tests/test_prompt_generation_helper.py +++ b/tests/test_prompt_generation_helper.py @@ -1,24 +1,42 @@ import unittest -from unittest.mock import MagicMock +from hackingBuddyGPT.utils.prompt_generation import PromptGenerationHelper -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.prompt_generation_helper import ( - PromptGenerationHelper, -) - -class TestPromptHelper(unittest.TestCase): +class TestPromptGenerationHelper(unittest.TestCase): def setUp(self): - self.response_handler = MagicMock() - self.prompt_helper = PromptGenerationHelper(self.response_handler) - - def test_check_prompt(self): - self.response_handler.get_response_for_prompt = MagicMock(return_value="shortened_prompt") - prompt = self.prompt_helper.check_prompt( - previous_prompt="previous_prompt", - steps=["step1", "step2", "step3", "step4", "step5", "step6"], - max_tokens=2, - ) - self.assertEqual("shortened_prompt", prompt) + self.host = "https://reqres.in" + self.description = "Fake API" + self.prompt_helper = PromptGenerationHelper(self.host, self.description) + + def test_get_user_from_prompt(self): + step = { + "step": "Create a new user with user: {'email': 'eve.holt@reqres.in', 'password': 'pistol'}.\n" + } + accounts = [ + {"email": "eve.holt@reqres.in", "password": "pistol"} + ] + + user_info = self.prompt_helper.get_user_from_prompt(step, accounts) + + self.assertEqual(user_info["email"], "eve.holt@reqres.in") + self.assertEqual(user_info["password"], "pistol") + self.assertIn("x", user_info) + self.assertEqual(user_info["x"], "") + + def test_get_user_from_prompt_with_sql_injection(self): + step = { + "step": "Create user with user: {'email': \"' or 1=1--\", 'password': 'pistol'}.\n" + } + accounts = [ + {"email": "' or 1=1--", "password": "pistol"} + ] + + user_info = self.prompt_helper.get_user_from_prompt(step, accounts) + + self.assertEqual(user_info["email"], " or 1=1--") + self.assertEqual(user_info["password"], "pistol") + self.assertIn("x", user_info) + self.assertEqual(user_info["x"], "") if __name__ == "__main__": diff --git a/tests/test_response_analyzer.py b/tests/test_response_analyzer.py index 0c621bcf..ebd266b6 100644 --- a/tests/test_response_analyzer.py +++ b/tests/test_response_analyzer.py @@ -1,68 +1,111 @@ import unittest -from unittest.mock import patch - -from hackingBuddyGPT.usecases.web_api_testing.prompt_generation.information.prompt_information import ( - PromptPurpose, -) -from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_analyzer import ( - ResponseAnalyzer, -) +from hackingBuddyGPT.utils.prompt_generation.information import PromptPurpose +from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_analyzer import ResponseAnalyzer class TestResponseAnalyzer(unittest.TestCase): + def setUp(self): - # Example HTTP response to use in tests - self.raw_http_response = """HTTP/1.1 404 Not Found - Date: Fri, 16 Aug 2024 10:01:19 GMT - Content-Type: application/json; charset=utf-8 - Content-Length: 2 - Connection: keep-alive - Report-To: {"group":"heroku-nel","max_age":3600,"endpoints":[{"url":"https://nel.heroku.com/reports?ts=1723802269&sid=e11707d5-02a7-43ef-b45e-2cf4d2036f7d&s=dkvm744qehjJmab8kgf%2BGuZA8g%2FCCIkfoYc1UdYuZMc%3D"}]} - X-Powered-By: Express - X-Ratelimit-Limit: 1000 - X-Ratelimit-Remaining: 999 - X-Ratelimit-Reset: 1723802321 - Cache-Control: max-age=43200 - Server: cloudflare - - {}""" - - def test_parse_http_response(self): + self.auth_headers = ( + "HTTP/1.1 200 OK\n" + "Authorization: Bearer token\n" + "X-Ratelimit-Limit: 1000\n" + "X-Ratelimit-Remaining: 998\n" + "X-Ratelimit-Reset: 1723802321\n" + "\n" + '[{"message": "Welcome!"}]' + ) + + self.error_body = ( + "HTTP/1.1 403 Forbidden\n" + "Content-Type: application/json\n" + "\n" + '[{"error": "Access denied"}]' + ) + + self.validation_fail = ( + "HTTP/1.1 400 Bad Request\n" + "X-Content-Type-Options: nosniff\n" + "\n" + '[{"error": "Invalid input"}]' + ) + + def test_parse_http_response_success(self): + analyzer = ResponseAnalyzer() + status, headers, body = analyzer.parse_http_response(self.auth_headers) + + self.assertEqual(200, status) + self.assertIn("Authorization", headers) + if isinstance(body, dict): + msg = body.get("message") + self.assertEqual( "Welcome!",msg ) + + def test_parse_http_response_invalid(self): analyzer = ResponseAnalyzer() - status_code, headers, body = analyzer.parse_http_response(self.raw_http_response) + status, headers, body = analyzer.parse_http_response(self.error_body) - self.assertEqual(status_code, 404) - self.assertEqual(headers["Content-Type"], "application/json; charset=utf-8") - self.assertEqual(body, "Empty") + self.assertEqual( 403, status) + if isinstance(body, dict): + msg = body.get("message") - def test_analyze_authentication_authorization(self): - analyzer = ResponseAnalyzer(PromptPurpose.AUTHENTICATION_AUTHORIZATION) - analysis = analyzer.analyze_response(self.raw_http_response) + self.assertEqual( "Access denied", msg) - self.assertEqual(analysis["status_code"], 404) - self.assertEqual(analysis["authentication_status"], "Unknown") - self.assertTrue(analysis["content_body"], "Empty") - self.assertIn("X-Ratelimit-Limit", analysis["rate_limiting"]) + def test_analyze_authentication(self): + analyzer = ResponseAnalyzer(PromptPurpose.AUTHENTICATION) + result = analyzer.analyze_response(self.auth_headers) - def test_analyze_input_validation(self): + self.assertEqual(result["status_code"], 200) + self.assertEqual(result["authentication_status"], "Authenticated") + self.assertTrue(result["auth_headers_present"]) + self.assertIn("X-Ratelimit-Limit", result["rate_limiting"]) + + def test_analyze_input_validation_invalid(self): analyzer = ResponseAnalyzer(PromptPurpose.INPUT_VALIDATION) - analysis = analyzer.analyze_response(self.raw_http_response) + result = analyzer.analyze_response(self.validation_fail) + + self.assertEqual(result["status_code"], 400) + self.assertEqual(result["is_valid_response"], "Invalid") + self.assertTrue(result["security_headers_present"]) - self.assertEqual(analysis["status_code"], 404) - self.assertEqual(analysis["is_valid_response"], "Error") - self.assertTrue(analysis["response_body"], "Empty") - self.assertIn("security_headers_present", analysis) + def test_is_valid_input_response(self): + analyzer = ResponseAnalyzer() + self.assertEqual(analyzer.is_valid_input_response(200, "data"), "Valid") + self.assertEqual(analyzer.is_valid_input_response(400, "error"), "Invalid") + self.assertEqual(analyzer.is_valid_input_response(500, "error"), "Error") + self.assertEqual(analyzer.is_valid_input_response(999, "???"), "Unexpected") - @patch("builtins.print") - def test_print_analysis(self, mock_print): + def test_document_findings(self): + analyzer = ResponseAnalyzer() + document = analyzer.document_findings( + status_code=403, + headers={"Content-Type": "application/json"}, + body="Access denied", + expected_behavior="Access should be allowed", + actual_behavior="Access denied" + ) + self.assertEqual(document["Status Code"], 403) + self.assertIn("Access denied", document["Actual Behavior"]) + + def test_print_analysis_output_structure(self): analyzer = ResponseAnalyzer(PromptPurpose.INPUT_VALIDATION) - analysis = analyzer.analyze_response(self.raw_http_response) - analysis_str = analyzer.print_analysis(analysis) + result = analyzer.analyze_response(self.validation_fail) + printed = analyzer.print_analysis(result) + + self.assertIn("HTTP Status Code: 400", printed) + self.assertIn("Valid Response: Invalid", printed) + self.assertIn("Security Headers Present", printed) - # Check that the correct calls were made to print - self.assertIn("HTTP Status Code: 404", analysis_str) - self.assertIn("Response Body: Empty", analysis_str) - self.assertIn("Security Headers Present: Yes", analysis_str) + def test_report_issues_found(self): + analyzer = ResponseAnalyzer() + document = analyzer.document_findings( + status_code=200, + headers={}, + body="test", + expected_behavior="User not authenticated", + actual_behavior="User is authenticated" + ) + # Just ensure no exceptions, prints okay + analyzer.report_issues(document) if __name__ == "__main__": diff --git a/tests/test_response_analyzer_with_llm.py b/tests/test_response_analyzer_with_llm.py new file mode 100644 index 00000000..d384edaf --- /dev/null +++ b/tests/test_response_analyzer_with_llm.py @@ -0,0 +1,85 @@ +import unittest +from unittest.mock import MagicMock + +from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_analyzer_with_llm import ResponseAnalyzerWithLLM +from hackingBuddyGPT.utils.prompt_generation.information import PromptPurpose + + +class TestResponseAnalyzerWithLLM(unittest.TestCase): + def setUp(self): + self.llm_handler = MagicMock() + self.pentesting_info = MagicMock() + self.prompt_helper = MagicMock() + self.analyzer = ResponseAnalyzerWithLLM( + purpose=PromptPurpose.PARSING, + llm_handler=self.llm_handler, + pentesting_info=self.pentesting_info, + capacity=MagicMock(), + prompt_helper=self.prompt_helper + ) + + def test_parse_http_response_success(self): + raw_response = ( + "HTTP/1.1 200 OK\n" + "Content-Type: application/json\n" + "\n" + '{"id": 1, "name": "John"}' + ) + + status_code, headers, body = self.analyzer.parse_http_response(raw_response) + + self.assertEqual(status_code, "200") + self.assertEqual(headers["Content-Type"], "application/json") + self.assertEqual(body, {"id": 1, "name": "John"}) + + def test_parse_http_response_html(self): + raw_response = ( + "HTTP/1.1 200 OK\n" + "Content-Type: text/html\n" + "\n" + "Error Page" + ) + + status_code, headers, body = self.analyzer.parse_http_response(raw_response) + + self.assertEqual(status_code, "200") + self.assertEqual(headers["Content-Type"], "text/html") + self.assertEqual(body, "") + + def test_process_step_calls_llm_handler(self): + step = "Please analyze the response" + prompt_history = [] + capability = "http_request" + + fake_response = MagicMock() + fake_response.execute.return_value = "Execution Result" + + fake_completion = MagicMock() + fake_completion.choices = [MagicMock(message=MagicMock(tool_calls=[MagicMock(id="abc123")]))] + + self.llm_handler.execute_prompt_with_specific_capability.return_value = (fake_response, fake_completion) + + updated_history, result = self.analyzer.process_step(step, prompt_history, capability) + + self.assertIn(step, updated_history[0]["content"]) + self.assertEqual(result, "Execution Result") + + def test_get_addition_context(self): + raw_response = ( + "HTTP/1.1 404 Not Found\n" + "Content-Type: application/json\n" + "{}" + ) + step = { + "expected_response_code": ["200", "201"], + "security": "Ensure auth token" + } + + status_code, additional_context, full_response = self.analyzer.get_addition_context(raw_response, step) + + self.assertEqual(status_code, "404") + self.assertIn("Ensure auth token", additional_context) + self.assertIn("Status Code: 404", full_response) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_response_handler.py b/tests/test_response_handler.py index 31a223de..a4f72c87 100644 --- a/tests/test_response_handler.py +++ b/tests/test_response_handler.py @@ -1,28 +1,28 @@ +import os import unittest from unittest.mock import MagicMock, patch +from hackingBuddyGPT.utils.prompt_generation import PromptGenerationHelper +from hackingBuddyGPT.utils.prompt_generation.information import PromptContext from hackingBuddyGPT.usecases.web_api_testing.response_processing.response_handler import ( ResponseHandler, ) +from hackingBuddyGPT.usecases.web_api_testing.utils import LLMHandler +from hackingBuddyGPT.usecases.web_api_testing.utils.configuration_handler import ConfigurationHandler class TestResponseHandler(unittest.TestCase): def setUp(self): - self.llm_handler_mock = MagicMock() - self.response_handler = ResponseHandler(self.llm_handler_mock) + self.llm_handler_mock = MagicMock(spec=LLMHandler) + self.config_path = os.path.join(os.path.dirname(__file__), "test_files","test_config.json") + self.configuration_handler = ConfigurationHandler(self.config_path) + self.config = self.configuration_handler._load_config(self.config_path) + self.host = "https://reqres.in" + self.description = "Fake API" + self.prompt_helper = PromptGenerationHelper(self.host, self.description) + self.response_handler = ResponseHandler(self.llm_handler_mock, PromptContext.DOCUMENTATION, self.config, + self.prompt_helper, None) - def test_get_response_for_prompt(self): - prompt = "Test prompt" - response_mock = MagicMock() - response_mock.execute.return_value = "Response text" - self.llm_handler_mock.call_llm.return_value = (response_mock, MagicMock()) - - response_text = self.response_handler.get_response_for_prompt(prompt) - - self.llm_handler_mock.call_llm.assert_called_once_with( - [{"role": "user", "content": [{"type": "text", "text": prompt}]}] - ) - self.assertEqual(response_text, "Response text") def test_parse_http_status_line_valid(self): status_line = "HTTP/1.1 200 OK" diff --git a/tests/test_web_api_documentation.py b/tests/test_web_api_documentation.py index 03f79127..c2b951f6 100644 --- a/tests/test_web_api_documentation.py +++ b/tests/test_web_api_documentation.py @@ -1,3 +1,4 @@ +import os import unittest from unittest.mock import MagicMock, patch @@ -23,18 +24,19 @@ def setUp(self, MockOpenAILib): console=console, tag="webApiDocumentation", ) - self.agent = SimpleWebAPIDocumentation(llm=self.mock_llm, log=log) + config_path = os.path.join(os.path.dirname(__file__), "test_files", "test_config.json") + + self.agent = SimpleWebAPIDocumentation(llm=self.mock_llm, log=log, config_path=config_path, + strategy_string="cot") self.agent.init() self.simple_api_testing = SimpleWebAPIDocumentationUseCase( - agent=self.agent, - log=log, - max_turns=len(self.mock_llm.responses), + agent=self.agent ) self.simple_api_testing.init() def test_initial_prompt(self): # Test if the initial prompt is set correctly - expected_prompt = "You're tasked with documenting the REST APIs of a website hosted at https://jsonplaceholder.typicode.com. Start with an empty OpenAPI specification.\nMaintain meticulousness in documenting your observations as you traverse the APIs." + expected_prompt = "You're tasked with documenting the REST APIs of a website hosted at https://jsonplaceholder.typicode.com/. The website is See https://jsonplaceholder.typicode.com/. Start with an empty OpenAPI specification and be meticulous in documenting your observations as you traverse the APIs" self.assertIn(expected_prompt, self.agent._prompt_history[0]["content"]) @@ -63,11 +65,26 @@ def test_perform_round(self, mock_perf_counter): ) # Mock the tool execution result - mock_response.execute.return_value = "HTTP/1.1 200 OK" + real_http_response = ( + "HTTP/1.1 200 OK\r\n" + "Date: Fri, 18 Apr 2025 07:31:21 GMT\r\n" + "Content-Type: application/json; charset=utf-8\r\n" + "Transfer-Encoding: chunked\r\n" + "Connection: keep-alive\r\n" + "Content-Encoding: gzip\r\n" + "\r\n" + '{"page":1,"per_page":6,"total":12,"total_pages":2,"data":[{"id":1,"name":"cerulean"}]}' + ) + + mock_response.execute.return_value = real_http_response + mock_response.action.path = "/posts/" + + self.agent.prompt_helper.found_endpoints = ["/users/"] # Perform the round result = self.agent.perform_round(1) + # Assertions self.assertFalse(result) diff --git a/tests/test_web_api_testing.py b/tests/test_web_api_testing.py index 6a320b68..125f44ff 100644 --- a/tests/test_web_api_testing.py +++ b/tests/test_web_api_testing.py @@ -1,13 +1,12 @@ +import os import unittest from unittest.mock import MagicMock, patch - -from hackingBuddyGPT.usecases import SimpleWebAPITesting -from hackingBuddyGPT.utils.logging import LocalLogger from hackingBuddyGPT.usecases.web_api_testing.simple_web_api_testing import ( - SimpleWebAPITestingUseCase, + SimpleWebAPITestingUseCase, SimpleWebAPITesting, ) from hackingBuddyGPT.utils import Console, DbStorage +from hackingBuddyGPT.utils.logging import LocalLogger class TestSimpleWebAPITestingTest(unittest.TestCase): @patch("hackingBuddyGPT.utils.openai.openai_lib.OpenAILib") @@ -21,29 +20,33 @@ def setUp(self, MockOpenAILib): log = LocalLogger( log_db=log_db, console=console, - tag="integration_test_linuxprivesc", ) - self.agent = SimpleWebAPITesting(llm=self.mock_llm, log=log) + config_path = os.path.join(os.path.dirname(__file__), "test_files","fakeapi_config.json") + + self.agent = SimpleWebAPITesting(llm=self.mock_llm, log=log,config_path= config_path, strategy_string= "cot") + self.agent.init() self.simple_api_testing = SimpleWebAPITestingUseCase( - agent=self.agent, - log=log, - max_turns=len(self.mock_llm.responses), + agent=self.agent + ) self.simple_api_testing.init() + + def test_initial_prompt(self): + contents = [prompt_history_entry["content"] for prompt_history_entry in self.agent._prompt_history] # Test if the initial prompt is set correctly self.assertIn( - "You're a seasoned API tester tasked with evaluating the REST APIs of a website hosted at https://jsonplaceholder.typicode.com. Your primary objective is to thoroughly explore the APIs, understanding their endpoints, parameters, and responses, and then devise and execute potential exploits. Be diligent in documenting your findings as you navigate through the APIs. Avoid resorting to brute-force tactics. All necessary information can be obtained from the API endpoints themselves. Remember, if you encounter an HTTP method (A string that represents an HTTP method (e.g., 'GET', 'POST', etc.).), promptly submit it as it is of utmost importance.", - self.agent._prompt_history[0]["content"], + "You're a seasoned API tester tasked with evaluating the REST APIs of a website hosted at https://dummyjson.com. Your primary objective is to thoroughly explore the APIs, understanding their endpoints, parameters, and responses, and then devise and execute potential exploits. Be diligent in documenting your findings as you navigate through the APIs. Avoid resorting to brute-force tactics. All necessary information can be obtained from the API endpoints themselves. Remember, if you encounter an HTTP method (A string that represents an HTTP method (e.g., 'GET', 'POST', etc.).), promptly submit it as it is of utmost importance.", + contents, ) def test_all_flags_found(self): # Mock console.print to suppress output during testing with patch("rich.console.Console.print"): - self.agent.all_http_methods_found() - self.assertFalse(self.agent.all_http_methods_found()) + self.agent.all_test_cases_run() + self.assertFalse(self.agent.all_test_cases_run()) @patch("time.perf_counter", side_effect=[1, 2]) # Mocking perf_counter for consistent timing def test_perform_round(self, mock_perf_counter): @@ -64,7 +67,25 @@ def test_perform_round(self, mock_perf_counter): ) # Mock the tool execution result - mock_response.execute.return_value = "HTTP/1.1 200 OK" + mock_response.execute.return_value = ( + "HTTP/1.1 200 OK\n" + "Date: Wed, 17 Apr 2025 12:00:00 GMT\n" + "Content-Type: application/json; charset=utf-8\n" + "Content-Length: 85\n" + "Connection: keep-alive\n" + "X-Powered-By: Express\n" + "Strict-Transport-Security: max-age=31536000; includeSubDomains\n" + "Cache-Control: no-store\n" + "Set-Cookie: sessionId=abc123; HttpOnly; Secure; Path=/\r\n\r\n" + "\n" + "{\n" + ' "id": 1,\n' + ' "username": "alice@example.com",\n' + ' "role": "user",\n' + ' "token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."\n' + "}" +) + mock_response.action.path = "/users/" # Perform the round