diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 34692dd20cb..26e15bdf0fe 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @siriak +* @imp2002 @vil02 diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000000..7abaea9b883 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,13 @@ +--- +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/.github/workflows/" + schedule: + interval: "weekly" + + - package-ecosystem: "cargo" + directory: "/" + schedule: + interval: "daily" +... diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 00000000000..3623445a8c7 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,29 @@ +# Pull Request Template + +## Description + +Please include a summary of the change and which issue (if any) is fixed. +A brief description of the algorithm and your implementation method can be helpful too. If the implemented method/algorithm is not so +well-known, it would be helpful to add a link to an article explaining it with more details. + +## Type of change + +Please delete options that are not relevant. + +- [ ] Bug fix (non-breaking change which fixes an issue) +- [ ] New feature (non-breaking change which adds functionality) +- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) + +## Checklist: + +- [ ] I ran bellow commands using the latest version of **rust nightly**. +- [ ] I ran `cargo clippy --all -- -D warnings` just before my last commit and fixed any issue that was found. +- [ ] I ran `cargo fmt` just before my last commit. +- [ ] I ran `cargo test` just before my last commit and all tests passed. +- [ ] I added my algorithm to the corresponding `mod.rs` file within its own folder, and in any parent folder(s). +- [ ] I added my algorithm to `DIRECTORY.md` with the correct link. +- [ ] I checked `COUNTRIBUTING.md` and my code follows its guidelines. + +Please make sure that if there is a test that takes too long to run ( > 300ms), you `#[ignore]` that or +try to optimize your code or make the test easier to run. We have this rule because we have hundreds of +tests to run; If each one of them took 300ms, we would have to wait for a long time. diff --git a/.github/stale.yml b/.github/stale.yml deleted file mode 100644 index 00988710503..00000000000 --- a/.github/stale.yml +++ /dev/null @@ -1,62 +0,0 @@ -# Configuration for probot-stale - https://github.com/probot/stale - -# Number of days of inactivity before an Issue or Pull Request becomes stale -daysUntilStale: 30 - -# Number of days of inactivity before an Issue or Pull Request with the stale label is closed. -# Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale. -daysUntilClose: 7 - -# Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled) -onlyLabels: [] - -# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable -exemptLabels: - - "dont-close" - -# Set to true to ignore issues in a project (defaults to false) -exemptProjects: false - -# Set to true to ignore issues in a milestone (defaults to false) -exemptMilestones: false - -# Set to true to ignore issues with an assignee (defaults to false) -exemptAssignees: false - -# Label to use when marking as stale -staleLabel: abandoned - -# Limit the number of actions per hour, from 1-30. Default is 30 -limitPerRun: 1 - -# Comment to post when removing the stale label. -# unmarkComment: > -# Your comment here. - -# Optionally, specify configuration settings that are specific to just 'issues' or 'pulls': -pulls: - # Comment to post when marking as stale. Set to `false` to disable - markComment: > - This pull request has been automatically marked as abandoned because it has not had - recent activity. It will be closed if no further activity occurs. Thank you - for your contributions. - - # Comment to post when closing a stale Pull Request. - closeComment: > - Please ping one of the maintainers once you commit the changes requested - or make improvements on the code. If this is not the case and you need - some help, feel free to ask for help in our [Gitter](https://gitter.im/TheAlgorithms) - channel. Thank you for your contributions! - -issues: - # Comment to post when marking as stale. Set to `false` to disable - markComment: > - This issue has been automatically marked as abandoned because it has not had - recent activity. It will be closed if no further activity occurs. Thank you - for your contributions. - - # Comment to post when closing a stale Issue. - closeComment: > - Please ping one of the maintainers once you add more information and updates here. - If this is not the case and you need some help, feel free to ask for help - in our [Gitter](https://gitter.im/TheAlgorithms) channel. Thank you for your contributions! diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 00000000000..1ab85c40554 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,35 @@ +name: build + +'on': + pull_request: + workflow_dispatch: + schedule: + - cron: '51 2 * * 4' + +permissions: + contents: read + +jobs: + fmt: + name: cargo fmt + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: cargo fmt + run: cargo fmt --all -- --check + + clippy: + name: cargo clippy + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: cargo clippy + run: cargo clippy --all --all-targets -- -D warnings + + test: + name: cargo test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: cargo test + run: cargo test diff --git a/.github/workflows/code_ql.yml b/.github/workflows/code_ql.yml new file mode 100644 index 00000000000..707822d15a3 --- /dev/null +++ b/.github/workflows/code_ql.yml @@ -0,0 +1,35 @@ +--- +name: code_ql + +'on': + workflow_dispatch: + push: + branches: + - master + pull_request: + schedule: + - cron: '10 7 * * 1' + +jobs: + analyze_actions: + name: Analyze Actions + runs-on: 'ubuntu-latest' + permissions: + actions: read + contents: read + security-events: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + languages: 'actions' + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 + with: + category: "/language:actions" +... diff --git a/.github/workflows/directory_workflow.yml b/.github/workflows/directory_workflow.yml index bc46852fbe3..9595c7ad8cb 100644 --- a/.github/workflows/directory_workflow.yml +++ b/.github/workflows/directory_workflow.yml @@ -1,58 +1,30 @@ -name: directory_md -on: [push, pull_request] +name: build_directory_md +on: + push: + branches: [master] + +permissions: + contents: read jobs: MainSequence: name: DIRECTORY.md runs-on: ubuntu-latest steps: - - uses: actions/checkout@v1 # v2 is broken for git diff - - uses: actions/setup-python@v2 + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-python@v5 - name: Setup Git Specs run: | - git config --global user.name github-actions - git config --global user.email '${GITHUB_ACTOR}@users.noreply.github.com' + git config --global user.name "$GITHUB_ACTOR" + git config --global user.email "$GITHUB_ACTOR@users.noreply.github.com" git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/$GITHUB_REPOSITORY - name: Update DIRECTORY.md - shell: python run: | - import os - from typing import Iterator - URL_BASE = "https://github.com/TheAlgorithms/Rust/blob/master" - g_output = [] - def good_filepaths(top_dir: str = ".") -> Iterator[str]: - fs_exts = tuple(".rs".split()) - for dirpath, dirnames, filenames in os.walk(top_dir): - dirnames[:] = [d for d in dirnames if d[0] not in "._"] - for filename in filenames: - if os.path.splitext(filename)[1].lower() in fs_exts: - yield os.path.join(dirpath, filename).lstrip("./") - def md_prefix(i): - return f"{i * ' '}*" if i else "\n##" - def print_path(old_path: str, new_path: str) -> str: - global g_output - old_parts = old_path.split(os.sep) - for i, new_part in enumerate(new_path.split(os.sep)): - if i + 1 > len(old_parts) or old_parts[i] != new_part: - if new_part: - g_output.append(f"{md_prefix(i)} {new_part.replace('_', ' ').title()}") - return new_path - def build_directory_md(top_dir: str = ".") -> str: - global g_output - old_path = "" - for filepath in sorted(good_filepaths(), key=str.lower): - filepath, filename = os.path.split(filepath) - if filepath != old_path: - old_path = print_path(old_path, filepath) - indent = (filepath.count(os.sep) + 1) if filepath else 0 - url = "/".join((URL_BASE, filepath, filename)).replace(" ", "%20") - filename = os.path.splitext(filename.replace("_", " ").title())[0] - g_output.append(f"{md_prefix(indent)} [{filename}]({url})") - return "# List of all files\n" + "\n".join(g_output) - with open("DIRECTORY.md", "w") as out_file: - out_file.write(build_directory_md(".") + "\n") + cargo run --manifest-path=.github/workflows/scripts/build_directory/Cargo.toml - name: Commit DIRECTORY.md run: | - git commit -m "updating DIRECTORY.md" DIRECTORY.md || true - git diff DIRECTORY.md - git push --force origin HEAD:$GITHUB_REF || true + git add DIRECTORY.md + git commit -m "Update DIRECTORY.md [skip actions]" || true + git push origin HEAD:$GITHUB_REF || true diff --git a/.github/workflows/scripts/build_directory/Cargo.toml b/.github/workflows/scripts/build_directory/Cargo.toml new file mode 100644 index 00000000000..fd8a7760e45 --- /dev/null +++ b/.github/workflows/scripts/build_directory/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "build_directory" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] diff --git a/.github/workflows/scripts/build_directory/src/lib.rs b/.github/workflows/scripts/build_directory/src/lib.rs new file mode 100644 index 00000000000..2a5aba9700e --- /dev/null +++ b/.github/workflows/scripts/build_directory/src/lib.rs @@ -0,0 +1,124 @@ +use std::{ + error::Error, + fs, + path::{Path, PathBuf}, +}; + +static URL_BASE: &str = "https://github.com/TheAlgorithms/Rust/blob/master"; + +fn good_filepaths(top_dir: &Path) -> Result, Box> { + let mut good_fs = Vec::new(); + if top_dir.is_dir() { + for entry in fs::read_dir(top_dir)? { + let entry = entry?; + let path = entry.path(); + if entry.file_name().to_str().unwrap().starts_with('.') + || entry.file_name().to_str().unwrap().starts_with('_') + { + continue; + } + if path.is_dir() { + let mut other = good_filepaths(&path)?; + good_fs.append(&mut other); + } else if entry.file_name().to_str().unwrap().ends_with(".rs") + && entry.file_name().to_str().unwrap() != "mod.rs" + { + good_fs.push( + path.into_os_string() + .into_string() + .unwrap() + .split_at(2) + .1 + .to_string(), + ); + } + } + } + good_fs.sort(); + Ok(good_fs) +} + +fn md_prefix(indent_count: usize) -> String { + if indent_count > 0 { + format!("{}*", " ".repeat(indent_count)) + } else { + "\n##".to_string() + } +} + +fn print_path(old_path: String, new_path: String) -> (String, String) { + let old_parts = old_path + .split(std::path::MAIN_SEPARATOR) + .collect::>(); + let mut result = String::new(); + for (count, new_part) in new_path.split(std::path::MAIN_SEPARATOR).enumerate() { + if count + 1 > old_parts.len() || old_parts[count] != new_part { + println!("{} {}", md_prefix(count), to_title(new_part)); + result.push_str(format!("{} {}\n", md_prefix(count), to_title(new_part)).as_str()); + } + } + (new_path, result) +} + +pub fn build_directory_md(top_dir: &Path) -> Result> { + let mut old_path = String::from(""); + let mut result = String::new(); + for filepath in good_filepaths(top_dir)? { + let mut filepath = PathBuf::from(filepath); + let filename = filepath.file_name().unwrap().to_owned(); + filepath.pop(); + let filepath = filepath.into_os_string().into_string().unwrap(); + if filepath != old_path { + let path_res = print_path(old_path, filepath); + old_path = path_res.0; + result.push_str(path_res.1.as_str()); + } + let url = format!("{}/{}", old_path, filename.to_string_lossy()); + let url = get_addr(&url); + let indent = old_path.matches(std::path::MAIN_SEPARATOR).count() + 1; + let filename = to_title(filename.to_str().unwrap().split('.').collect::>()[0]); + println!("{} [{}]({})", md_prefix(indent), filename, url); + result.push_str(format!("{} [{}]({})\n", md_prefix(indent), filename, url).as_str()); + } + Ok(result) +} + +fn to_title(name: &str) -> String { + let mut change = true; + name.chars() + .map(move |letter| { + if change && !letter.is_numeric() { + change = false; + letter.to_uppercase().next().unwrap() + } else if letter == '_' { + change = true; + ' ' + } else { + if letter.is_numeric() || !letter.is_alphanumeric() { + change = true; + } + letter + } + }) + .collect::() +} + +fn get_addr(addr: &str) -> String { + if cfg!(windows) { + format!("{}/{}", URL_BASE, switch_backslash(addr)) + } else { + format!("{}/{}", URL_BASE, addr) + } +} + +// Function that changes '\' to '/' (for Windows builds only) +fn switch_backslash(addr: &str) -> String { + addr.chars() + .map(|mut symbol| { + if symbol == '\\' { + symbol = '/'; + } + symbol + }) + .collect::() +} diff --git a/.github/workflows/scripts/build_directory/src/main.rs b/.github/workflows/scripts/build_directory/src/main.rs new file mode 100644 index 00000000000..9a54f0c0e3e --- /dev/null +++ b/.github/workflows/scripts/build_directory/src/main.rs @@ -0,0 +1,17 @@ +use std::{fs::File, io::Write, path::Path}; + +use build_directory::build_directory_md; +fn main() -> Result<(), std::io::Error> { + let mut file = File::create("DIRECTORY.md").unwrap(); // unwrap for panic + + match build_directory_md(Path::new(".")) { + Ok(buf) => { + file.write_all("# List of all files\n".as_bytes())?; + file.write_all(buf.as_bytes())?; + } + Err(err) => { + panic!("Error while creating string: {err}"); + } + } + Ok(()) +} diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 00000000000..3e99d1d726d --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,24 @@ +name: 'Close stale issues and PRs' +on: + schedule: + - cron: '0 0 * * *' +permissions: + contents: read + +jobs: + stale: + permissions: + issues: write + pull-requests: write + runs-on: ubuntu-latest + steps: + - uses: actions/stale@v9 + with: + stale-issue-message: 'This issue has been automatically marked as abandoned because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.' + close-issue-message: 'Please ping one of the maintainers once you add more information and updates here. If this is not the case and you need some help, feel free to ask for help in our [Gitter](https://gitter.im/TheAlgorithms) channel. Thank you for your contributions!' + stale-pr-message: 'This pull request has been automatically marked as abandoned because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.' + close-pr-message: 'Please ping one of the maintainers once you commit the changes requested or make improvements on the code. If this is not the case and you need some help, feel free to ask for help in our [Gitter](https://gitter.im/TheAlgorithms) channel. Thank you for your contributions!' + exempt-issue-labels: 'dont-close' + exempt-pr-labels: 'dont-close' + days-before-stale: 30 + days-before-close: 7 diff --git a/.github/workflows/upload_coverage_report.yml b/.github/workflows/upload_coverage_report.yml new file mode 100644 index 00000000000..ebe347c99e4 --- /dev/null +++ b/.github/workflows/upload_coverage_report.yml @@ -0,0 +1,48 @@ +--- +name: upload_coverage_report + +# yamllint disable-line rule:truthy +on: + workflow_dispatch: + push: + branches: + - master + pull_request: + +permissions: + contents: read + +env: + REPORT_NAME: "lcov.info" + +jobs: + upload_coverage_report: + runs-on: ubuntu-latest + env: + CARGO_TERM_COLOR: always + steps: + - uses: actions/checkout@v4 + - uses: taiki-e/install-action@cargo-llvm-cov + - name: Generate code coverage + run: > + cargo llvm-cov + --all-features + --workspace + --lcov + --output-path "${{ env.REPORT_NAME }}" + - name: Upload coverage to codecov (tokenless) + if: >- + github.event_name == 'pull_request' && + github.event.pull_request.head.repo.full_name != github.repository + uses: codecov/codecov-action@v5 + with: + files: "${{ env.REPORT_NAME }}" + fail_ci_if_error: true + - name: Upload coverage to codecov (with token) + if: "! github.event.pull_request.head.repo.fork " + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: "${{ env.REPORT_NAME }}" + fail_ci_if_error: true +... diff --git a/.gitignore b/.gitignore index f0d308e3499..15348ed2e37 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ **/*.rs.bk Cargo.lock /.idea/ +.vscode diff --git a/.gitpod.Dockerfile b/.gitpod.Dockerfile new file mode 100644 index 00000000000..7faaafa9c0c --- /dev/null +++ b/.gitpod.Dockerfile @@ -0,0 +1,3 @@ +FROM gitpod/workspace-rust:2024-06-05-14-45-28 + +USER gitpod diff --git a/.gitpod.yml b/.gitpod.yml new file mode 100644 index 00000000000..26a402b692b --- /dev/null +++ b/.gitpod.yml @@ -0,0 +1,11 @@ +--- +image: + file: .gitpod.Dockerfile + +tasks: + - init: cargo build + +vscode: + extensions: + - rust-lang.rust-analyzer +... diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index db5320c1708..00000000000 --- a/.travis.yml +++ /dev/null @@ -1,9 +0,0 @@ -language: rust - -before_script: - - rustup component add rustfmt-preview - - rustup component add clippy -script: - - cargo fmt --all -- --check - - cargo clippy --all -- -D warnings - - cargo test diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index cbb926877a4..e6c0aecd597 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,7 +4,7 @@ This project aims at showcasing common algorithms implemented in `Rust`, with an ## Project structure -The project is organized as follow: +The project is organized as follows: `src/` - `my_algo_category/` @@ -39,12 +39,13 @@ mod tests { } ``` -## Before submitting you PR +## Before submitting your PR Do **not** use acronyms: `DFS` should be `depth_first_search`. -Make sure you ran +Make sure you run * `cargo test` * `cargo fmt` - + * `cargo clippy --all -- -D warnings` + And that's about it ! diff --git a/Cargo.toml b/Cargo.toml index eabc8f02107..71d16d2eafb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,171 @@ [package] name = "the_algorithms_rust" +edition = "2021" version = "0.1.0" authors = ["Anshul Malik "] [dependencies] +num-bigint = { version = "0.4", optional = true } +num-traits = { version = "0.2", optional = true } +rand = "0.9" +nalgebra = "0.33.0" + +[dev-dependencies] +quickcheck = "1.0" +quickcheck_macros = "1.0" + +[features] +default = ["big-math"] +big-math = ["dep:num-bigint", "dep:num-traits"] + +[lints.clippy] +pedantic = "warn" +restriction = "warn" +nursery = "warn" +cargo = "warn" +# pedantic-lints: +cast_lossless = { level = "allow", priority = 1 } +cast_possible_truncation = { level = "allow", priority = 1 } +cast_possible_wrap = { level = "allow", priority = 1 } +cast_precision_loss = { level = "allow", priority = 1 } +cast_sign_loss = { level = "allow", priority = 1 } +cloned_instead_of_copied = { level = "allow", priority = 1 } +doc_markdown = { level = "allow", priority = 1 } +explicit_deref_methods = { level = "allow", priority = 1 } +explicit_iter_loop = { level = "allow", priority = 1 } +float_cmp = { level = "allow", priority = 1 } +if_not_else = { level = "allow", priority = 1 } +implicit_clone = { level = "allow", priority = 1 } +implicit_hasher = { level = "allow", priority = 1 } +items_after_statements = { level = "allow", priority = 1 } +iter_without_into_iter = { level = "allow", priority = 1 } +linkedlist = { level = "allow", priority = 1 } +manual_assert = { level = "allow", priority = 1 } +manual_let_else = { level = "allow", priority = 1 } +manual_string_new = { level = "allow", priority = 1 } +many_single_char_names = { level = "allow", priority = 1 } +match_on_vec_items = { level = "allow", priority = 1 } +match_wildcard_for_single_variants = { level = "allow", priority = 1 } +missing_errors_doc = { level = "allow", priority = 1 } +missing_fields_in_debug = { level = "allow", priority = 1 } +missing_panics_doc = { level = "allow", priority = 1 } +module_name_repetitions = { level = "allow", priority = 1 } +must_use_candidate = { level = "allow", priority = 1 } +needless_pass_by_value = { level = "allow", priority = 1 } +redundant_closure_for_method_calls = { level = "allow", priority = 1 } +return_self_not_must_use = { level = "allow", priority = 1 } +semicolon_if_nothing_returned = { level = "allow", priority = 1 } +should_panic_without_expect = { level = "allow", priority = 1 } +similar_names = { level = "allow", priority = 1 } +single_match_else = { level = "allow", priority = 1 } +stable_sort_primitive = { level = "allow", priority = 1 } +too_many_lines = { level = "allow", priority = 1 } +trivially_copy_pass_by_ref = { level = "allow", priority = 1 } +unnecessary_box_returns = { level = "allow", priority = 1 } +unnested_or_patterns = { level = "allow", priority = 1 } +unreadable_literal = { level = "allow", priority = 1 } +unused_self = { level = "allow", priority = 1 } +used_underscore_binding = { level = "allow", priority = 1 } +ref_option = { level = "allow", priority = 1 } +unnecessary_semicolon = { level = "allow", priority = 1 } +ignore_without_reason = { level = "allow", priority = 1 } +# restriction-lints: +absolute_paths = { level = "allow", priority = 1 } +arithmetic_side_effects = { level = "allow", priority = 1 } +as_conversions = { level = "allow", priority = 1 } +assertions_on_result_states = { level = "allow", priority = 1 } +blanket_clippy_restriction_lints = { level = "allow", priority = 1 } +clone_on_ref_ptr = { level = "allow", priority = 1 } +dbg_macro = { level = "allow", priority = 1 } +decimal_literal_representation = { level = "allow", priority = 1 } +default_numeric_fallback = { level = "allow", priority = 1 } +deref_by_slicing = { level = "allow", priority = 1 } +else_if_without_else = { level = "allow", priority = 1 } +exhaustive_enums = { level = "allow", priority = 1 } +exhaustive_structs = { level = "allow", priority = 1 } +expect_used = { level = "allow", priority = 1 } +float_arithmetic = { level = "allow", priority = 1 } +float_cmp_const = { level = "allow", priority = 1 } +if_then_some_else_none = { level = "allow", priority = 1 } +impl_trait_in_params = { level = "allow", priority = 1 } +implicit_return = { level = "allow", priority = 1 } +indexing_slicing = { level = "allow", priority = 1 } +integer_division = { level = "allow", priority = 1 } +integer_division_remainder_used = { level = "allow", priority = 1 } +iter_over_hash_type = { level = "allow", priority = 1 } +little_endian_bytes = { level = "allow", priority = 1 } +map_err_ignore = { level = "allow", priority = 1 } +min_ident_chars = { level = "allow", priority = 1 } +missing_assert_message = { level = "allow", priority = 1 } +missing_asserts_for_indexing = { level = "allow", priority = 1 } +missing_docs_in_private_items = { level = "allow", priority = 1 } +missing_inline_in_public_items = { level = "allow", priority = 1 } +missing_trait_methods = { level = "allow", priority = 1 } +mod_module_files = { level = "allow", priority = 1 } +modulo_arithmetic = { level = "allow", priority = 1 } +multiple_unsafe_ops_per_block = { level = "allow", priority = 1 } +non_ascii_literal = { level = "allow", priority = 1 } +panic = { level = "allow", priority = 1 } +partial_pub_fields = { level = "allow", priority = 1 } +pattern_type_mismatch = { level = "allow", priority = 1 } +print_stderr = { level = "allow", priority = 1 } +print_stdout = { level = "allow", priority = 1 } +pub_use = { level = "allow", priority = 1 } +pub_with_shorthand = { level = "allow", priority = 1 } +question_mark_used = { level = "allow", priority = 1 } +same_name_method = { level = "allow", priority = 1 } +semicolon_outside_block = { level = "allow", priority = 1 } +separated_literal_suffix = { level = "allow", priority = 1 } +shadow_reuse = { level = "allow", priority = 1 } +shadow_same = { level = "allow", priority = 1 } +shadow_unrelated = { level = "allow", priority = 1 } +single_call_fn = { level = "allow", priority = 1 } +single_char_lifetime_names = { level = "allow", priority = 1 } +std_instead_of_alloc = { level = "allow", priority = 1 } +std_instead_of_core = { level = "allow", priority = 1 } +str_to_string = { level = "allow", priority = 1 } +string_add = { level = "allow", priority = 1 } +string_slice = { level = "allow", priority = 1 } +undocumented_unsafe_blocks = { level = "allow", priority = 1 } +unnecessary_safety_comment = { level = "allow", priority = 1 } +unreachable = { level = "allow", priority = 1 } +unseparated_literal_suffix = { level = "allow", priority = 1 } +unwrap_in_result = { level = "allow", priority = 1 } +unwrap_used = { level = "allow", priority = 1 } +use_debug = { level = "allow", priority = 1 } +wildcard_enum_match_arm = { level = "allow", priority = 1 } +renamed_function_params = { level = "allow", priority = 1 } +allow_attributes_without_reason = { level = "allow", priority = 1 } +allow_attributes = { level = "allow", priority = 1 } +cfg_not_test = { level = "allow", priority = 1 } +field_scoped_visibility_modifiers = { level = "allow", priority = 1 } +unused_trait_names = { level = "allow", priority = 1 } +used_underscore_items = { level = "allow", priority = 1 } +arbitrary_source_item_ordering = { level = "allow", priority = 1 } +map_with_unused_argument_over_ranges = { level = "allow", priority = 1 } +precedence_bits = { level = "allow", priority = 1 } +redundant_test_prefix = { level = "allow", priority = 1 } +# nursery-lints: +branches_sharing_code = { level = "allow", priority = 1 } +cognitive_complexity = { level = "allow", priority = 1 } +derive_partial_eq_without_eq = { level = "allow", priority = 1 } +empty_line_after_doc_comments = { level = "allow", priority = 1 } +fallible_impl_from = { level = "allow", priority = 1 } +imprecise_flops = { level = "allow", priority = 1 } +missing_const_for_fn = { level = "allow", priority = 1 } +nonstandard_macro_braces = { level = "allow", priority = 1 } +option_if_let_else = { level = "allow", priority = 1 } +suboptimal_flops = { level = "allow", priority = 1 } +suspicious_operation_groupings = { level = "allow", priority = 1 } +use_self = { level = "allow", priority = 1 } +while_float = { level = "allow", priority = 1 } +too_long_first_doc_paragraph = { level = "allow", priority = 1 } +# cargo-lints: +cargo_common_metadata = { level = "allow", priority = 1 } +# style-lints: +doc_lazy_continuation = { level = "allow", priority = 1 } +needless_return = { level = "allow", priority = 1 } +doc_overindented_list_items = { level = "allow", priority = 1 } +# complexity-lints +precedence = { level = "allow", priority = 1 } +manual_div_ceil = { level = "allow", priority = 1 } diff --git a/DIRECTORY.md b/DIRECTORY.md index 4ae7ff4746e..564a7813807 100644 --- a/DIRECTORY.md +++ b/DIRECTORY.md @@ -1,46 +1,344 @@ # List of all files ## Src + * Backtracking + * [All Combination Of Size K](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/all_combination_of_size_k.rs) + * [Graph Coloring](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/graph_coloring.rs) + * [Hamiltonian Cycle](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/hamiltonian_cycle.rs) + * [Knight Tour](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/knight_tour.rs) + * [N Queens](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/n_queens.rs) + * [Parentheses Generator](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/parentheses_generator.rs) + * [Permutations](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/permutations.rs) + * [Rat In Maze](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/rat_in_maze.rs) + * [Subset Sum](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/subset_sum.rs) + * [Sudoku](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/sudoku.rs) + * Big Integer + * [Fast Factorial](https://github.com/TheAlgorithms/Rust/blob/master/src/big_integer/fast_factorial.rs) + * [Multiply](https://github.com/TheAlgorithms/Rust/blob/master/src/big_integer/multiply.rs) + * [Poly1305](https://github.com/TheAlgorithms/Rust/blob/master/src/big_integer/poly1305.rs) + * Bit Manipulation + * [Counting Bits](https://github.com/TheAlgorithms/Rust/blob/master/src/bit_manipulation/counting_bits.rs) + * [Highest Set Bit](https://github.com/TheAlgorithms/Rust/blob/master/src/bit_manipulation/highest_set_bit.rs) + * [N Bits Gray Code](https://github.com/TheAlgorithms/Rust/blob/master/src/bit_manipulation/n_bits_gray_code.rs) + * [Sum Of Two Integers](https://github.com/TheAlgorithms/Rust/blob/master/src/bit_manipulation/sum_of_two_integers.rs) * Ciphers + * [Aes](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/aes.rs) + * [Another Rot13](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/another_rot13.rs) + * [Baconian Cipher](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/baconian_cipher.rs) + * [Base64](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/base64.rs) + * [Blake2B](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/blake2b.rs) * [Caesar](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/caesar.rs) - * [Mod](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/mod.rs) + * [Chacha](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/chacha.rs) + * [Diffie Hellman](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/diffie_hellman.rs) + * [Hashing Traits](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/hashing_traits.rs) + * [Kerninghan](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/kerninghan.rs) + * [Morse Code](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/morse_code.rs) + * [Polybius](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/polybius.rs) + * [Rail Fence](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/rail_fence.rs) * [Rot13](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/rot13.rs) + * [Salsa](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/salsa.rs) + * [Sha256](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/sha256.rs) + * [Sha3](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/sha3.rs) + * [Tea](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/tea.rs) + * [Theoretical Rot13](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/theoretical_rot13.rs) + * [Transposition](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/transposition.rs) * [Vigenere](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/vigenere.rs) + * [Xor](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/xor.rs) + * Compression + * [Move To Front](https://github.com/TheAlgorithms/Rust/blob/master/src/compression/move_to_front.rs) + * [Run Length Encoding](https://github.com/TheAlgorithms/Rust/blob/master/src/compression/run_length_encoding.rs) + * Conversions + * [Binary To Decimal](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/binary_to_decimal.rs) + * [Binary To Hexadecimal](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/binary_to_hexadecimal.rs) + * [Decimal To Binary](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/decimal_to_binary.rs) + * [Decimal To Hexadecimal](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/decimal_to_hexadecimal.rs) + * [Hexadecimal To Binary](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/hexadecimal_to_binary.rs) + * [Hexadecimal To Decimal](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/hexadecimal_to_decimal.rs) + * [Length Conversion](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/length_conversion.rs) + * [Octal To Binary](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/octal_to_binary.rs) + * [Octal To Decimal](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/octal_to_decimal.rs) + * [Rgb Cmyk Conversion](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/rgb_cmyk_conversion.rs) * Data Structures + * [Avl Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/avl_tree.rs) * [B Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/b_tree.rs) * [Binary Search Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/binary_search_tree.rs) + * [Fenwick Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/fenwick_tree.rs) + * [Floyds Algorithm](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/floyds_algorithm.rs) + * [Graph](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/graph.rs) + * [Hash Table](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/hash_table.rs) * [Heap](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/heap.rs) + * [Lazy Segment Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/lazy_segment_tree.rs) * [Linked List](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/linked_list.rs) - * [Mod](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/mod.rs) + * Probabilistic + * [Bloom Filter](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/probabilistic/bloom_filter.rs) + * [Count Min Sketch](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/probabilistic/count_min_sketch.rs) + * [Queue](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/queue.rs) + * [Range Minimum Query](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/range_minimum_query.rs) + * [Rb Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/rb_tree.rs) + * [Segment Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/segment_tree.rs) + * [Segment Tree Recursive](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/segment_tree_recursive.rs) + * [Stack Using Singly Linked List](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/stack_using_singly_linked_list.rs) + * [Treap](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/treap.rs) + * [Trie](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/trie.rs) + * [Union Find](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/union_find.rs) + * [Veb Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/veb_tree.rs) * Dynamic Programming - * [Edit Distance](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/edit_distance.rs) + * [Coin Change](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/coin_change.rs) * [Egg Dropping](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/egg_dropping.rs) * [Fibonacci](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/fibonacci.rs) + * [Fractional Knapsack](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/fractional_knapsack.rs) + * [Is Subsequence](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/is_subsequence.rs) * [Knapsack](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/knapsack.rs) * [Longest Common Subsequence](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/longest_common_subsequence.rs) - * [Coin Change](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/coin_change.rs) - * [Mod](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/mod.rs) + * [Longest Common Substring](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/longest_common_substring.rs) + * [Longest Continuous Increasing Subsequence](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/longest_continuous_increasing_subsequence.rs) + * [Longest Increasing Subsequence](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/longest_increasing_subsequence.rs) + * [Matrix Chain Multiply](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/matrix_chain_multiply.rs) + * [Maximal Square](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/maximal_square.rs) + * [Maximum Subarray](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/maximum_subarray.rs) + * [Minimum Cost Path](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/minimum_cost_path.rs) + * [Optimal Bst](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/optimal_bst.rs) + * [Rod Cutting](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/rod_cutting.rs) + * [Snail](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/snail.rs) + * [Subset Generation](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/subset_generation.rs) + * [Trapped Rainwater](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/trapped_rainwater.rs) + * [Word Break](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/word_break.rs) + * Financial + * [Present Value](https://github.com/TheAlgorithms/Rust/blob/master/src/financial/present_value.rs) * General * [Convex Hull](https://github.com/TheAlgorithms/Rust/blob/master/src/general/convex_hull.rs) + * [Fisher Yates Shuffle](https://github.com/TheAlgorithms/Rust/blob/master/src/general/fisher_yates_shuffle.rs) + * [Genetic](https://github.com/TheAlgorithms/Rust/blob/master/src/general/genetic.rs) * [Hanoi](https://github.com/TheAlgorithms/Rust/blob/master/src/general/hanoi.rs) + * [Huffman Encoding](https://github.com/TheAlgorithms/Rust/blob/master/src/general/huffman_encoding.rs) + * [Kadane Algorithm](https://github.com/TheAlgorithms/Rust/blob/master/src/general/kadane_algorithm.rs) * [Kmeans](https://github.com/TheAlgorithms/Rust/blob/master/src/general/kmeans.rs) - * [Mod](https://github.com/TheAlgorithms/Rust/blob/master/src/general/mod.rs) + * [Mex](https://github.com/TheAlgorithms/Rust/blob/master/src/general/mex.rs) + * Permutations + * [Heap](https://github.com/TheAlgorithms/Rust/blob/master/src/general/permutations/heap.rs) + * [Naive](https://github.com/TheAlgorithms/Rust/blob/master/src/general/permutations/naive.rs) + * [Steinhaus Johnson Trotter](https://github.com/TheAlgorithms/Rust/blob/master/src/general/permutations/steinhaus_johnson_trotter.rs) + * [Two Sum](https://github.com/TheAlgorithms/Rust/blob/master/src/general/two_sum.rs) + * Geometry + * [Closest Points](https://github.com/TheAlgorithms/Rust/blob/master/src/geometry/closest_points.rs) + * [Graham Scan](https://github.com/TheAlgorithms/Rust/blob/master/src/geometry/graham_scan.rs) + * [Jarvis Scan](https://github.com/TheAlgorithms/Rust/blob/master/src/geometry/jarvis_scan.rs) + * [Point](https://github.com/TheAlgorithms/Rust/blob/master/src/geometry/point.rs) + * [Polygon Points](https://github.com/TheAlgorithms/Rust/blob/master/src/geometry/polygon_points.rs) + * [Ramer Douglas Peucker](https://github.com/TheAlgorithms/Rust/blob/master/src/geometry/ramer_douglas_peucker.rs) + * [Segment](https://github.com/TheAlgorithms/Rust/blob/master/src/geometry/segment.rs) + * Graph + * [Astar](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/astar.rs) + * [Bellman Ford](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/bellman_ford.rs) + * [Bipartite Matching](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/bipartite_matching.rs) + * [Breadth First Search](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/breadth_first_search.rs) + * [Centroid Decomposition](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/centroid_decomposition.rs) + * [Decremental Connectivity](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/decremental_connectivity.rs) + * [Depth First Search](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/depth_first_search.rs) + * [Depth First Search Tic Tac Toe](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/depth_first_search_tic_tac_toe.rs) + * [Detect Cycle](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/detect_cycle.rs) + * [Dijkstra](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/dijkstra.rs) + * [Dinic Maxflow](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/dinic_maxflow.rs) + * [Disjoint Set Union](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/disjoint_set_union.rs) + * [Eulerian Path](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/eulerian_path.rs) + * [Floyd Warshall](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/floyd_warshall.rs) + * [Ford Fulkerson](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/ford_fulkerson.rs) + * [Graph Enumeration](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/graph_enumeration.rs) + * [Heavy Light Decomposition](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/heavy_light_decomposition.rs) + * [Kosaraju](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/kosaraju.rs) + * [Lee Breadth First Search](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/lee_breadth_first_search.rs) + * [Lowest Common Ancestor](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/lowest_common_ancestor.rs) + * [Minimum Spanning Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/minimum_spanning_tree.rs) + * [Prim](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/prim.rs) + * [Prufer Code](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/prufer_code.rs) + * [Strongly Connected Components](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/strongly_connected_components.rs) + * [Tarjans Ssc](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/tarjans_ssc.rs) + * [Topological Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/topological_sort.rs) + * [Two Satisfiability](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/two_satisfiability.rs) + * Greedy + * [Stable Matching](https://github.com/TheAlgorithms/Rust/blob/master/src/greedy/stable_matching.rs) * [Lib](https://github.com/TheAlgorithms/Rust/blob/master/src/lib.rs) + * Machine Learning + * [Cholesky](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/cholesky.rs) + * [K Means](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/k_means.rs) + * [Linear Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/linear_regression.rs) + * [Logistic Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/logistic_regression.rs) + * Loss Function + * [Average Margin Ranking Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/average_margin_ranking_loss.rs) + * [Hinge Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/hinge_loss.rs) + * [Huber Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/huber_loss.rs) + * [Kl Divergence Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/kl_divergence_loss.rs) + * [Mean Absolute Error Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/mean_absolute_error_loss.rs) + * [Mean Squared Error Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/mean_squared_error_loss.rs) + * [Negative Log Likelihood](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/negative_log_likelihood.rs) + * Optimization + * [Adam](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/optimization/adam.rs) + * [Gradient Descent](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/optimization/gradient_descent.rs) + * Math + * [Abs](https://github.com/TheAlgorithms/Rust/blob/master/src/math/abs.rs) + * [Aliquot Sum](https://github.com/TheAlgorithms/Rust/blob/master/src/math/aliquot_sum.rs) + * [Amicable Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/amicable_numbers.rs) + * [Area Of Polygon](https://github.com/TheAlgorithms/Rust/blob/master/src/math/area_of_polygon.rs) + * [Area Under Curve](https://github.com/TheAlgorithms/Rust/blob/master/src/math/area_under_curve.rs) + * [Armstrong Number](https://github.com/TheAlgorithms/Rust/blob/master/src/math/armstrong_number.rs) + * [Average](https://github.com/TheAlgorithms/Rust/blob/master/src/math/average.rs) + * [Baby Step Giant Step](https://github.com/TheAlgorithms/Rust/blob/master/src/math/baby_step_giant_step.rs) + * [Bell Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/bell_numbers.rs) + * [Binary Exponentiation](https://github.com/TheAlgorithms/Rust/blob/master/src/math/binary_exponentiation.rs) + * [Binomial Coefficient](https://github.com/TheAlgorithms/Rust/blob/master/src/math/binomial_coefficient.rs) + * [Catalan Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/catalan_numbers.rs) + * [Ceil](https://github.com/TheAlgorithms/Rust/blob/master/src/math/ceil.rs) + * [Chinese Remainder Theorem](https://github.com/TheAlgorithms/Rust/blob/master/src/math/chinese_remainder_theorem.rs) + * [Collatz Sequence](https://github.com/TheAlgorithms/Rust/blob/master/src/math/collatz_sequence.rs) + * [Combinations](https://github.com/TheAlgorithms/Rust/blob/master/src/math/combinations.rs) + * [Cross Entropy Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/math/cross_entropy_loss.rs) + * [Decimal To Fraction](https://github.com/TheAlgorithms/Rust/blob/master/src/math/decimal_to_fraction.rs) + * [Doomsday](https://github.com/TheAlgorithms/Rust/blob/master/src/math/doomsday.rs) + * [Elliptic Curve](https://github.com/TheAlgorithms/Rust/blob/master/src/math/elliptic_curve.rs) + * [Euclidean Distance](https://github.com/TheAlgorithms/Rust/blob/master/src/math/euclidean_distance.rs) + * [Exponential Linear Unit](https://github.com/TheAlgorithms/Rust/blob/master/src/math/exponential_linear_unit.rs) + * [Extended Euclidean Algorithm](https://github.com/TheAlgorithms/Rust/blob/master/src/math/extended_euclidean_algorithm.rs) + * [Factorial](https://github.com/TheAlgorithms/Rust/blob/master/src/math/factorial.rs) + * [Factors](https://github.com/TheAlgorithms/Rust/blob/master/src/math/factors.rs) + * [Fast Fourier Transform](https://github.com/TheAlgorithms/Rust/blob/master/src/math/fast_fourier_transform.rs) + * [Fast Power](https://github.com/TheAlgorithms/Rust/blob/master/src/math/fast_power.rs) + * [Faster Perfect Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/faster_perfect_numbers.rs) + * [Field](https://github.com/TheAlgorithms/Rust/blob/master/src/math/field.rs) + * [Frizzy Number](https://github.com/TheAlgorithms/Rust/blob/master/src/math/frizzy_number.rs) + * [Gaussian Elimination](https://github.com/TheAlgorithms/Rust/blob/master/src/math/gaussian_elimination.rs) + * [Gaussian Error Linear Unit](https://github.com/TheAlgorithms/Rust/blob/master/src/math/gaussian_error_linear_unit.rs) + * [Gcd Of N Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/gcd_of_n_numbers.rs) + * [Geometric Series](https://github.com/TheAlgorithms/Rust/blob/master/src/math/geometric_series.rs) + * [Greatest Common Divisor](https://github.com/TheAlgorithms/Rust/blob/master/src/math/greatest_common_divisor.rs) + * [Huber Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/math/huber_loss.rs) + * [Infix To Postfix](https://github.com/TheAlgorithms/Rust/blob/master/src/math/infix_to_postfix.rs) + * [Interest](https://github.com/TheAlgorithms/Rust/blob/master/src/math/interest.rs) + * [Interpolation](https://github.com/TheAlgorithms/Rust/blob/master/src/math/interpolation.rs) + * [Interquartile Range](https://github.com/TheAlgorithms/Rust/blob/master/src/math/interquartile_range.rs) + * [Karatsuba Multiplication](https://github.com/TheAlgorithms/Rust/blob/master/src/math/karatsuba_multiplication.rs) + * [Lcm Of N Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/lcm_of_n_numbers.rs) + * [Leaky Relu](https://github.com/TheAlgorithms/Rust/blob/master/src/math/leaky_relu.rs) + * [Least Square Approx](https://github.com/TheAlgorithms/Rust/blob/master/src/math/least_square_approx.rs) + * [Linear Sieve](https://github.com/TheAlgorithms/Rust/blob/master/src/math/linear_sieve.rs) + * [Logarithm](https://github.com/TheAlgorithms/Rust/blob/master/src/math/logarithm.rs) + * [Lucas Series](https://github.com/TheAlgorithms/Rust/blob/master/src/math/lucas_series.rs) + * [Matrix Ops](https://github.com/TheAlgorithms/Rust/blob/master/src/math/matrix_ops.rs) + * [Mersenne Primes](https://github.com/TheAlgorithms/Rust/blob/master/src/math/mersenne_primes.rs) + * [Miller Rabin](https://github.com/TheAlgorithms/Rust/blob/master/src/math/miller_rabin.rs) + * [Modular Exponential](https://github.com/TheAlgorithms/Rust/blob/master/src/math/modular_exponential.rs) + * [Newton Raphson](https://github.com/TheAlgorithms/Rust/blob/master/src/math/newton_raphson.rs) + * [Nthprime](https://github.com/TheAlgorithms/Rust/blob/master/src/math/nthprime.rs) + * [Pascal Triangle](https://github.com/TheAlgorithms/Rust/blob/master/src/math/pascal_triangle.rs) + * [Perfect Cube](https://github.com/TheAlgorithms/Rust/blob/master/src/math/perfect_cube.rs) + * [Perfect Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/perfect_numbers.rs) + * [Perfect Square](https://github.com/TheAlgorithms/Rust/blob/master/src/math/perfect_square.rs) + * [Pollard Rho](https://github.com/TheAlgorithms/Rust/blob/master/src/math/pollard_rho.rs) + * [Postfix Evaluation](https://github.com/TheAlgorithms/Rust/blob/master/src/math/postfix_evaluation.rs) + * [Prime Check](https://github.com/TheAlgorithms/Rust/blob/master/src/math/prime_check.rs) + * [Prime Factors](https://github.com/TheAlgorithms/Rust/blob/master/src/math/prime_factors.rs) + * [Prime Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/prime_numbers.rs) + * [Quadratic Residue](https://github.com/TheAlgorithms/Rust/blob/master/src/math/quadratic_residue.rs) + * [Random](https://github.com/TheAlgorithms/Rust/blob/master/src/math/random.rs) + * [Relu](https://github.com/TheAlgorithms/Rust/blob/master/src/math/relu.rs) + * [Sieve Of Eratosthenes](https://github.com/TheAlgorithms/Rust/blob/master/src/math/sieve_of_eratosthenes.rs) + * [Sigmoid](https://github.com/TheAlgorithms/Rust/blob/master/src/math/sigmoid.rs) + * [Signum](https://github.com/TheAlgorithms/Rust/blob/master/src/math/signum.rs) + * [Simpsons Integration](https://github.com/TheAlgorithms/Rust/blob/master/src/math/simpsons_integration.rs) + * [Softmax](https://github.com/TheAlgorithms/Rust/blob/master/src/math/softmax.rs) + * [Sprague Grundy Theorem](https://github.com/TheAlgorithms/Rust/blob/master/src/math/sprague_grundy_theorem.rs) + * [Square Pyramidal Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/square_pyramidal_numbers.rs) + * [Square Root](https://github.com/TheAlgorithms/Rust/blob/master/src/math/square_root.rs) + * [Sum Of Digits](https://github.com/TheAlgorithms/Rust/blob/master/src/math/sum_of_digits.rs) + * [Sum Of Geometric Progression](https://github.com/TheAlgorithms/Rust/blob/master/src/math/sum_of_geometric_progression.rs) + * [Sum Of Harmonic Series](https://github.com/TheAlgorithms/Rust/blob/master/src/math/sum_of_harmonic_series.rs) + * [Sylvester Sequence](https://github.com/TheAlgorithms/Rust/blob/master/src/math/sylvester_sequence.rs) + * [Tanh](https://github.com/TheAlgorithms/Rust/blob/master/src/math/tanh.rs) + * [Trapezoidal Integration](https://github.com/TheAlgorithms/Rust/blob/master/src/math/trapezoidal_integration.rs) + * [Trial Division](https://github.com/TheAlgorithms/Rust/blob/master/src/math/trial_division.rs) + * [Trig Functions](https://github.com/TheAlgorithms/Rust/blob/master/src/math/trig_functions.rs) + * [Vector Cross Product](https://github.com/TheAlgorithms/Rust/blob/master/src/math/vector_cross_product.rs) + * [Zellers Congruence Algorithm](https://github.com/TheAlgorithms/Rust/blob/master/src/math/zellers_congruence_algorithm.rs) + * Navigation + * [Bearing](https://github.com/TheAlgorithms/Rust/blob/master/src/navigation/bearing.rs) + * [Haversine](https://github.com/TheAlgorithms/Rust/blob/master/src/navigation/haversine.rs) + * Number Theory + * [Compute Totient](https://github.com/TheAlgorithms/Rust/blob/master/src/number_theory/compute_totient.rs) + * [Euler Totient](https://github.com/TheAlgorithms/Rust/blob/master/src/number_theory/euler_totient.rs) + * [Kth Factor](https://github.com/TheAlgorithms/Rust/blob/master/src/number_theory/kth_factor.rs) * Searching * [Binary Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/binary_search.rs) + * [Binary Search Recursive](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/binary_search_recursive.rs) + * [Exponential Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/exponential_search.rs) + * [Fibonacci Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/fibonacci_search.rs) + * [Interpolation Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/interpolation_search.rs) + * [Jump Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/jump_search.rs) + * [Kth Smallest](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/kth_smallest.rs) + * [Kth Smallest Heap](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/kth_smallest_heap.rs) * [Linear Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/linear_search.rs) - * [Mod](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/mod.rs) + * [Moore Voting](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/moore_voting.rs) + * [Quick Select](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/quick_select.rs) + * [Saddleback Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/saddleback_search.rs) + * [Ternary Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/ternary_search.rs) + * [Ternary Search Min Max](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/ternary_search_min_max.rs) + * [Ternary Search Min Max Recursive](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/ternary_search_min_max_recursive.rs) + * [Ternary Search Recursive](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/ternary_search_recursive.rs) * Sorting + * [Bead Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/bead_sort.rs) + * [Binary Insertion Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/binary_insertion_sort.rs) + * [Bingo Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/bingo_sort.rs) + * [Bitonic Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/bitonic_sort.rs) + * [Bogo Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/bogo_sort.rs) * [Bubble Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/bubble_sort.rs) + * [Bucket Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/bucket_sort.rs) + * [Cocktail Shaker Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/cocktail_shaker_sort.rs) + * [Comb Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/comb_sort.rs) * [Counting Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/counting_sort.rs) + * [Cycle Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/cycle_sort.rs) + * [Dutch National Flag Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/dutch_national_flag_sort.rs) + * [Exchange Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/exchange_sort.rs) + * [Gnome Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/gnome_sort.rs) * [Heap Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/heap_sort.rs) - * [Insertion](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/insertion.rs) + * [Insertion Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/insertion_sort.rs) + * [Intro Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/intro_sort.rs) * [Merge Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/merge_sort.rs) - * [Mod](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/mod.rs) + * [Odd Even Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/odd_even_sort.rs) + * [Pancake Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/pancake_sort.rs) + * [Patience Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/patience_sort.rs) + * [Pigeonhole Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/pigeonhole_sort.rs) * [Quick Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/quick_sort.rs) + * [Quick Sort 3_ways](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/quick_sort_3_ways.rs) * [Radix Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/radix_sort.rs) * [Selection Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/selection_sort.rs) * [Shell Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/shell_sort.rs) + * [Sleep Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/sleep_sort.rs) + * [Sort Utils](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/sort_utils.rs) + * [Stooge Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/stooge_sort.rs) + * [Tim Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/tim_sort.rs) + * [Tree Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/tree_sort.rs) + * [Wave Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/wave_sort.rs) + * [Wiggle Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/wiggle_sort.rs) * String + * [Aho Corasick](https://github.com/TheAlgorithms/Rust/blob/master/src/string/aho_corasick.rs) + * [Anagram](https://github.com/TheAlgorithms/Rust/blob/master/src/string/anagram.rs) + * [Autocomplete Using Trie](https://github.com/TheAlgorithms/Rust/blob/master/src/string/autocomplete_using_trie.rs) + * [Boyer Moore Search](https://github.com/TheAlgorithms/Rust/blob/master/src/string/boyer_moore_search.rs) + * [Burrows Wheeler Transform](https://github.com/TheAlgorithms/Rust/blob/master/src/string/burrows_wheeler_transform.rs) + * [Duval Algorithm](https://github.com/TheAlgorithms/Rust/blob/master/src/string/duval_algorithm.rs) + * [Hamming Distance](https://github.com/TheAlgorithms/Rust/blob/master/src/string/hamming_distance.rs) + * [Isogram](https://github.com/TheAlgorithms/Rust/blob/master/src/string/isogram.rs) + * [Isomorphism](https://github.com/TheAlgorithms/Rust/blob/master/src/string/isomorphism.rs) + * [Jaro Winkler Distance](https://github.com/TheAlgorithms/Rust/blob/master/src/string/jaro_winkler_distance.rs) * [Knuth Morris Pratt](https://github.com/TheAlgorithms/Rust/blob/master/src/string/knuth_morris_pratt.rs) - * [Mod](https://github.com/TheAlgorithms/Rust/blob/master/src/string/mod.rs) + * [Levenshtein Distance](https://github.com/TheAlgorithms/Rust/blob/master/src/string/levenshtein_distance.rs) + * [Lipogram](https://github.com/TheAlgorithms/Rust/blob/master/src/string/lipogram.rs) + * [Manacher](https://github.com/TheAlgorithms/Rust/blob/master/src/string/manacher.rs) + * [Palindrome](https://github.com/TheAlgorithms/Rust/blob/master/src/string/palindrome.rs) + * [Pangram](https://github.com/TheAlgorithms/Rust/blob/master/src/string/pangram.rs) + * [Rabin Karp](https://github.com/TheAlgorithms/Rust/blob/master/src/string/rabin_karp.rs) + * [Reverse](https://github.com/TheAlgorithms/Rust/blob/master/src/string/reverse.rs) + * [Run Length Encoding](https://github.com/TheAlgorithms/Rust/blob/master/src/string/run_length_encoding.rs) + * [Shortest Palindrome](https://github.com/TheAlgorithms/Rust/blob/master/src/string/shortest_palindrome.rs) + * [Suffix Array](https://github.com/TheAlgorithms/Rust/blob/master/src/string/suffix_array.rs) + * [Suffix Array Manber Myers](https://github.com/TheAlgorithms/Rust/blob/master/src/string/suffix_array_manber_myers.rs) + * [Suffix Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/string/suffix_tree.rs) + * [Z Algorithm](https://github.com/TheAlgorithms/Rust/blob/master/src/string/z_algorithm.rs) diff --git a/README.md b/README.md index 872bdf02730..fcab70096eb 100644 --- a/README.md +++ b/README.md @@ -1,83 +1,32 @@ -# The Algorithms - Rust [![Gitter](https://img.shields.io/gitter/room/the-algorithms/rust.svg?style=flat-square)](https://gitter.im/the-algorithms/rust) [![Build Status](https://travis-ci.com/TheAlgorithms/Rust.svg?branch=master)](https://travis-ci.com/TheAlgorithms/Rust) +
+ + + +

The Algorithms - Rust

+ + + + Gitpod Ready-to-Code + + + Build workflow + + + + + + Discord community + + + Gitter chat + + + +

All algorithms implemented in Rust - for education

+
+ +### List of Algorithms +See our [directory](DIRECTORY.md) for easier navigation and a better overview of the project. - - - - -### All algorithms implemented in Rust (for educational purposes) -These are for demonstration purposes only. - -## [Sort Algorithms](./src/sorting) - -- [x] [Bubble](./src/sorting/bubble_sort.rs) -- [x] [Counting](./src/sorting/counting_sort.rs) -- [x] [Heap](./src/sorting/heap_sort.rs) -- [x] [Insertion](./src/sorting/insertion_sort.rs) -- [x] [Merge](./src/sorting/merge_sort.rs) -- [x] [Quick](./src/sorting/quick_sort.rs) -- [x] [Radix](./src/sorting/radix_sort.rs) -- [x] [Selection](./src/sorting/selection_sort.rs) -- [x] [Shell](./src/sorting/shell_sort.rs) - -## Graphs - -- [ ] Dijkstra -- [ ] Kruskal's Minimum Spanning Tree -- [ ] Prim's Minimum Spanning Tree -- [ ] BFS -- [ ] DFS - -## [Dynamic Programming](./src/dynamic_programming) - -- [x] [0-1 Knapsack](./src/dynamic_programming/knapsack.rs) -- [x] [Edit Distance](./src/dynamic_programming/edit_distance.rs) -- [x] [Longest common subsequence](./src/dynamic_programming/longest_common_subsequence.rs) -- [ ] Longest increasing subsequence -- [x] [K-Means Clustering](./src/general/kmeans.rs) -- [ ] Coin Change -- [ ] Rod cut -- [x] [Egg Dropping Puzzle](./src/dynamic_programming/egg_dropping.rs) - -## [Data Structures](./src/data_structures) - -- [ ] Queue -- [x] [Heap](./src/data_structures/heap.rs) -- [x] [Linked List](./src/data_structures/linked_list.rs) -- Graph - - [ ] Directed - - [ ] Undirected -- [ ] Trie -- [x] [Binary Search Tree](./src/data_structures/binary_search_tree.rs) -- [x] [B-Tree](./src/data_structures/b_tree.rs) -- [ ] AVL Tree - -## [Strings](./src/string) - -- [x] [Knuth Morris Pratt](./src/string/knuth_morris_pratt.rs) -- [ ] Rabin Carp - -## [General](./src/general) - -- [x] [Convex Hull: Graham Scan](./src/general/convex_hull.rs) -- [ ] N-Queensp -- [ ] Graph Coloringp -- [x] [Tower of Hanoi](./src/general/hanoi.rs) - -## [Search Algorithms](./src/searching) - -- [x] [Linear](./src/searching/linear_search.rs) -- [x] [Binary](./src/searching/binary_search.rs) - -## [Ciphers](./src/ciphers) - -- [x] [Caesar](./src/ciphers/caesar.rs) -- [x] [Vigenère](./src/ciphers/vigenere.rs) -- [ ] Transposition - ---- - -### All implemented Algos -See [DIRECTORY.md](./DIRECTORY.md) ### Contributing - -See [CONTRIBUTING.md](CONTRIBUTING.md) +Read through our [Contribution Guidelines](CONTRIBUTING.md) before you contribute. diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 00000000000..1b3dd21fbf7 --- /dev/null +++ b/clippy.toml @@ -0,0 +1,4 @@ +allowed-duplicate-crates = [ + "zerocopy", + "zerocopy-derive", +] diff --git a/src/backtracking/all_combination_of_size_k.rs b/src/backtracking/all_combination_of_size_k.rs new file mode 100644 index 00000000000..65b6b643b97 --- /dev/null +++ b/src/backtracking/all_combination_of_size_k.rs @@ -0,0 +1,123 @@ +//! This module provides a function to generate all possible combinations +//! of `k` numbers out of `0...n-1` using a backtracking algorithm. + +/// Custom error type for combination generation. +#[derive(Debug, PartialEq)] +pub enum CombinationError { + KGreaterThanN, + InvalidZeroRange, +} + +/// Generates all possible combinations of `k` numbers out of `0...n-1`. +/// +/// # Arguments +/// +/// * `n` - The upper limit of the range (`0` to `n-1`). +/// * `k` - The number of elements in each combination. +/// +/// # Returns +/// +/// A `Result` containing a vector with all possible combinations of `k` numbers out of `0...n-1`, +/// or a `CombinationError` if the input is invalid. +pub fn generate_all_combinations(n: usize, k: usize) -> Result>, CombinationError> { + if n == 0 && k > 0 { + return Err(CombinationError::InvalidZeroRange); + } + + if k > n { + return Err(CombinationError::KGreaterThanN); + } + + let mut combinations = vec![]; + let mut current = vec![0; k]; + backtrack(0, n, k, 0, &mut current, &mut combinations); + Ok(combinations) +} + +/// Helper function to generate combinations recursively. +/// +/// # Arguments +/// +/// * `start` - The current number to start the combination with. +/// * `n` - The upper limit of the range (`0` to `n-1`). +/// * `k` - The number of elements left to complete the combination. +/// * `index` - The current index being filled in the combination. +/// * `current` - A mutable reference to the current combination being constructed. +/// * `combinations` - A mutable reference to the vector holding all combinations. +fn backtrack( + start: usize, + n: usize, + k: usize, + index: usize, + current: &mut Vec, + combinations: &mut Vec>, +) { + if index == k { + combinations.push(current.clone()); + return; + } + + for num in start..=(n - k + index) { + current[index] = num; + backtrack(num + 1, n, k, index + 1, current, combinations); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! combination_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (n, k, expected) = $test_case; + assert_eq!(generate_all_combinations(n, k), expected); + } + )* + } + } + + combination_tests! { + test_generate_4_2: (4, 2, Ok(vec![ + vec![0, 1], + vec![0, 2], + vec![0, 3], + vec![1, 2], + vec![1, 3], + vec![2, 3], + ])), + test_generate_4_3: (4, 3, Ok(vec![ + vec![0, 1, 2], + vec![0, 1, 3], + vec![0, 2, 3], + vec![1, 2, 3], + ])), + test_generate_5_3: (5, 3, Ok(vec![ + vec![0, 1, 2], + vec![0, 1, 3], + vec![0, 1, 4], + vec![0, 2, 3], + vec![0, 2, 4], + vec![0, 3, 4], + vec![1, 2, 3], + vec![1, 2, 4], + vec![1, 3, 4], + vec![2, 3, 4], + ])), + test_generate_5_1: (5, 1, Ok(vec![ + vec![0], + vec![1], + vec![2], + vec![3], + vec![4], + ])), + test_empty: (0, 0, Ok(vec![vec![]])), + test_generate_n_eq_k: (3, 3, Ok(vec![ + vec![0, 1, 2], + ])), + test_generate_k_greater_than_n: (3, 4, Err(CombinationError::KGreaterThanN)), + test_zero_range_with_nonzero_k: (0, 1, Err(CombinationError::InvalidZeroRange)), + } +} diff --git a/src/backtracking/graph_coloring.rs b/src/backtracking/graph_coloring.rs new file mode 100644 index 00000000000..40bdf398e91 --- /dev/null +++ b/src/backtracking/graph_coloring.rs @@ -0,0 +1,370 @@ +//! This module provides functionality for generating all possible colorings of a undirected (or directed) graph +//! given a certain number of colors. It includes the GraphColoring struct and methods +//! for validating color assignments and finding all valid colorings. + +/// Represents potential errors when coloring on an adjacency matrix. +#[derive(Debug, PartialEq, Eq)] +pub enum GraphColoringError { + // Indicates that the adjacency matrix is empty + EmptyAdjacencyMatrix, + // Indicates that the adjacency matrix is not squared + ImproperAdjacencyMatrix, +} + +/// Generates all possible valid colorings of a graph. +/// +/// # Arguments +/// +/// * `adjacency_matrix` - A 2D vector representing the adjacency matrix of the graph. +/// * `num_colors` - The number of colors available for coloring the graph. +/// +/// # Returns +/// +/// * A `Result` containing an `Option` with a vector of solutions or a `GraphColoringError` if +/// there is an issue with the matrix. +pub fn generate_colorings( + adjacency_matrix: Vec>, + num_colors: usize, +) -> Result>>, GraphColoringError> { + Ok(GraphColoring::new(adjacency_matrix)?.find_solutions(num_colors)) +} + +/// A struct representing a graph coloring problem. +struct GraphColoring { + // The adjacency matrix of the graph + adjacency_matrix: Vec>, + // The current colors assigned to each vertex + vertex_colors: Vec, + // Vector of all valid color assignments for the vertices found during the search + solutions: Vec>, +} + +impl GraphColoring { + /// Creates a new GraphColoring instance. + /// + /// # Arguments + /// + /// * `adjacency_matrix` - A 2D vector representing the adjacency matrix of the graph. + /// + /// # Returns + /// + /// * A new instance of GraphColoring or a `GraphColoringError` if the matrix is empty or non-square. + fn new(adjacency_matrix: Vec>) -> Result { + let num_vertices = adjacency_matrix.len(); + + // Check if the adjacency matrix is empty + if num_vertices == 0 { + return Err(GraphColoringError::EmptyAdjacencyMatrix); + } + + // Check if the adjacency matrix is square + if adjacency_matrix.iter().any(|row| row.len() != num_vertices) { + return Err(GraphColoringError::ImproperAdjacencyMatrix); + } + + Ok(GraphColoring { + adjacency_matrix, + vertex_colors: vec![usize::MAX; num_vertices], + solutions: Vec::new(), + }) + } + + /// Returns the number of vertices in the graph. + fn num_vertices(&self) -> usize { + self.adjacency_matrix.len() + } + + /// Checks if a given color can be assigned to a vertex without causing a conflict. + /// + /// # Arguments + /// + /// * `vertex` - The index of the vertex to be colored. + /// * `color` - The color to be assigned to the vertex. + /// + /// # Returns + /// + /// * `true` if the color can be assigned to the vertex, `false` otherwise. + fn is_color_valid(&self, vertex: usize, color: usize) -> bool { + for neighbor in 0..self.num_vertices() { + // Check outgoing edges from vertex and incoming edges to vertex + if (self.adjacency_matrix[vertex][neighbor] || self.adjacency_matrix[neighbor][vertex]) + && self.vertex_colors[neighbor] == color + { + return false; + } + } + true + } + + /// Recursively finds all valid colorings for the graph. + /// + /// # Arguments + /// + /// * `vertex` - The current vertex to be colored. + /// * `num_colors` - The number of colors available for coloring the graph. + fn find_colorings(&mut self, vertex: usize, num_colors: usize) { + if vertex == self.num_vertices() { + self.solutions.push(self.vertex_colors.clone()); + return; + } + + for color in 0..num_colors { + if self.is_color_valid(vertex, color) { + self.vertex_colors[vertex] = color; + self.find_colorings(vertex + 1, num_colors); + self.vertex_colors[vertex] = usize::MAX; + } + } + } + + /// Finds all solutions for the graph coloring problem. + /// + /// # Arguments + /// + /// * `num_colors` - The number of colors available for coloring the graph. + /// + /// # Returns + /// + /// * A `Result` containing an `Option` with a vector of solutions or a `GraphColoringError`. + fn find_solutions(&mut self, num_colors: usize) -> Option>> { + self.find_colorings(0, num_colors); + if self.solutions.is_empty() { + None + } else { + Some(std::mem::take(&mut self.solutions)) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_graph_coloring { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (adjacency_matrix, num_colors, expected) = $test_case; + let actual = generate_colorings(adjacency_matrix, num_colors); + assert_eq!(actual, expected); + } + )* + }; + } + + test_graph_coloring! { + test_complete_graph_with_3_colors: ( + vec![ + vec![false, true, true, true], + vec![true, false, true, false], + vec![true, true, false, true], + vec![true, false, true, false], + ], + 3, + Ok(Some(vec![ + vec![0, 1, 2, 1], + vec![0, 2, 1, 2], + vec![1, 0, 2, 0], + vec![1, 2, 0, 2], + vec![2, 0, 1, 0], + vec![2, 1, 0, 1], + ])) + ), + test_linear_graph_with_2_colors: ( + vec![ + vec![false, true, false, false], + vec![true, false, true, false], + vec![false, true, false, true], + vec![false, false, true, false], + ], + 2, + Ok(Some(vec![ + vec![0, 1, 0, 1], + vec![1, 0, 1, 0], + ])) + ), + test_incomplete_graph_with_insufficient_colors: ( + vec![ + vec![false, true, true], + vec![true, false, true], + vec![true, true, false], + ], + 1, + Ok(None::>>) + ), + test_empty_graph: ( + vec![], + 1, + Err(GraphColoringError::EmptyAdjacencyMatrix) + ), + test_non_square_matrix: ( + vec![ + vec![false, true, true], + vec![true, false, true], + ], + 3, + Err(GraphColoringError::ImproperAdjacencyMatrix) + ), + test_single_vertex_graph: ( + vec![ + vec![false], + ], + 1, + Ok(Some(vec![ + vec![0], + ])) + ), + test_bipartite_graph_with_2_colors: ( + vec![ + vec![false, true, false, true], + vec![true, false, true, false], + vec![false, true, false, true], + vec![true, false, true, false], + ], + 2, + Ok(Some(vec![ + vec![0, 1, 0, 1], + vec![1, 0, 1, 0], + ])) + ), + test_large_graph_with_3_colors: ( + vec![ + vec![false, true, true, false, true, true, false, true, true, false], + vec![true, false, true, true, false, true, true, false, true, true], + vec![true, true, false, true, true, false, true, true, false, true], + vec![false, true, true, false, true, true, false, true, true, false], + vec![true, false, true, true, false, true, true, false, true, true], + vec![true, true, false, true, true, false, true, true, false, true], + vec![false, true, true, false, true, true, false, true, true, false], + vec![true, false, true, true, false, true, true, false, true, true], + vec![true, true, false, true, true, false, true, true, false, true], + vec![false, true, true, false, true, true, false, true, true, false], + ], + 3, + Ok(Some(vec![ + vec![0, 1, 2, 0, 1, 2, 0, 1, 2, 0], + vec![0, 2, 1, 0, 2, 1, 0, 2, 1, 0], + vec![1, 0, 2, 1, 0, 2, 1, 0, 2, 1], + vec![1, 2, 0, 1, 2, 0, 1, 2, 0, 1], + vec![2, 0, 1, 2, 0, 1, 2, 0, 1, 2], + vec![2, 1, 0, 2, 1, 0, 2, 1, 0, 2], + ])) + ), + test_disconnected_graph: ( + vec![ + vec![false, false, false], + vec![false, false, false], + vec![false, false, false], + ], + 2, + Ok(Some(vec![ + vec![0, 0, 0], + vec![0, 0, 1], + vec![0, 1, 0], + vec![0, 1, 1], + vec![1, 0, 0], + vec![1, 0, 1], + vec![1, 1, 0], + vec![1, 1, 1], + ])) + ), + test_no_valid_coloring: ( + vec![ + vec![false, true, true], + vec![true, false, true], + vec![true, true, false], + ], + 2, + Ok(None::>>) + ), + test_more_colors_than_nodes: ( + vec![ + vec![true, true], + vec![true, true], + ], + 3, + Ok(Some(vec![ + vec![0, 1], + vec![0, 2], + vec![1, 0], + vec![1, 2], + vec![2, 0], + vec![2, 1], + ])) + ), + test_no_coloring_with_zero_colors: ( + vec![ + vec![true], + ], + 0, + Ok(None::>>) + ), + test_complete_graph_with_3_vertices_and_3_colors: ( + vec![ + vec![false, true, true], + vec![true, false, true], + vec![true, true, false], + ], + 3, + Ok(Some(vec![ + vec![0, 1, 2], + vec![0, 2, 1], + vec![1, 0, 2], + vec![1, 2, 0], + vec![2, 0, 1], + vec![2, 1, 0], + ])) + ), + test_directed_graph_with_3_colors: ( + vec![ + vec![false, true, false, true], + vec![false, false, true, false], + vec![true, false, false, true], + vec![true, false, false, false], + ], + 3, + Ok(Some(vec![ + vec![0, 1, 2, 1], + vec![0, 2, 1, 2], + vec![1, 0, 2, 0], + vec![1, 2, 0, 2], + vec![2, 0, 1, 0], + vec![2, 1, 0, 1], + ])) + ), + test_directed_graph_no_valid_coloring: ( + vec![ + vec![false, true, false, true], + vec![false, false, true, true], + vec![true, false, false, true], + vec![true, false, false, false], + ], + 3, + Ok(None::>>) + ), + test_large_directed_graph_with_3_colors: ( + vec![ + vec![false, true, false, false, true, false, false, true, false, false], + vec![false, false, true, false, false, true, false, false, true, false], + vec![false, false, false, true, false, false, true, false, false, true], + vec![true, false, false, false, true, false, false, true, false, false], + vec![false, true, false, false, false, true, false, false, true, false], + vec![false, false, true, false, false, false, true, false, false, true], + vec![true, false, false, false, true, false, false, true, false, false], + vec![false, true, false, false, false, true, false, false, true, false], + vec![false, false, true, false, false, false, true, false, false, true], + vec![true, false, false, false, true, false, false, true, false, false], + ], + 3, + Ok(Some(vec![ + vec![0, 1, 2, 1, 2, 0, 1, 2, 0, 1], + vec![0, 2, 1, 2, 1, 0, 2, 1, 0, 2], + vec![1, 0, 2, 0, 2, 1, 0, 2, 1, 0], + vec![1, 2, 0, 2, 0, 1, 2, 0, 1, 2], + vec![2, 0, 1, 0, 1, 2, 0, 1, 2, 0], + vec![2, 1, 0, 1, 0, 2, 1, 0, 2, 1] + ])) + ), + } +} diff --git a/src/backtracking/hamiltonian_cycle.rs b/src/backtracking/hamiltonian_cycle.rs new file mode 100644 index 00000000000..2eacd5feb3c --- /dev/null +++ b/src/backtracking/hamiltonian_cycle.rs @@ -0,0 +1,310 @@ +//! This module provides functionality to find a Hamiltonian cycle in a directed or undirected graph. +//! Source: [Wikipedia](https://en.wikipedia.org/wiki/Hamiltonian_path_problem) + +/// Represents potential errors when finding hamiltonian cycle on an adjacency matrix. +#[derive(Debug, PartialEq, Eq)] +pub enum FindHamiltonianCycleError { + /// Indicates that the adjacency matrix is empty. + EmptyAdjacencyMatrix, + /// Indicates that the adjacency matrix is not square. + ImproperAdjacencyMatrix, + /// Indicates that the starting vertex is out of bounds. + StartOutOfBound, +} + +/// Represents a graph using an adjacency matrix. +struct Graph { + /// The adjacency matrix representing the graph. + adjacency_matrix: Vec>, +} + +impl Graph { + /// Creates a new graph with the provided adjacency matrix. + /// + /// # Arguments + /// + /// * `adjacency_matrix` - A square matrix where each element indicates + /// the presence (`true`) or absence (`false`) of an edge + /// between two vertices. + /// + /// # Returns + /// + /// A `Result` containing the graph if successful, or an `FindHamiltonianCycleError` if there is an issue with the matrix. + fn new(adjacency_matrix: Vec>) -> Result { + // Check if the adjacency matrix is empty. + if adjacency_matrix.is_empty() { + return Err(FindHamiltonianCycleError::EmptyAdjacencyMatrix); + } + + // Validate that the adjacency matrix is square. + if adjacency_matrix + .iter() + .any(|row| row.len() != adjacency_matrix.len()) + { + return Err(FindHamiltonianCycleError::ImproperAdjacencyMatrix); + } + + Ok(Self { adjacency_matrix }) + } + + /// Returns the number of vertices in the graph. + fn num_vertices(&self) -> usize { + self.adjacency_matrix.len() + } + + /// Determines if it is safe to include vertex `v` in the Hamiltonian cycle path. + /// + /// # Arguments + /// + /// * `v` - The index of the vertex being considered. + /// * `visited` - A reference to the vector representing the visited vertices. + /// * `path` - A reference to the current path being explored. + /// * `pos` - The position of the current vertex being considered. + /// + /// # Returns + /// + /// `true` if it is safe to include `v` in the path, `false` otherwise. + fn is_safe(&self, v: usize, visited: &[bool], path: &[Option], pos: usize) -> bool { + // Check if the current vertex and the last vertex in the path are adjacent. + if !self.adjacency_matrix[path[pos - 1].unwrap()][v] { + return false; + } + + // Check if the vertex has already been included in the path. + !visited[v] + } + + /// Recursively searches for a Hamiltonian cycle. + /// + /// This function is called by `find_hamiltonian_cycle`. + /// + /// # Arguments + /// + /// * `path` - A mutable vector representing the current path being explored. + /// * `visited` - A mutable vector representing the visited vertices. + /// * `pos` - The position of the current vertex being considered. + /// + /// # Returns + /// + /// `true` if a Hamiltonian cycle is found, `false` otherwise. + fn hamiltonian_cycle_util( + &self, + path: &mut [Option], + visited: &mut [bool], + pos: usize, + ) -> bool { + if pos == self.num_vertices() { + // Check if there is an edge from the last included vertex to the first vertex. + return self.adjacency_matrix[path[pos - 1].unwrap()][path[0].unwrap()]; + } + + for v in 0..self.num_vertices() { + if self.is_safe(v, visited, path, pos) { + path[pos] = Some(v); + visited[v] = true; + if self.hamiltonian_cycle_util(path, visited, pos + 1) { + return true; + } + path[pos] = None; + visited[v] = false; + } + } + + false + } + + /// Attempts to find a Hamiltonian cycle in the graph, starting from the specified vertex. + /// + /// A Hamiltonian cycle visits every vertex exactly once and returns to the starting vertex. + /// + /// # Note + /// This implementation may not find all possible Hamiltonian cycles. + /// It stops as soon as it finds one valid cycle. If multiple Hamiltonian cycles exist, + /// only one will be returned. + /// + /// # Returns + /// + /// `Ok(Some(path))` if a Hamiltonian cycle is found, where `path` is a vector + /// containing the indices of vertices in the cycle, starting and ending with the same vertex. + /// + /// `Ok(None)` if no Hamiltonian cycle exists. + fn find_hamiltonian_cycle( + &self, + start_vertex: usize, + ) -> Result>, FindHamiltonianCycleError> { + // Validate the start vertex. + if start_vertex >= self.num_vertices() { + return Err(FindHamiltonianCycleError::StartOutOfBound); + } + + // Initialize the path. + let mut path = vec![None; self.num_vertices()]; + // Start at the specified vertex. + path[0] = Some(start_vertex); + + // Initialize the visited vector. + let mut visited = vec![false; self.num_vertices()]; + visited[start_vertex] = true; + + if self.hamiltonian_cycle_util(&mut path, &mut visited, 1) { + // Complete the cycle by returning to the starting vertex. + path.push(Some(start_vertex)); + Ok(Some(path.into_iter().map(Option::unwrap).collect())) + } else { + Ok(None) + } + } +} + +/// Attempts to find a Hamiltonian cycle in a graph represented by an adjacency matrix, starting from a specified vertex. +pub fn find_hamiltonian_cycle( + adjacency_matrix: Vec>, + start_vertex: usize, +) -> Result>, FindHamiltonianCycleError> { + Graph::new(adjacency_matrix)?.find_hamiltonian_cycle(start_vertex) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! hamiltonian_cycle_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (adjacency_matrix, start_vertex, expected) = $test_case; + let result = find_hamiltonian_cycle(adjacency_matrix, start_vertex); + assert_eq!(result, expected); + } + )* + }; + } + + hamiltonian_cycle_tests! { + test_complete_graph: ( + vec![ + vec![false, true, true, true], + vec![true, false, true, true], + vec![true, true, false, true], + vec![true, true, true, false], + ], + 0, + Ok(Some(vec![0, 1, 2, 3, 0])) + ), + test_directed_graph_with_cycle: ( + vec![ + vec![false, true, false, false, false], + vec![false, false, true, true, false], + vec![true, false, false, true, true], + vec![false, false, true, false, true], + vec![true, true, false, false, false], + ], + 2, + Ok(Some(vec![2, 3, 4, 0, 1, 2])) + ), + test_undirected_graph_with_cycle: ( + vec![ + vec![false, true, false, false, true], + vec![true, false, true, false, false], + vec![false, true, false, true, false], + vec![false, false, true, false, true], + vec![true, false, false, true, false], + ], + 2, + Ok(Some(vec![2, 1, 0, 4, 3, 2])) + ), + test_directed_graph_no_cycle: ( + vec![ + vec![false, true, false, true, false], + vec![false, false, true, true, false], + vec![false, false, false, true, false], + vec![false, false, false, false, true], + vec![false, false, true, false, false], + ], + 0, + Ok(None::>) + ), + test_undirected_graph_no_cycle: ( + vec![ + vec![false, true, false, false, false], + vec![true, false, true, true, false], + vec![false, true, false, true, true], + vec![false, true, true, false, true], + vec![false, false, true, true, false], + ], + 0, + Ok(None::>) + ), + test_triangle_graph: ( + vec![ + vec![false, true, false], + vec![false, false, true], + vec![true, false, false], + ], + 1, + Ok(Some(vec![1, 2, 0, 1])) + ), + test_tree_graph: ( + vec![ + vec![false, true, false, true, false], + vec![true, false, true, true, false], + vec![false, true, false, false, false], + vec![true, true, false, false, true], + vec![false, false, false, true, false], + ], + 0, + Ok(None::>) + ), + test_empty_graph: ( + vec![], + 0, + Err(FindHamiltonianCycleError::EmptyAdjacencyMatrix) + ), + test_improper_graph: ( + vec![ + vec![false, true], + vec![true], + vec![false, true, true], + vec![true, true, true, false] + ], + 0, + Err(FindHamiltonianCycleError::ImproperAdjacencyMatrix) + ), + test_start_out_of_bound: ( + vec![ + vec![false, true, true], + vec![true, false, true], + vec![true, true, false], + ], + 3, + Err(FindHamiltonianCycleError::StartOutOfBound) + ), + test_complex_directed_graph: ( + vec![ + vec![false, true, false, true, false, false], + vec![false, false, true, false, true, false], + vec![false, false, false, true, false, false], + vec![false, true, false, false, true, false], + vec![false, false, true, false, false, true], + vec![true, false, false, false, false, false], + ], + 0, + Ok(Some(vec![0, 1, 2, 3, 4, 5, 0])) + ), + single_node_self_loop: ( + vec![ + vec![true], + ], + 0, + Ok(Some(vec![0, 0])) + ), + single_node: ( + vec![ + vec![false], + ], + 0, + Ok(None) + ), + } +} diff --git a/src/backtracking/knight_tour.rs b/src/backtracking/knight_tour.rs new file mode 100644 index 00000000000..26e9e36d682 --- /dev/null +++ b/src/backtracking/knight_tour.rs @@ -0,0 +1,195 @@ +//! This module contains the implementation of the Knight's Tour problem. +//! +//! The Knight's Tour is a classic chess problem where the objective is to move a knight to every square on a chessboard exactly once. + +/// Finds the Knight's Tour starting from the specified position. +/// +/// # Arguments +/// +/// * `size_x` - The width of the chessboard. +/// * `size_y` - The height of the chessboard. +/// * `start_x` - The x-coordinate of the starting position. +/// * `start_y` - The y-coordinate of the starting position. +/// +/// # Returns +/// +/// A tour matrix if the tour was found or None if not found. +/// The tour matrix returned is essentially the board field of the `KnightTour` +/// struct `Vec>`. It represents the sequence of moves made by the +/// knight on the chessboard, with each cell containing the order in which the knight visited that square. +pub fn find_knight_tour( + size_x: usize, + size_y: usize, + start_x: usize, + start_y: usize, +) -> Option>> { + let mut tour = KnightTour::new(size_x, size_y); + tour.find_tour(start_x, start_y) +} + +/// Represents the KnightTour struct which implements the Knight's Tour problem. +struct KnightTour { + board: Vec>, +} + +impl KnightTour { + /// Possible moves of the knight on the board + const MOVES: [(isize, isize); 8] = [ + (2, 1), + (1, 2), + (-1, 2), + (-2, 1), + (-2, -1), + (-1, -2), + (1, -2), + (2, -1), + ]; + + /// Constructs a new KnightTour instance with the given board size. + /// # Arguments + /// + /// * `size_x` - The width of the chessboard. + /// * `size_y` - The height of the chessboard. + /// + /// # Returns + /// + /// A new KnightTour instance. + fn new(size_x: usize, size_y: usize) -> Self { + let board = vec![vec![0; size_x]; size_y]; + KnightTour { board } + } + + /// Returns the width of the chessboard. + fn size_x(&self) -> usize { + self.board.len() + } + + /// Returns the height of the chessboard. + fn size_y(&self) -> usize { + self.board[0].len() + } + + /// Checks if the given position is safe to move to. + /// + /// # Arguments + /// + /// * `x` - The x-coordinate of the position. + /// * `y` - The y-coordinate of the position. + /// + /// # Returns + /// + /// A boolean indicating whether the position is safe to move to. + fn is_safe(&self, x: isize, y: isize) -> bool { + x >= 0 + && y >= 0 + && x < self.size_x() as isize + && y < self.size_y() as isize + && self.board[x as usize][y as usize] == 0 + } + + /// Recursively solves the Knight's Tour problem. + /// + /// # Arguments + /// + /// * `x` - The current x-coordinate of the knight. + /// * `y` - The current y-coordinate of the knight. + /// * `move_count` - The current move count. + /// + /// # Returns + /// + /// A boolean indicating whether a solution was found. + fn solve_tour(&mut self, x: isize, y: isize, move_count: usize) -> bool { + if move_count == self.size_x() * self.size_y() { + return true; + } + for &(dx, dy) in &Self::MOVES { + let next_x = x + dx; + let next_y = y + dy; + + if self.is_safe(next_x, next_y) { + self.board[next_x as usize][next_y as usize] = move_count + 1; + + if self.solve_tour(next_x, next_y, move_count + 1) { + return true; + } + // Backtrack + self.board[next_x as usize][next_y as usize] = 0; + } + } + + false + } + + /// Finds the Knight's Tour starting from the specified position. + /// + /// # Arguments + /// + /// * `start_x` - The x-coordinate of the starting position. + /// * `start_y` - The y-coordinate of the starting position. + /// + /// # Returns + /// + /// A tour matrix if the tour was found or None if not found. + fn find_tour(&mut self, start_x: usize, start_y: usize) -> Option>> { + if !self.is_safe(start_x as isize, start_y as isize) { + return None; + } + + self.board[start_x][start_y] = 1; + + if !self.solve_tour(start_x as isize, start_y as isize, 1) { + return None; + } + + Some(self.board.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_find_knight_tour { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (size_x, size_y, start_x, start_y, expected) = $tc; + if expected.is_some() { + assert_eq!(expected.clone().unwrap()[start_x][start_y], 1) + } + assert_eq!(find_knight_tour(size_x, size_y, start_x, start_y), expected); + } + )* + } + } + test_find_knight_tour! { + test_knight_tour_5x5: (5, 5, 0, 0, Some(vec![ + vec![1, 6, 15, 10, 21], + vec![14, 9, 20, 5, 16], + vec![19, 2, 7, 22, 11], + vec![8, 13, 24, 17, 4], + vec![25, 18, 3, 12, 23], + ])), + test_knight_tour_6x6: (6, 6, 0, 0, Some(vec![ + vec![1, 16, 7, 26, 11, 14], + vec![34, 25, 12, 15, 6, 27], + vec![17, 2, 33, 8, 13, 10], + vec![32, 35, 24, 21, 28, 5], + vec![23, 18, 3, 30, 9, 20], + vec![36, 31, 22, 19, 4, 29], + ])), + test_knight_tour_8x8: (8, 8, 0, 0, Some(vec![ + vec![1, 60, 39, 34, 31, 18, 9, 64], + vec![38, 35, 32, 61, 10, 63, 30, 17], + vec![59, 2, 37, 40, 33, 28, 19, 8], + vec![36, 49, 42, 27, 62, 11, 16, 29], + vec![43, 58, 3, 50, 41, 24, 7, 20], + vec![48, 51, 46, 55, 26, 21, 12, 15], + vec![57, 44, 53, 4, 23, 14, 25, 6], + vec![52, 47, 56, 45, 54, 5, 22, 13], + ])), + test_no_solution: (5, 5, 2, 1, None::>>), + test_invalid_start_position: (8, 8, 10, 10, None::>>), + } +} diff --git a/src/backtracking/mod.rs b/src/backtracking/mod.rs new file mode 100644 index 00000000000..182c6fbbc01 --- /dev/null +++ b/src/backtracking/mod.rs @@ -0,0 +1,21 @@ +mod all_combination_of_size_k; +mod graph_coloring; +mod hamiltonian_cycle; +mod knight_tour; +mod n_queens; +mod parentheses_generator; +mod permutations; +mod rat_in_maze; +mod subset_sum; +mod sudoku; + +pub use all_combination_of_size_k::generate_all_combinations; +pub use graph_coloring::generate_colorings; +pub use hamiltonian_cycle::find_hamiltonian_cycle; +pub use knight_tour::find_knight_tour; +pub use n_queens::n_queens_solver; +pub use parentheses_generator::generate_parentheses; +pub use permutations::permute; +pub use rat_in_maze::find_path_in_maze; +pub use subset_sum::has_subset_with_sum; +pub use sudoku::sudoku_solver; diff --git a/src/backtracking/n_queens.rs b/src/backtracking/n_queens.rs new file mode 100644 index 00000000000..234195e97ea --- /dev/null +++ b/src/backtracking/n_queens.rs @@ -0,0 +1,221 @@ +//! This module provides functionality to solve the N-Queens problem. +//! +//! The N-Queens problem is a classic chessboard puzzle where the goal is to +//! place N queens on an NxN chessboard so that no two queens threaten each +//! other. Queens can attack each other if they share the same row, column, or +//! diagonal. +//! +//! This implementation solves the N-Queens problem using a backtracking algorithm. +//! It starts with an empty chessboard and iteratively tries to place queens in +//! different rows, ensuring they do not conflict with each other. If a valid +//! solution is found, it's added to the list of solutions. + +/// Solves the N-Queens problem for a given size and returns a vector of solutions. +/// +/// # Arguments +/// +/// * `n` - The size of the chessboard (NxN). +/// +/// # Returns +/// +/// A vector containing all solutions to the N-Queens problem. +pub fn n_queens_solver(n: usize) -> Vec> { + let mut solver = NQueensSolver::new(n); + solver.solve() +} + +/// Represents a solver for the N-Queens problem. +struct NQueensSolver { + // The size of the chessboard + size: usize, + // A 2D vector representing the chessboard where '.' denotes an empty space and 'Q' denotes a queen + board: Vec>, + // A vector to store all valid solutions + solutions: Vec>, +} + +impl NQueensSolver { + /// Creates a new `NQueensSolver` instance with the given size. + /// + /// # Arguments + /// + /// * `size` - The size of the chessboard (N×N). + /// + /// # Returns + /// + /// A new `NQueensSolver` instance. + fn new(size: usize) -> Self { + NQueensSolver { + size, + board: vec![vec!['.'; size]; size], + solutions: Vec::new(), + } + } + + /// Solves the N-Queens problem and returns a vector of solutions. + /// + /// # Returns + /// + /// A vector containing all solutions to the N-Queens problem. + fn solve(&mut self) -> Vec> { + self.solve_helper(0); + std::mem::take(&mut self.solutions) + } + + /// Checks if it's safe to place a queen at the specified position (row, col). + /// + /// # Arguments + /// + /// * `row` - The row index of the position to check. + /// * `col` - The column index of the position to check. + /// + /// # Returns + /// + /// `true` if it's safe to place a queen at the specified position, `false` otherwise. + fn is_safe(&self, row: usize, col: usize) -> bool { + // Check column and diagonals + for i in 0..row { + if self.board[i][col] == 'Q' + || (col >= row - i && self.board[i][col - (row - i)] == 'Q') + || (col + row - i < self.size && self.board[i][col + (row - i)] == 'Q') + { + return false; + } + } + true + } + + /// Recursive helper function to solve the N-Queens problem. + /// + /// # Arguments + /// + /// * `row` - The current row being processed. + fn solve_helper(&mut self, row: usize) { + if row == self.size { + self.solutions + .push(self.board.iter().map(|row| row.iter().collect()).collect()); + return; + } + + for col in 0..self.size { + if self.is_safe(row, col) { + self.board[row][col] = 'Q'; + self.solve_helper(row + 1); + self.board[row][col] = '.'; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_n_queens_solver { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (n, expected_solutions) = $tc; + let solutions = n_queens_solver(n); + assert_eq!(solutions, expected_solutions); + } + )* + }; + } + + test_n_queens_solver! { + test_0_queens: (0, vec![Vec::::new()]), + test_1_queen: (1, vec![vec!["Q"]]), + test_2_queens:(2, Vec::>::new()), + test_3_queens:(3, Vec::>::new()), + test_4_queens: (4, vec![ + vec![".Q..", + "...Q", + "Q...", + "..Q."], + vec!["..Q.", + "Q...", + "...Q", + ".Q.."], + ]), + test_5_queens:(5, vec![ + vec!["Q....", + "..Q..", + "....Q", + ".Q...", + "...Q."], + vec!["Q....", + "...Q.", + ".Q...", + "....Q", + "..Q.."], + vec![".Q...", + "...Q.", + "Q....", + "..Q..", + "....Q"], + vec![".Q...", + "....Q", + "..Q..", + "Q....", + "...Q."], + vec!["..Q..", + "Q....", + "...Q.", + ".Q...", + "....Q"], + vec!["..Q..", + "....Q", + ".Q...", + "...Q.", + "Q...."], + vec!["...Q.", + "Q....", + "..Q..", + "....Q", + ".Q..."], + vec!["...Q.", + ".Q...", + "....Q", + "..Q..", + "Q...."], + vec!["....Q", + ".Q...", + "...Q.", + "Q....", + "..Q.."], + vec!["....Q", + "..Q..", + "Q....", + "...Q.", + ".Q..."], + ]), + test_6_queens: (6, vec![ + vec![".Q....", + "...Q..", + ".....Q", + "Q.....", + "..Q...", + "....Q."], + vec!["..Q...", + ".....Q", + ".Q....", + "....Q.", + "Q.....", + "...Q.."], + vec!["...Q..", + "Q.....", + "....Q.", + ".Q....", + ".....Q", + "..Q..."], + vec!["....Q.", + "..Q...", + "Q.....", + ".....Q", + "...Q..", + ".Q...."], + ]), + } +} diff --git a/src/backtracking/parentheses_generator.rs b/src/backtracking/parentheses_generator.rs new file mode 100644 index 00000000000..9aafe81fa7a --- /dev/null +++ b/src/backtracking/parentheses_generator.rs @@ -0,0 +1,76 @@ +/// Generates all combinations of well-formed parentheses given a non-negative integer `n`. +/// +/// This function uses backtracking to generate all possible combinations of well-formed +/// parentheses. The resulting combinations are returned as a vector of strings. +/// +/// # Arguments +/// +/// * `n` - A non-negative integer representing the number of pairs of parentheses. +pub fn generate_parentheses(n: usize) -> Vec { + let mut result = Vec::new(); + if n > 0 { + generate("", 0, 0, n, &mut result); + } + result +} + +/// Helper function for generating parentheses recursively. +/// +/// This function is called recursively to build combinations of well-formed parentheses. +/// It tracks the number of open and close parentheses added so far and adds a new parenthesis +/// if it's valid to do so. +/// +/// # Arguments +/// +/// * `current` - The current string of parentheses being built. +/// * `open_count` - The count of open parentheses in the current string. +/// * `close_count` - The count of close parentheses in the current string. +/// * `n` - The total number of pairs of parentheses to be generated. +/// * `result` - A mutable reference to the vector storing the generated combinations. +fn generate( + current: &str, + open_count: usize, + close_count: usize, + n: usize, + result: &mut Vec, +) { + if current.len() == (n * 2) { + result.push(current.to_string()); + return; + } + + if open_count < n { + let new_str = current.to_string() + "("; + generate(&new_str, open_count + 1, close_count, n, result); + } + + if close_count < open_count { + let new_str = current.to_string() + ")"; + generate(&new_str, open_count, close_count + 1, n, result); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! generate_parentheses_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (n, expected_result) = $test_case; + assert_eq!(generate_parentheses(n), expected_result); + } + )* + }; + } + + generate_parentheses_tests! { + test_generate_parentheses_0: (0, Vec::::new()), + test_generate_parentheses_1: (1, vec!["()"]), + test_generate_parentheses_2: (2, vec!["(())", "()()"]), + test_generate_parentheses_3: (3, vec!["((()))", "(()())", "(())()", "()(())", "()()()"]), + test_generate_parentheses_4: (4, vec!["(((())))", "((()()))", "((())())", "((()))()", "(()(()))", "(()()())", "(()())()", "(())(())", "(())()()", "()((()))", "()(()())", "()(())()", "()()(())", "()()()()"]), + } +} diff --git a/src/backtracking/permutations.rs b/src/backtracking/permutations.rs new file mode 100644 index 00000000000..8859a633310 --- /dev/null +++ b/src/backtracking/permutations.rs @@ -0,0 +1,141 @@ +//! This module provides a function to generate all possible distinct permutations +//! of a given collection of integers using a backtracking algorithm. + +/// Generates all possible distinct permutations of a given vector of integers. +/// +/// # Arguments +/// +/// * `nums` - A vector of integers. The input vector is sorted before generating +/// permutations to handle duplicates effectively. +/// +/// # Returns +/// +/// A vector containing all possible distinct permutations of the input vector. +pub fn permute(mut nums: Vec) -> Vec> { + let mut permutations = Vec::new(); + let mut current = Vec::new(); + let mut used = vec![false; nums.len()]; + + nums.sort(); + generate(&nums, &mut current, &mut used, &mut permutations); + + permutations +} + +/// Helper function for the `permute` function to generate distinct permutations recursively. +/// +/// # Arguments +/// +/// * `nums` - A reference to the sorted slice of integers. +/// * `current` - A mutable reference to the vector holding the current permutation. +/// * `used` - A mutable reference to a vector tracking which elements are used. +/// * `permutations` - A mutable reference to the vector holding all generated distinct permutations. +fn generate( + nums: &[isize], + current: &mut Vec, + used: &mut Vec, + permutations: &mut Vec>, +) { + if current.len() == nums.len() { + permutations.push(current.clone()); + return; + } + + for idx in 0..nums.len() { + if used[idx] { + continue; + } + + if idx > 0 && nums[idx] == nums[idx - 1] && !used[idx - 1] { + continue; + } + + current.push(nums[idx]); + used[idx] = true; + + generate(nums, current, used, permutations); + + current.pop(); + used[idx] = false; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! permute_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(permute(input), expected); + } + )* + } + } + + permute_tests! { + test_permute_basic: (vec![1, 2, 3], vec![ + vec![1, 2, 3], + vec![1, 3, 2], + vec![2, 1, 3], + vec![2, 3, 1], + vec![3, 1, 2], + vec![3, 2, 1], + ]), + test_permute_empty: (Vec::::new(), vec![vec![]]), + test_permute_single: (vec![1], vec![vec![1]]), + test_permute_duplicates: (vec![1, 1, 2], vec![ + vec![1, 1, 2], + vec![1, 2, 1], + vec![2, 1, 1], + ]), + test_permute_all_duplicates: (vec![1, 1, 1, 1], vec![ + vec![1, 1, 1, 1], + ]), + test_permute_negative: (vec![-1, -2, -3], vec![ + vec![-3, -2, -1], + vec![-3, -1, -2], + vec![-2, -3, -1], + vec![-2, -1, -3], + vec![-1, -3, -2], + vec![-1, -2, -3], + ]), + test_permute_mixed: (vec![-1, 0, 1], vec![ + vec![-1, 0, 1], + vec![-1, 1, 0], + vec![0, -1, 1], + vec![0, 1, -1], + vec![1, -1, 0], + vec![1, 0, -1], + ]), + test_permute_larger: (vec![1, 2, 3, 4], vec![ + vec![1, 2, 3, 4], + vec![1, 2, 4, 3], + vec![1, 3, 2, 4], + vec![1, 3, 4, 2], + vec![1, 4, 2, 3], + vec![1, 4, 3, 2], + vec![2, 1, 3, 4], + vec![2, 1, 4, 3], + vec![2, 3, 1, 4], + vec![2, 3, 4, 1], + vec![2, 4, 1, 3], + vec![2, 4, 3, 1], + vec![3, 1, 2, 4], + vec![3, 1, 4, 2], + vec![3, 2, 1, 4], + vec![3, 2, 4, 1], + vec![3, 4, 1, 2], + vec![3, 4, 2, 1], + vec![4, 1, 2, 3], + vec![4, 1, 3, 2], + vec![4, 2, 1, 3], + vec![4, 2, 3, 1], + vec![4, 3, 1, 2], + vec![4, 3, 2, 1], + ]), + } +} diff --git a/src/backtracking/rat_in_maze.rs b/src/backtracking/rat_in_maze.rs new file mode 100644 index 00000000000..fb658697b39 --- /dev/null +++ b/src/backtracking/rat_in_maze.rs @@ -0,0 +1,327 @@ +//! This module contains the implementation of the Rat in Maze problem. +//! +//! The Rat in Maze problem is a classic algorithmic problem where the +//! objective is to find a path from the starting position to the exit +//! position in a maze. + +/// Enum representing various errors that can occur while working with mazes. +#[derive(Debug, PartialEq, Eq)] +pub enum MazeError { + /// Indicates that the maze is empty (zero rows). + EmptyMaze, + /// Indicates that the starting position is out of bounds. + OutOfBoundPos, + /// Indicates an improper representation of the maze (e.g., non-rectangular maze). + ImproperMazeRepr, +} + +/// Finds a path through the maze starting from the specified position. +/// +/// # Arguments +/// +/// * `maze` - The maze represented as a vector of vectors where each +/// inner vector represents a row in the maze grid. +/// * `start_x` - The x-coordinate of the starting position. +/// * `start_y` - The y-coordinate of the starting position. +/// +/// # Returns +/// +/// A `Result` where: +/// - `Ok(Some(solution))` if a path is found and contains the solution matrix. +/// - `Ok(None)` if no path is found. +/// - `Err(MazeError)` for various error conditions such as out-of-bound start position or improper maze representation. +/// +/// # Solution Selection +/// +/// The function returns the first successful path it discovers based on the predefined order of moves. +/// The order of moves is defined in the `MOVES` constant of the `Maze` struct. +/// +/// The backtracking algorithm explores each direction in this order. If multiple solutions exist, +/// the algorithm returns the first path it finds according to this sequence. It recursively explores +/// each direction, marks valid moves, and backtracks if necessary, ensuring that the solution is found +/// efficiently and consistently. +pub fn find_path_in_maze( + maze: &[Vec], + start_x: usize, + start_y: usize, +) -> Result>>, MazeError> { + if maze.is_empty() { + return Err(MazeError::EmptyMaze); + } + + // Validate start position + if start_x >= maze.len() || start_y >= maze[0].len() { + return Err(MazeError::OutOfBoundPos); + } + + // Validate maze representation (if necessary) + if maze.iter().any(|row| row.len() != maze[0].len()) { + return Err(MazeError::ImproperMazeRepr); + } + + // If validations pass, proceed with finding the path + let maze_instance = Maze::new(maze.to_owned()); + Ok(maze_instance.find_path(start_x, start_y)) +} + +/// Represents a maze. +struct Maze { + maze: Vec>, +} + +impl Maze { + /// Represents possible moves in the maze. + const MOVES: [(isize, isize); 4] = [(0, 1), (1, 0), (0, -1), (-1, 0)]; + + /// Constructs a new Maze instance. + /// # Arguments + /// + /// * `maze` - The maze represented as a vector of vectors where each + /// inner vector represents a row in the maze grid. + /// + /// # Returns + /// + /// A new Maze instance. + fn new(maze: Vec>) -> Self { + Maze { maze } + } + + /// Returns the width of the maze. + /// + /// # Returns + /// + /// The width of the maze. + fn width(&self) -> usize { + self.maze[0].len() + } + + /// Returns the height of the maze. + /// + /// # Returns + /// + /// The height of the maze. + fn height(&self) -> usize { + self.maze.len() + } + + /// Finds a path through the maze starting from the specified position. + /// + /// # Arguments + /// + /// * `start_x` - The x-coordinate of the starting position. + /// * `start_y` - The y-coordinate of the starting position. + /// + /// # Returns + /// + /// A solution matrix if a path is found or None if not found. + fn find_path(&self, start_x: usize, start_y: usize) -> Option>> { + let mut solution = vec![vec![false; self.width()]; self.height()]; + if self.solve(start_x as isize, start_y as isize, &mut solution) { + Some(solution) + } else { + None + } + } + + /// Recursively solves the Rat in Maze problem using backtracking. + /// + /// # Arguments + /// + /// * `x` - The current x-coordinate. + /// * `y` - The current y-coordinate. + /// * `solution` - The current solution matrix. + /// + /// # Returns + /// + /// A boolean indicating whether a solution was found. + fn solve(&self, x: isize, y: isize, solution: &mut [Vec]) -> bool { + if x == (self.height() as isize - 1) && y == (self.width() as isize - 1) { + solution[x as usize][y as usize] = true; + return true; + } + + if self.is_valid(x, y, solution) { + solution[x as usize][y as usize] = true; + + for &(dx, dy) in &Self::MOVES { + if self.solve(x + dx, y + dy, solution) { + return true; + } + } + + // If none of the directions lead to the solution, backtrack + solution[x as usize][y as usize] = false; + return false; + } + false + } + + /// Checks if a given position is valid in the maze. + /// + /// # Arguments + /// + /// * `x` - The x-coordinate of the position. + /// * `y` - The y-coordinate of the position. + /// * `solution` - The current solution matrix. + /// + /// # Returns + /// + /// A boolean indicating whether the position is valid. + fn is_valid(&self, x: isize, y: isize, solution: &[Vec]) -> bool { + x >= 0 + && y >= 0 + && x < self.height() as isize + && y < self.width() as isize + && self.maze[x as usize][y as usize] + && !solution[x as usize][y as usize] + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_find_path_in_maze { + ($($name:ident: $start_x:expr, $start_y:expr, $maze:expr, $expected:expr,)*) => { + $( + #[test] + fn $name() { + let solution = find_path_in_maze($maze, $start_x, $start_y); + assert_eq!(solution, $expected); + if let Ok(Some(expected_solution)) = &solution { + assert_eq!(expected_solution[$start_x][$start_y], true); + } + } + )* + } + } + + test_find_path_in_maze! { + maze_with_solution_5x5: 0, 0, &[ + vec![true, false, true, false, false], + vec![true, true, false, true, false], + vec![false, true, true, true, false], + vec![false, false, false, true, true], + vec![false, true, false, false, true], + ], Ok(Some(vec![ + vec![true, false, false, false, false], + vec![true, true, false, false, false], + vec![false, true, true, true, false], + vec![false, false, false, true, true], + vec![false, false, false, false, true], + ])), + maze_with_solution_6x6: 0, 0, &[ + vec![true, false, true, false, true, false], + vec![true, true, false, true, false, true], + vec![false, true, true, true, true, false], + vec![false, false, false, true, true, true], + vec![false, true, false, false, true, false], + vec![true, true, true, true, true, true], + ], Ok(Some(vec![ + vec![true, false, false, false, false, false], + vec![true, true, false, false, false, false], + vec![false, true, true, true, true, false], + vec![false, false, false, false, true, false], + vec![false, false, false, false, true, false], + vec![false, false, false, false, true, true], + ])), + maze_with_solution_8x8: 0, 0, &[ + vec![true, false, false, false, false, false, false, true], + vec![true, true, false, true, true, true, false, false], + vec![false, true, true, true, false, false, false, false], + vec![false, false, false, true, false, true, true, false], + vec![false, true, false, true, true, true, false, true], + vec![true, false, true, false, false, true, true, true], + vec![false, false, true, true, true, false, true, true], + vec![true, true, true, false, true, true, true, true], + ], Ok(Some(vec![ + vec![true, false, false, false, false, false, false, false], + vec![true, true, false, false, false, false, false, false], + vec![false, true, true, true, false, false, false, false], + vec![false, false, false, true, false, false, false, false], + vec![false, false, false, true, true, true, false, false], + vec![false, false, false, false, false, true, true, true], + vec![false, false, false, false, false, false, false, true], + vec![false, false, false, false, false, false, false, true], + ])), + maze_without_solution_4x4: 0, 0, &[ + vec![true, false, false, false], + vec![true, true, false, false], + vec![false, false, true, false], + vec![false, false, false, true], + ], Ok(None::>>), + maze_with_solution_3x4: 0, 0, &[ + vec![true, false, true, true], + vec![true, true, true, false], + vec![false, true, true, true], + ], Ok(Some(vec![ + vec![true, false, false, false], + vec![true, true, true, false], + vec![false, false, true, true], + ])), + maze_without_solution_3x4: 0, 0, &[ + vec![true, false, true, true], + vec![true, false, true, false], + vec![false, true, false, true], + ], Ok(None::>>), + improper_maze_representation: 0, 0, &[ + vec![true], + vec![true, true], + vec![true, true, true], + vec![true, true, true, true] + ], Err(MazeError::ImproperMazeRepr), + out_of_bound_start: 0, 3, &[ + vec![true, false, true], + vec![true, true], + vec![false, true, true], + ], Err(MazeError::OutOfBoundPos), + empty_maze: 0, 0, &[], Err(MazeError::EmptyMaze), + maze_with_single_cell: 0, 0, &[ + vec![true], + ], Ok(Some(vec![ + vec![true] + ])), + maze_with_one_row_and_multiple_columns: 0, 0, &[ + vec![true, false, true, true, false] + ], Ok(None::>>), + maze_with_multiple_rows_and_one_column: 0, 0, &[ + vec![true], + vec![true], + vec![false], + vec![true], + ], Ok(None::>>), + maze_with_walls_surrounding_border: 0, 0, &[ + vec![false, false, false], + vec![false, true, false], + vec![false, false, false], + ], Ok(None::>>), + maze_with_no_walls: 0, 0, &[ + vec![true, true, true], + vec![true, true, true], + vec![true, true, true], + ], Ok(Some(vec![ + vec![true, true, true], + vec![false, false, true], + vec![false, false, true], + ])), + maze_with_going_back: 0, 0, &[ + vec![true, true, true, true, true, true], + vec![false, false, false, true, false, true], + vec![true, true, true, true, false, false], + vec![true, false, false, false, false, false], + vec![true, false, false, false, true, true], + vec![true, false, true, true, true, false], + vec![true, false, true , false, true, false], + vec![true, true, true, false, true, true], + ], Ok(Some(vec![ + vec![true, true, true, true, false, false], + vec![false, false, false, true, false, false], + vec![true, true, true, true, false, false], + vec![true, false, false, false, false, false], + vec![true, false, false, false, false, false], + vec![true, false, true, true, true, false], + vec![true, false, true , false, true, false], + vec![true, true, true, false, true, true], + ])), + } +} diff --git a/src/backtracking/subset_sum.rs b/src/backtracking/subset_sum.rs new file mode 100644 index 00000000000..3e69b380b58 --- /dev/null +++ b/src/backtracking/subset_sum.rs @@ -0,0 +1,55 @@ +//! This module provides functionality to check if there exists a subset of a given set of integers +//! that sums to a target value. The implementation uses a recursive backtracking approach. + +/// Checks if there exists a subset of the given set that sums to the target value. +pub fn has_subset_with_sum(set: &[isize], target: isize) -> bool { + backtrack(set, set.len(), target) +} + +fn backtrack(set: &[isize], remaining_items: usize, target: isize) -> bool { + // Found a subset with the required sum + if target == 0 { + return true; + } + // No more elements to process + if remaining_items == 0 { + return false; + } + // Check if we can find a subset including or excluding the last element + backtrack(set, remaining_items - 1, target) + || backtrack(set, remaining_items - 1, target - set[remaining_items - 1]) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! has_subset_with_sum_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (set, target, expected) = $test_case; + assert_eq!(has_subset_with_sum(set, target), expected); + } + )* + } + } + + has_subset_with_sum_tests! { + test_small_set_with_sum: (&[3, 34, 4, 12, 5, 2], 9, true), + test_small_set_without_sum: (&[3, 34, 4, 12, 5, 2], 30, false), + test_consecutive_set_with_sum: (&[1, 2, 3, 4, 5, 6], 10, true), + test_consecutive_set_without_sum: (&[1, 2, 3, 4, 5, 6], 22, false), + test_large_set_with_sum: (&[5, 10, 12, 13, 15, 18, -1, 10, 50, -2, 3, 4], 30, true), + test_empty_set: (&[], 0, true), + test_empty_set_with_nonzero_sum: (&[], 10, false), + test_single_element_equal_to_sum: (&[10], 10, true), + test_single_element_not_equal_to_sum: (&[5], 10, false), + test_negative_set_with_sum: (&[-7, -3, -2, 5, 8], 0, true), + test_negative_sum: (&[1, 2, 3, 4, 5], -1, false), + test_negative_sum_with_negatives: (&[-7, -3, -2, 5, 8], -4, true), + test_negative_sum_with_negatives_no_solution: (&[-7, -3, -2, 5, 8], -14, false), + test_even_inputs_odd_target: (&[2, 4, 6, 2, 8, -2, 10, 12, -24, 8, 12, 18], 3, false), + } +} diff --git a/src/backtracking/sudoku.rs b/src/backtracking/sudoku.rs new file mode 100644 index 00000000000..bb6c13cbde6 --- /dev/null +++ b/src/backtracking/sudoku.rs @@ -0,0 +1,164 @@ +//! A Rust implementation of Sudoku solver using Backtracking. +//! +//! This module provides functionality to solve Sudoku puzzles using the backtracking algorithm. +//! +//! GeeksForGeeks: [Sudoku Backtracking](https://www.geeksforgeeks.org/sudoku-backtracking-7/) + +/// Solves a Sudoku puzzle. +/// +/// Given a partially filled Sudoku puzzle represented by a 9x9 grid, this function attempts to +/// solve the puzzle using the backtracking algorithm. +/// +/// Returns the solved Sudoku board if a solution exists, or `None` if no solution is found. +pub fn sudoku_solver(board: &[[u8; 9]; 9]) -> Option<[[u8; 9]; 9]> { + let mut solver = SudokuSolver::new(*board); + if solver.solve() { + Some(solver.board) + } else { + None + } +} + +/// Represents a Sudoku puzzle solver. +struct SudokuSolver { + /// The Sudoku board represented by a 9x9 grid. + board: [[u8; 9]; 9], +} + +impl SudokuSolver { + /// Creates a new Sudoku puzzle solver with the given board. + fn new(board: [[u8; 9]; 9]) -> SudokuSolver { + SudokuSolver { board } + } + + /// Finds an empty cell in the Sudoku board. + /// + /// Returns the coordinates of an empty cell `(row, column)` if found, or `None` if all cells are filled. + fn find_empty_cell(&self) -> Option<(usize, usize)> { + // Find an empty cell in the board (returns None if all cells are filled) + for row in 0..9 { + for column in 0..9 { + if self.board[row][column] == 0 { + return Some((row, column)); + } + } + } + + None + } + + /// Checks whether a given value can be placed in a specific cell according to Sudoku rules. + /// + /// Returns `true` if the value can be placed in the cell, otherwise `false`. + fn is_value_valid(&self, coordinates: (usize, usize), value: u8) -> bool { + let (row, column) = coordinates; + + // Checks if the value to be added in the board is an acceptable value for the cell + // Checking through the row + for current_column in 0..9 { + if self.board[row][current_column] == value { + return false; + } + } + + // Checking through the column + for current_row in 0..9 { + if self.board[current_row][column] == value { + return false; + } + } + + // Checking through the 3x3 block of the cell + let start_row = row / 3 * 3; + let start_column = column / 3 * 3; + + for current_row in start_row..start_row + 3 { + for current_column in start_column..start_column + 3 { + if self.board[current_row][current_column] == value { + return false; + } + } + } + + true + } + + /// Solves the Sudoku puzzle recursively using backtracking. + /// + /// Returns `true` if a solution is found, otherwise `false`. + fn solve(&mut self) -> bool { + let empty_cell = self.find_empty_cell(); + + if let Some((row, column)) = empty_cell { + for value in 1..=9 { + if self.is_value_valid((row, column), value) { + self.board[row][column] = value; + if self.solve() { + return true; + } + // Backtracking if the board cannot be solved using the current configuration + self.board[row][column] = 0; + } + } + } else { + // If the board is complete + return true; + } + + // Returning false if the board cannot be solved using the current configuration + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_sudoku_solver { + ($($name:ident: $board:expr, $expected:expr,)*) => { + $( + #[test] + fn $name() { + let result = sudoku_solver(&$board); + assert_eq!(result, $expected); + } + )* + }; + } + + test_sudoku_solver! { + test_sudoku_correct: [ + [3, 0, 6, 5, 0, 8, 4, 0, 0], + [5, 2, 0, 0, 0, 0, 0, 0, 0], + [0, 8, 7, 0, 0, 0, 0, 3, 1], + [0, 0, 3, 0, 1, 0, 0, 8, 0], + [9, 0, 0, 8, 6, 3, 0, 0, 5], + [0, 5, 0, 0, 9, 0, 6, 0, 0], + [1, 3, 0, 0, 0, 0, 2, 5, 0], + [0, 0, 0, 0, 0, 0, 0, 7, 4], + [0, 0, 5, 2, 0, 6, 3, 0, 0], + ], Some([ + [3, 1, 6, 5, 7, 8, 4, 9, 2], + [5, 2, 9, 1, 3, 4, 7, 6, 8], + [4, 8, 7, 6, 2, 9, 5, 3, 1], + [2, 6, 3, 4, 1, 5, 9, 8, 7], + [9, 7, 4, 8, 6, 3, 1, 2, 5], + [8, 5, 1, 7, 9, 2, 6, 4, 3], + [1, 3, 8, 9, 4, 7, 2, 5, 6], + [6, 9, 2, 3, 5, 1, 8, 7, 4], + [7, 4, 5, 2, 8, 6, 3, 1, 9], + ]), + + test_sudoku_incorrect: [ + [6, 0, 3, 5, 0, 8, 4, 0, 0], + [5, 2, 0, 0, 0, 0, 0, 0, 0], + [0, 8, 7, 0, 0, 0, 0, 3, 1], + [0, 0, 3, 0, 1, 0, 0, 8, 0], + [9, 0, 0, 8, 6, 3, 0, 0, 5], + [0, 5, 0, 0, 9, 0, 6, 0, 0], + [1, 3, 0, 0, 0, 0, 2, 5, 0], + [0, 0, 0, 0, 0, 0, 0, 7, 4], + [0, 0, 5, 2, 0, 6, 3, 0, 0], + ], None::<[[u8; 9]; 9]>, + } +} diff --git a/src/big_integer/fast_factorial.rs b/src/big_integer/fast_factorial.rs new file mode 100644 index 00000000000..8272f9ee100 --- /dev/null +++ b/src/big_integer/fast_factorial.rs @@ -0,0 +1,87 @@ +// Algorithm created by Peter Borwein in 1985 +// https://doi.org/10.1016/0196-6774(85)90006-9 + +use crate::math::sieve_of_eratosthenes; +use num_bigint::BigUint; +use num_traits::One; +use std::collections::BTreeMap; + +/// Calculate the sum of n / p^i with integer division for all values of i +fn index(p: usize, n: usize) -> usize { + let mut index = 0; + let mut i = 1; + let mut quot = n / p; + + while quot > 0 { + index += quot; + i += 1; + quot = n / p.pow(i); + } + + index +} + +/// Calculate the factorial with time complexity O(log(log(n)) * M(n * log(n))) where M(n) is the time complexity of multiplying two n-digit numbers together. +pub fn fast_factorial(n: usize) -> BigUint { + if n < 2 { + return BigUint::one(); + } + + // get list of primes that will be factors of n! + let primes = sieve_of_eratosthenes(n); + + // Map the primes with their index + let p_indices = primes + .into_iter() + .map(|p| (p, index(p, n))) + .collect::>(); + + let max_bits = p_indices[&2].next_power_of_two().ilog2() + 1; + + // Create a Vec of 1's + let mut a = vec![BigUint::one(); max_bits as usize]; + + // For every prime p, multiply a[i] by p if the ith bit of p's index is 1 + for (p, i) in p_indices { + let mut bit = 1usize; + while bit.ilog2() < max_bits { + if (bit & i) > 0 { + a[bit.ilog2() as usize] *= p; + } + + bit <<= 1; + } + } + + a.into_iter() + .enumerate() + .map(|(i, a_i)| a_i.pow(2u32.pow(i as u32))) // raise every a[i] to the 2^ith power + .product() // we get our answer by multiplying the result +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::math::factorial::factorial_bigmath; + + #[test] + fn fact() { + assert_eq!(fast_factorial(0), BigUint::one()); + assert_eq!(fast_factorial(1), BigUint::one()); + assert_eq!(fast_factorial(2), factorial_bigmath(2)); + assert_eq!(fast_factorial(3), factorial_bigmath(3)); + assert_eq!(fast_factorial(6), factorial_bigmath(6)); + assert_eq!(fast_factorial(7), factorial_bigmath(7)); + assert_eq!(fast_factorial(10), factorial_bigmath(10)); + assert_eq!(fast_factorial(11), factorial_bigmath(11)); + assert_eq!(fast_factorial(18), factorial_bigmath(18)); + assert_eq!(fast_factorial(19), factorial_bigmath(19)); + assert_eq!(fast_factorial(30), factorial_bigmath(30)); + assert_eq!(fast_factorial(34), factorial_bigmath(34)); + assert_eq!(fast_factorial(35), factorial_bigmath(35)); + assert_eq!(fast_factorial(52), factorial_bigmath(52)); + assert_eq!(fast_factorial(100), factorial_bigmath(100)); + assert_eq!(fast_factorial(1000), factorial_bigmath(1000)); + assert_eq!(fast_factorial(5000), factorial_bigmath(5000)); + } +} diff --git a/src/big_integer/mod.rs b/src/big_integer/mod.rs new file mode 100644 index 00000000000..13c6767b36b --- /dev/null +++ b/src/big_integer/mod.rs @@ -0,0 +1,9 @@ +#![cfg(feature = "big-math")] + +mod fast_factorial; +mod multiply; +mod poly1305; + +pub use self::fast_factorial::fast_factorial; +pub use self::multiply::multiply; +pub use self::poly1305::Poly1305; diff --git a/src/big_integer/multiply.rs b/src/big_integer/multiply.rs new file mode 100644 index 00000000000..1f7d1a57de3 --- /dev/null +++ b/src/big_integer/multiply.rs @@ -0,0 +1,77 @@ +/// Performs long multiplication on string representations of non-negative numbers. +pub fn multiply(num1: &str, num2: &str) -> String { + if !is_valid_nonnegative(num1) || !is_valid_nonnegative(num2) { + panic!("String does not conform to specification") + } + + if num1 == "0" || num2 == "0" { + return "0".to_string(); + } + let output_size = num1.len() + num2.len(); + + let mut mult = vec![0; output_size]; + for (i, c1) in num1.chars().rev().enumerate() { + for (j, c2) in num2.chars().rev().enumerate() { + let mul = c1.to_digit(10).unwrap() * c2.to_digit(10).unwrap(); + // It could be a two-digit number here. + mult[i + j + 1] += (mult[i + j] + mul) / 10; + // Handling rounding. Here's a single digit. + mult[i + j] = (mult[i + j] + mul) % 10; + } + } + if mult[output_size - 1] == 0 { + mult.pop(); + } + mult.iter().rev().map(|&n| n.to_string()).collect() +} + +pub fn is_valid_nonnegative(num: &str) -> bool { + num.chars().all(char::is_numeric) && !num.is_empty() && (!num.starts_with('0') || num == "0") +} + +#[cfg(test)] +mod tests { + use super::*; + macro_rules! test_multiply { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (s, t, expected) = $inputs; + assert_eq!(multiply(s, t), expected); + assert_eq!(multiply(t, s), expected); + } + )* + } + } + + test_multiply! { + multiply0: ("2", "3", "6"), + multiply1: ("123", "456", "56088"), + multiply_zero: ("0", "222", "0"), + other_1: ("99", "99", "9801"), + other_2: ("999", "99", "98901"), + other_3: ("9999", "99", "989901"), + other_4: ("192939", "9499596", "1832842552644"), + } + + macro_rules! test_multiply_with_wrong_input { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + #[should_panic] + fn $name() { + let (s, t) = $inputs; + multiply(s, t); + } + )* + } + } + test_multiply_with_wrong_input! { + empty_input: ("", "121"), + leading_zero: ("01", "3"), + wrong_characters: ("2", "12d4"), + wrong_input_and_zero_1: ("0", "x"), + wrong_input_and_zero_2: ("y", "0"), + } +} diff --git a/src/big_integer/poly1305.rs b/src/big_integer/poly1305.rs new file mode 100644 index 00000000000..eae3c016eaa --- /dev/null +++ b/src/big_integer/poly1305.rs @@ -0,0 +1,98 @@ +use num_bigint::BigUint; +use num_traits::Num; +use num_traits::Zero; + +macro_rules! hex_uint { + ($a:literal) => { + BigUint::from_str_radix($a, 16).unwrap() + }; +} + +/** + * Poly1305 Message Authentication Code: + * This implementation is based on RFC8439. + * Note that the Big Integer library we are using may not be suitable for + * cryptographic applications due to non constant time operations. +*/ +pub struct Poly1305 { + p: BigUint, + r: BigUint, + s: BigUint, + /// The accumulator + pub acc: BigUint, +} + +impl Default for Poly1305 { + fn default() -> Self { + Self::new() + } +} + +impl Poly1305 { + pub fn new() -> Self { + Poly1305 { + p: hex_uint!("3fffffffffffffffffffffffffffffffb"), // 2^130 - 5 + r: Zero::zero(), + s: Zero::zero(), + acc: Zero::zero(), + } + } + pub fn clamp_r(&mut self) { + self.r &= hex_uint!("0ffffffc0ffffffc0ffffffc0fffffff"); + } + pub fn set_key(&mut self, key: &[u8; 32]) { + self.r = BigUint::from_bytes_le(&key[..16]); + self.s = BigUint::from_bytes_le(&key[16..]); + self.clamp_r(); + } + /// process a 16-byte-long message block. If message is not long enough, + /// fill the `msg` array with zeros, but set `msg_bytes` to the original + /// chunk length in bytes. See `basic_tv1` for example usage. + pub fn add_msg(&mut self, msg: &[u8; 16], msg_bytes: u64) { + let mut n = BigUint::from_bytes_le(msg); + n.set_bit(msg_bytes * 8, true); + self.acc += n; + self.acc *= &self.r; + self.acc %= &self.p; + } + /// The result is guaranteed to be 16 bytes long + pub fn get_tag(&self) -> Vec { + let result = &self.acc + &self.s; + let mut bytes = result.to_bytes_le(); + bytes.resize(16, 0); + bytes + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fmt::Write; + fn get_tag_hex(tag: &[u8]) -> String { + let mut result = String::new(); + for &x in tag { + write!(result, "{x:02x}").unwrap(); + } + result + } + #[test] + fn basic_tv1() { + let mut mac = Poly1305::default(); + let key: [u8; 32] = [ + 0x85, 0xd6, 0xbe, 0x78, 0x57, 0x55, 0x6d, 0x33, 0x7f, 0x44, 0x52, 0xfe, 0x42, 0xd5, + 0x06, 0xa8, 0x01, 0x03, 0x80, 0x8a, 0xfb, 0x0d, 0xb2, 0xfd, 0x4a, 0xbf, 0xf6, 0xaf, + 0x41, 0x49, 0xf5, 0x1b, + ]; + let mut tmp_buffer = [0_u8; 16]; + mac.set_key(&key); + mac.add_msg(b"Cryptographic Fo", 16); + mac.add_msg(b"rum Research Gro", 16); + tmp_buffer[..2].copy_from_slice(b"up"); + mac.add_msg(&tmp_buffer, 2); + let result = mac.get_tag(); + assert_eq!( + get_tag_hex(result.as_slice()), + "a8061dc1305136c6c22b8baf0c0127a9" + ); + } +} diff --git a/src/bit_manipulation/counting_bits.rs b/src/bit_manipulation/counting_bits.rs new file mode 100644 index 00000000000..9357ca3080c --- /dev/null +++ b/src/bit_manipulation/counting_bits.rs @@ -0,0 +1,54 @@ +//! This module implements a function to count the number of set bits (1s) +//! in the binary representation of an unsigned integer. +//! It uses Brian Kernighan's algorithm, which efficiently clears the least significant +//! set bit in each iteration until all bits are cleared. +//! The algorithm runs in O(k), where k is the number of set bits. + +/// Counts the number of set bits in an unsigned integer. +/// +/// # Arguments +/// +/// * `n` - An unsigned 32-bit integer whose set bits will be counted. +/// +/// # Returns +/// +/// * `usize` - The number of set bits (1s) in the binary representation of the input number. +pub fn count_set_bits(mut n: usize) -> usize { + // Initialize a variable to keep track of the count of set bits + let mut count = 0; + while n > 0 { + // Clear the least significant set bit by + // performing a bitwise AND operation with (n - 1) + n &= n - 1; + + // Increment the count for each set bit found + count += 1; + } + + count +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_count_set_bits { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(count_set_bits(input), expected); + } + )* + }; + } + test_count_set_bits! { + test_count_set_bits_zero: (0, 0), + test_count_set_bits_one: (1, 1), + test_count_set_bits_power_of_two: (16, 1), + test_count_set_bits_all_set_bits: (usize::MAX, std::mem::size_of::() * 8), + test_count_set_bits_alternating_bits: (0b10101010, 4), + test_count_set_bits_mixed_bits: (0b11011011, 6), + } +} diff --git a/src/bit_manipulation/highest_set_bit.rs b/src/bit_manipulation/highest_set_bit.rs new file mode 100644 index 00000000000..3488f49a7d9 --- /dev/null +++ b/src/bit_manipulation/highest_set_bit.rs @@ -0,0 +1,52 @@ +//! This module provides a function to find the position of the most significant bit (MSB) +//! set to 1 in a given positive integer. + +/// Finds the position of the highest (most significant) set bit in a positive integer. +/// +/// # Arguments +/// +/// * `num` - An integer value for which the highest set bit will be determined. +/// +/// # Returns +/// +/// * Returns `Some(position)` if a set bit exists or `None` if no bit is set. +pub fn find_highest_set_bit(num: usize) -> Option { + if num == 0 { + return None; + } + + let mut position = 0; + let mut n = num; + + while n > 0 { + n >>= 1; + position += 1; + } + + Some(position - 1) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_find_highest_set_bit { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(find_highest_set_bit(input), expected); + } + )* + }; + } + + test_find_highest_set_bit! { + test_positive_number: (18, Some(4)), + test_0: (0, None), + test_1: (1, Some(0)), + test_2: (2, Some(1)), + test_3: (3, Some(1)), + } +} diff --git a/src/bit_manipulation/mod.rs b/src/bit_manipulation/mod.rs new file mode 100644 index 00000000000..027c4b81817 --- /dev/null +++ b/src/bit_manipulation/mod.rs @@ -0,0 +1,9 @@ +mod counting_bits; +mod highest_set_bit; +mod n_bits_gray_code; +mod sum_of_two_integers; + +pub use counting_bits::count_set_bits; +pub use highest_set_bit::find_highest_set_bit; +pub use n_bits_gray_code::generate_gray_code; +pub use sum_of_two_integers::add_two_integers; diff --git a/src/bit_manipulation/n_bits_gray_code.rs b/src/bit_manipulation/n_bits_gray_code.rs new file mode 100644 index 00000000000..64c717bc761 --- /dev/null +++ b/src/bit_manipulation/n_bits_gray_code.rs @@ -0,0 +1,75 @@ +/// Custom error type for Gray code generation. +#[derive(Debug, PartialEq)] +pub enum GrayCodeError { + ZeroBitCount, +} + +/// Generates an n-bit Gray code sequence using the direct Gray code formula. +/// +/// # Arguments +/// +/// * `n` - The number of bits for the Gray code. +/// +/// # Returns +/// +/// A vector of Gray code sequences as strings. +pub fn generate_gray_code(n: usize) -> Result, GrayCodeError> { + if n == 0 { + return Err(GrayCodeError::ZeroBitCount); + } + + let num_codes = 1 << n; + let mut result = Vec::with_capacity(num_codes); + + for i in 0..num_codes { + let gray = i ^ (i >> 1); + let gray_code = (0..n) + .rev() + .map(|bit| if gray & (1 << bit) != 0 { '1' } else { '0' }) + .collect::(); + result.push(gray_code); + } + + Ok(result) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! gray_code_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(generate_gray_code(input), expected); + } + )* + }; + } + + gray_code_tests! { + zero_bit_count: (0, Err(GrayCodeError::ZeroBitCount)), + gray_code_1_bit: (1, Ok(vec![ + "0".to_string(), + "1".to_string(), + ])), + gray_code_2_bit: (2, Ok(vec![ + "00".to_string(), + "01".to_string(), + "11".to_string(), + "10".to_string(), + ])), + gray_code_3_bit: (3, Ok(vec![ + "000".to_string(), + "001".to_string(), + "011".to_string(), + "010".to_string(), + "110".to_string(), + "111".to_string(), + "101".to_string(), + "100".to_string(), + ])), + } +} diff --git a/src/bit_manipulation/sum_of_two_integers.rs b/src/bit_manipulation/sum_of_two_integers.rs new file mode 100644 index 00000000000..45d3532b173 --- /dev/null +++ b/src/bit_manipulation/sum_of_two_integers.rs @@ -0,0 +1,55 @@ +//! This module provides a function to add two integers without using the `+` operator. +//! It relies on bitwise operations (XOR and AND) to compute the sum, simulating the addition process. + +/// Adds two integers using bitwise operations. +/// +/// # Arguments +/// +/// * `a` - The first integer to be added. +/// * `b` - The second integer to be added. +/// +/// # Returns +/// +/// * `isize` - The result of adding the two integers. +pub fn add_two_integers(mut a: isize, mut b: isize) -> isize { + let mut carry; + + while b != 0 { + let sum = a ^ b; + carry = (a & b) << 1; + a = sum; + b = carry; + } + + a +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_add_two_integers { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (a, b) = $test_case; + assert_eq!(add_two_integers(a, b), a + b); + assert_eq!(add_two_integers(b, a), a + b); + } + )* + }; + } + + test_add_two_integers! { + test_add_two_integers_positive: (3, 5), + test_add_two_integers_large_positive: (100, 200), + test_add_two_integers_edge_positive: (65535, 1), + test_add_two_integers_negative: (-10, 6), + test_add_two_integers_both_negative: (-50, -30), + test_add_two_integers_edge_negative: (-1, -1), + test_add_two_integers_zero: (0, 0), + test_add_two_integers_zero_with_positive: (0, 42), + test_add_two_integers_zero_with_negative: (0, -42), + } +} diff --git a/src/ciphers/README.md b/src/ciphers/README.md index bbd09b1a97c..fb54477c191 100644 --- a/src/ciphers/README.md +++ b/src/ciphers/README.md @@ -8,6 +8,21 @@ The method is named after **Julius Caesar**, who used it in his private correspo The encryption step performed by a Caesar cipher is often incorporated as part of more complex schemes, such as the Vigenère cipher, and still has modern application in the ROT13 system. As with all single-alphabet substitution ciphers, the Caesar cipher is easily broken and in modern practice offers essentially no communication security. ###### Source: [Wikipedia](https://en.wikipedia.org/wiki/Caesar_cipher) +### [Polybius](./polybius.rs) +The **Polybius square**, also known as the Polybius checkerboard, is a device invented by the ancient Greeks Cleoxenus and Democleitus, and made famous by the historian and scholar Polybius.
+The device is used for fractionating plaintext characters so that they can be represented by a smaller set of symbols, which is useful for telegraphy, steganography, and cryptography.
+The **Polybius square** is also used as a basic cipher called the Polybius cipher. This cipher is a **substitution cipher** with characters being substituted for pairs of digits. + +#### Example cipher + Δ | 1 | 2 | 3 | 4 | 5 +---|---|---|---| --- |--- +1 | a | b | c | d | e +2 | f | g | h | i/j | k +3 | l | m | n | o | p +4 | q | r | s | t | u +5 | v | w | x | y | z +###### Source: [Wikipedia](https://en.wikipedia.org/wiki/Polybius_square) + ### [Vigenère](./vigenere.rs) The **Vigenère cipher** is a method of encrypting alphabetic text by using a series of **interwoven Caesar ciphers** based on the letters of a keyword. It is **a form of polyalphabetic substitution**.
The Vigenère cipher has been reinvented many times. The method was originally described by Giovan Battista Bellaso in his 1553 book La cifra del. Sig. Giovan Battista Bellaso; however, the scheme was later misattributed to Blaise de Vigenère in the 19th century, and is now widely known as the "Vigenère cipher".
@@ -15,10 +30,20 @@ Though the cipher is easy to understand and implement, for three centuries it re Many people have tried to implement encryption schemes that are essentially Vigenère ciphers. Friedrich Kasiski was the first to publish a general method of deciphering a Vigenère cipher in 1863. ###### Source: [Wikipedia](https://en.wikipedia.org/wiki/Vigen%C3%A8re_cipher) -### Transposition _(Not implemented yet)_ +### [SHA-2](./sha256.rs) +SHA-2 (Secure Hash Algorithm 2) is a set of cryptographic hash functions designed by the United States National Security Agency (NSA) and first published in 2001. They are built using the Merkle–Damgård structure, from a one-way compression function itself built using the Davies–Meyer structure from a (classified) specialized block cipher. +###### Source: [Wikipedia](https://en.wikipedia.org/wiki/SHA-2) + +### [Transposition](./transposition.rs) In cryptography, a **transposition cipher** is a method of encryption by which the positions held by units of plaintext (which are commonly characters or groups of characters) are shifted according to a regular system, so that the ciphertext constitutes a permutation of the plaintext. That is, the order of the units is changed (the plaintext is reordered).
Mathematically a bijective function is used on the characters' positions to encrypt and an inverse function to decrypt. ###### Source: [Wikipedia](https://en.wikipedia.org/wiki/Transposition_cipher) [caesar]: https://upload.wikimedia.org/wikipedia/commons/4/4a/Caesar_cipher_left_shift_of_3.svg +### [AES](./aes.rs) +The Advanced Encryption Standard (AES), also known by its original name Rijndael (Dutch pronunciation: [ˈrɛindaːl]), is a specification for the encryption of electronic data established by the U.S. National Institute of Standards and Technology (NIST) in 2001. + +###### Source: [Wikipedia](https://en.wikipedia.org/wiki/Advanced_Encryption_Standard) + +![aes](https://upload.wikimedia.org/wikipedia/commons/5/50/AES_%28Rijndael%29_Round_Function.png) \ No newline at end of file diff --git a/src/ciphers/aes.rs b/src/ciphers/aes.rs new file mode 100644 index 00000000000..5d2eb98ece0 --- /dev/null +++ b/src/ciphers/aes.rs @@ -0,0 +1,544 @@ +const AES_WORD_SIZE: usize = 4; +const AES_BLOCK_SIZE: usize = 16; +const AES_NUM_BLOCK_WORDS: usize = AES_BLOCK_SIZE / AES_WORD_SIZE; + +type Byte = u8; +type Word = u32; + +type AesWord = [Byte; AES_WORD_SIZE]; + +/// Precalculated values for x to the power of 2 in Rijndaels galois field. +/// Used as 'RCON' during the key expansion. +const RCON: [Word; 256] = [ + // 0 1 2 3 4 5 6 7 8 9 A B C D E F + 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, + 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91, 0x39, + 0x72, 0xe4, 0xd3, 0xbd, 0x61, 0xc2, 0x9f, 0x25, 0x4a, 0x94, 0x33, 0x66, 0xcc, 0x83, 0x1d, 0x3a, + 0x74, 0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, + 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, + 0xc5, 0x91, 0x39, 0x72, 0xe4, 0xd3, 0xbd, 0x61, 0xc2, 0x9f, 0x25, 0x4a, 0x94, 0x33, 0x66, 0xcc, + 0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, + 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, + 0x7d, 0xfa, 0xef, 0xc5, 0x91, 0x39, 0x72, 0xe4, 0xd3, 0xbd, 0x61, 0xc2, 0x9f, 0x25, 0x4a, 0x94, + 0x33, 0x66, 0xcc, 0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, + 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, + 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91, 0x39, 0x72, 0xe4, 0xd3, 0xbd, 0x61, 0xc2, 0x9f, + 0x25, 0x4a, 0x94, 0x33, 0x66, 0xcc, 0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb, 0x8d, 0x01, 0x02, 0x04, + 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, + 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91, 0x39, 0x72, 0xe4, 0xd3, 0xbd, + 0x61, 0xc2, 0x9f, 0x25, 0x4a, 0x94, 0x33, 0x66, 0xcc, 0x83, 0x1d, 0x3a, 0x74, 0xe8, 0xcb, 0x8d, +]; + +/// Rijndael S-box Substitution table used for encryption in the subBytes +/// step, as well as the key expansion. +const SBOX: [Byte; 256] = [ + // 0 1 2 3 4 5 6 7 8 9 A B C D E F + 0x63, 0x7C, 0x77, 0x7B, 0xF2, 0x6B, 0x6F, 0xC5, 0x30, 0x01, 0x67, 0x2B, 0xFE, 0xD7, 0xAB, 0x76, + 0xCA, 0x82, 0xC9, 0x7D, 0xFA, 0x59, 0x47, 0xF0, 0xAD, 0xD4, 0xA2, 0xAF, 0x9C, 0xA4, 0x72, 0xC0, + 0xB7, 0xFD, 0x93, 0x26, 0x36, 0x3F, 0xF7, 0xCC, 0x34, 0xA5, 0xE5, 0xF1, 0x71, 0xD8, 0x31, 0x15, + 0x04, 0xC7, 0x23, 0xC3, 0x18, 0x96, 0x05, 0x9A, 0x07, 0x12, 0x80, 0xE2, 0xEB, 0x27, 0xB2, 0x75, + 0x09, 0x83, 0x2C, 0x1A, 0x1B, 0x6E, 0x5A, 0xA0, 0x52, 0x3B, 0xD6, 0xB3, 0x29, 0xE3, 0x2F, 0x84, + 0x53, 0xD1, 0x00, 0xED, 0x20, 0xFC, 0xB1, 0x5B, 0x6A, 0xCB, 0xBE, 0x39, 0x4A, 0x4C, 0x58, 0xCF, + 0xD0, 0xEF, 0xAA, 0xFB, 0x43, 0x4D, 0x33, 0x85, 0x45, 0xF9, 0x02, 0x7F, 0x50, 0x3C, 0x9F, 0xA8, + 0x51, 0xA3, 0x40, 0x8F, 0x92, 0x9D, 0x38, 0xF5, 0xBC, 0xB6, 0xDA, 0x21, 0x10, 0xFF, 0xF3, 0xD2, + 0xCD, 0x0C, 0x13, 0xEC, 0x5F, 0x97, 0x44, 0x17, 0xC4, 0xA7, 0x7E, 0x3D, 0x64, 0x5D, 0x19, 0x73, + 0x60, 0x81, 0x4F, 0xDC, 0x22, 0x2A, 0x90, 0x88, 0x46, 0xEE, 0xB8, 0x14, 0xDE, 0x5E, 0x0B, 0xDB, + 0xE0, 0x32, 0x3A, 0x0A, 0x49, 0x06, 0x24, 0x5C, 0xC2, 0xD3, 0xAC, 0x62, 0x91, 0x95, 0xE4, 0x79, + 0xE7, 0xC8, 0x37, 0x6D, 0x8D, 0xD5, 0x4E, 0xA9, 0x6C, 0x56, 0xF4, 0xEA, 0x65, 0x7A, 0xAE, 0x08, + 0xBA, 0x78, 0x25, 0x2E, 0x1C, 0xA6, 0xB4, 0xC6, 0xE8, 0xDD, 0x74, 0x1F, 0x4B, 0xBD, 0x8B, 0x8A, + 0x70, 0x3E, 0xB5, 0x66, 0x48, 0x03, 0xF6, 0x0E, 0x61, 0x35, 0x57, 0xB9, 0x86, 0xC1, 0x1D, 0x9E, + 0xE1, 0xF8, 0x98, 0x11, 0x69, 0xD9, 0x8E, 0x94, 0x9B, 0x1E, 0x87, 0xE9, 0xCE, 0x55, 0x28, 0xDF, + 0x8C, 0xA1, 0x89, 0x0D, 0xBF, 0xE6, 0x42, 0x68, 0x41, 0x99, 0x2D, 0x0F, 0xB0, 0x54, 0xBB, 0x16, +]; + +/// Inverse Rijndael S-box Substitution table used for decryption in the +/// subBytesDec step. +const INV_SBOX: [Byte; 256] = [ + // 0 1 2 3 4 5 6 7 8 9 A B C D E F + 0x52, 0x09, 0x6A, 0xD5, 0x30, 0x36, 0xA5, 0x38, 0xBF, 0x40, 0xA3, 0x9E, 0x81, 0xF3, 0xD7, 0xFB, + 0x7C, 0xE3, 0x39, 0x82, 0x9B, 0x2F, 0xFF, 0x87, 0x34, 0x8E, 0x43, 0x44, 0xC4, 0xDE, 0xE9, 0xCB, + 0x54, 0x7B, 0x94, 0x32, 0xA6, 0xC2, 0x23, 0x3D, 0xEE, 0x4C, 0x95, 0x0B, 0x42, 0xFA, 0xC3, 0x4E, + 0x08, 0x2E, 0xA1, 0x66, 0x28, 0xD9, 0x24, 0xB2, 0x76, 0x5B, 0xA2, 0x49, 0x6D, 0x8B, 0xD1, 0x25, + 0x72, 0xF8, 0xF6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xD4, 0xA4, 0x5C, 0xCC, 0x5D, 0x65, 0xB6, 0x92, + 0x6C, 0x70, 0x48, 0x50, 0xFD, 0xED, 0xB9, 0xDA, 0x5E, 0x15, 0x46, 0x57, 0xA7, 0x8D, 0x9D, 0x84, + 0x90, 0xD8, 0xAB, 0x00, 0x8C, 0xBC, 0xD3, 0x0A, 0xF7, 0xE4, 0x58, 0x05, 0xB8, 0xB3, 0x45, 0x06, + 0xD0, 0x2C, 0x1E, 0x8F, 0xCA, 0x3F, 0x0F, 0x02, 0xC1, 0xAF, 0xBD, 0x03, 0x01, 0x13, 0x8A, 0x6B, + 0x3A, 0x91, 0x11, 0x41, 0x4F, 0x67, 0xDC, 0xEA, 0x97, 0xF2, 0xCF, 0xCE, 0xF0, 0xB4, 0xE6, 0x73, + 0x96, 0xAC, 0x74, 0x22, 0xE7, 0xAD, 0x35, 0x85, 0xE2, 0xF9, 0x37, 0xE8, 0x1C, 0x75, 0xDF, 0x6E, + 0x47, 0xF1, 0x1A, 0x71, 0x1D, 0x29, 0xC5, 0x89, 0x6F, 0xB7, 0x62, 0x0E, 0xAA, 0x18, 0xBE, 0x1B, + 0xFC, 0x56, 0x3E, 0x4B, 0xC6, 0xD2, 0x79, 0x20, 0x9A, 0xDB, 0xC0, 0xFE, 0x78, 0xCD, 0x5A, 0xF4, + 0x1F, 0xDD, 0xA8, 0x33, 0x88, 0x07, 0xC7, 0x31, 0xB1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xEC, 0x5F, + 0x60, 0x51, 0x7F, 0xA9, 0x19, 0xB5, 0x4A, 0x0D, 0x2D, 0xE5, 0x7A, 0x9F, 0x93, 0xC9, 0x9C, 0xEF, + 0xA0, 0xE0, 0x3B, 0x4D, 0xAE, 0x2A, 0xF5, 0xB0, 0xC8, 0xEB, 0xBB, 0x3C, 0x83, 0x53, 0x99, 0x61, + 0x17, 0x2B, 0x04, 0x7E, 0xBA, 0x77, 0xD6, 0x26, 0xE1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0C, 0x7D, +]; + +#[rustfmt::skip] +const GF_MUL_TABLE: [[Byte; 256]; 16] = [ + /* 0 */ [0u8; 256], + /* 1 */ + [ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, + 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, + 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, + 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, + 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, + 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, + 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf, + 0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7, 0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf, + 0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7, 0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf, + 0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf, + 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, + ], + /* 2 */ + [ + 0x00, 0x02, 0x04, 0x06, 0x08, 0x0a, 0x0c, 0x0e, 0x10, 0x12, 0x14, 0x16, 0x18, 0x1a, 0x1c, 0x1e, + 0x20, 0x22, 0x24, 0x26, 0x28, 0x2a, 0x2c, 0x2e, 0x30, 0x32, 0x34, 0x36, 0x38, 0x3a, 0x3c, 0x3e, + 0x40, 0x42, 0x44, 0x46, 0x48, 0x4a, 0x4c, 0x4e, 0x50, 0x52, 0x54, 0x56, 0x58, 0x5a, 0x5c, 0x5e, + 0x60, 0x62, 0x64, 0x66, 0x68, 0x6a, 0x6c, 0x6e, 0x70, 0x72, 0x74, 0x76, 0x78, 0x7a, 0x7c, 0x7e, + 0x80, 0x82, 0x84, 0x86, 0x88, 0x8a, 0x8c, 0x8e, 0x90, 0x92, 0x94, 0x96, 0x98, 0x9a, 0x9c, 0x9e, + 0xa0, 0xa2, 0xa4, 0xa6, 0xa8, 0xaa, 0xac, 0xae, 0xb0, 0xb2, 0xb4, 0xb6, 0xb8, 0xba, 0xbc, 0xbe, + 0xc0, 0xc2, 0xc4, 0xc6, 0xc8, 0xca, 0xcc, 0xce, 0xd0, 0xd2, 0xd4, 0xd6, 0xd8, 0xda, 0xdc, 0xde, + 0xe0, 0xe2, 0xe4, 0xe6, 0xe8, 0xea, 0xec, 0xee, 0xf0, 0xf2, 0xf4, 0xf6, 0xf8, 0xfa, 0xfc, 0xfe, + 0x1b, 0x19, 0x1f, 0x1d, 0x13, 0x11, 0x17, 0x15, 0x0b, 0x09, 0x0f, 0x0d, 0x03, 0x01, 0x07, 0x05, + 0x3b, 0x39, 0x3f, 0x3d, 0x33, 0x31, 0x37, 0x35, 0x2b, 0x29, 0x2f, 0x2d, 0x23, 0x21, 0x27, 0x25, + 0x5b, 0x59, 0x5f, 0x5d, 0x53, 0x51, 0x57, 0x55, 0x4b, 0x49, 0x4f, 0x4d, 0x43, 0x41, 0x47, 0x45, + 0x7b, 0x79, 0x7f, 0x7d, 0x73, 0x71, 0x77, 0x75, 0x6b, 0x69, 0x6f, 0x6d, 0x63, 0x61, 0x67, 0x65, + 0x9b, 0x99, 0x9f, 0x9d, 0x93, 0x91, 0x97, 0x95, 0x8b, 0x89, 0x8f, 0x8d, 0x83, 0x81, 0x87, 0x85, + 0xbb, 0xb9, 0xbf, 0xbd, 0xb3, 0xb1, 0xb7, 0xb5, 0xab, 0xa9, 0xaf, 0xad, 0xa3, 0xa1, 0xa7, 0xa5, + 0xdb, 0xd9, 0xdf, 0xdd, 0xd3, 0xd1, 0xd7, 0xd5, 0xcb, 0xc9, 0xcf, 0xcd, 0xc3, 0xc1, 0xc7, 0xc5, + 0xfb, 0xf9, 0xff, 0xfd, 0xf3, 0xf1, 0xf7, 0xf5, 0xeb, 0xe9, 0xef, 0xed, 0xe3, 0xe1, 0xe7, 0xe5 + ], + /* 3 */ + [ + 0x00, 0x03, 0x06, 0x05, 0x0c, 0x0f, 0x0a, 0x09, 0x18, 0x1b, 0x1e, 0x1d, 0x14, 0x17, 0x12, 0x11, + 0x30, 0x33, 0x36, 0x35, 0x3c, 0x3f, 0x3a, 0x39, 0x28, 0x2b, 0x2e, 0x2d, 0x24, 0x27, 0x22, 0x21, + 0x60, 0x63, 0x66, 0x65, 0x6c, 0x6f, 0x6a, 0x69, 0x78, 0x7b, 0x7e, 0x7d, 0x74, 0x77, 0x72, 0x71, + 0x50, 0x53, 0x56, 0x55, 0x5c, 0x5f, 0x5a, 0x59, 0x48, 0x4b, 0x4e, 0x4d, 0x44, 0x47, 0x42, 0x41, + 0xc0, 0xc3, 0xc6, 0xc5, 0xcc, 0xcf, 0xca, 0xc9, 0xd8, 0xdb, 0xde, 0xdd, 0xd4, 0xd7, 0xd2, 0xd1, + 0xf0, 0xf3, 0xf6, 0xf5, 0xfc, 0xff, 0xfa, 0xf9, 0xe8, 0xeb, 0xee, 0xed, 0xe4, 0xe7, 0xe2, 0xe1, + 0xa0, 0xa3, 0xa6, 0xa5, 0xac, 0xaf, 0xaa, 0xa9, 0xb8, 0xbb, 0xbe, 0xbd, 0xb4, 0xb7, 0xb2, 0xb1, + 0x90, 0x93, 0x96, 0x95, 0x9c, 0x9f, 0x9a, 0x99, 0x88, 0x8b, 0x8e, 0x8d, 0x84, 0x87, 0x82, 0x81, + 0x9b, 0x98, 0x9d, 0x9e, 0x97, 0x94, 0x91, 0x92, 0x83, 0x80, 0x85, 0x86, 0x8f, 0x8c, 0x89, 0x8a, + 0xab, 0xa8, 0xad, 0xae, 0xa7, 0xa4, 0xa1, 0xa2, 0xb3, 0xb0, 0xb5, 0xb6, 0xbf, 0xbc, 0xb9, 0xba, + 0xfb, 0xf8, 0xfd, 0xfe, 0xf7, 0xf4, 0xf1, 0xf2, 0xe3, 0xe0, 0xe5, 0xe6, 0xef, 0xec, 0xe9, 0xea, + 0xcb, 0xc8, 0xcd, 0xce, 0xc7, 0xc4, 0xc1, 0xc2, 0xd3, 0xd0, 0xd5, 0xd6, 0xdf, 0xdc, 0xd9, 0xda, + 0x5b, 0x58, 0x5d, 0x5e, 0x57, 0x54, 0x51, 0x52, 0x43, 0x40, 0x45, 0x46, 0x4f, 0x4c, 0x49, 0x4a, + 0x6b, 0x68, 0x6d, 0x6e, 0x67, 0x64, 0x61, 0x62, 0x73, 0x70, 0x75, 0x76, 0x7f, 0x7c, 0x79, 0x7a, + 0x3b, 0x38, 0x3d, 0x3e, 0x37, 0x34, 0x31, 0x32, 0x23, 0x20, 0x25, 0x26, 0x2f, 0x2c, 0x29, 0x2a, + 0x0b, 0x08, 0x0d, 0x0e, 0x07, 0x04, 0x01, 0x02, 0x13, 0x10, 0x15, 0x16, 0x1f, 0x1c, 0x19, 0x1a, + ], + /* 4 */ [0u8; 256], + /* 5 */ [0u8; 256], + /* 6 */ [0u8; 256], + /* 7 */ [0u8; 256], + /* 8 */ [0u8; 256], + /* 9 */ + [ + 0x00, 0x09, 0x12, 0x1b, 0x24, 0x2d, 0x36, 0x3f, 0x48, 0x41, 0x5a, 0x53, 0x6c, 0x65, 0x7e, 0x77, + 0x90, 0x99, 0x82, 0x8b, 0xb4, 0xbd, 0xa6, 0xaf, 0xd8, 0xd1, 0xca, 0xc3, 0xfc, 0xf5, 0xee, 0xe7, + 0x3b, 0x32, 0x29, 0x20, 0x1f, 0x16, 0x0d, 0x04, 0x73, 0x7a, 0x61, 0x68, 0x57, 0x5e, 0x45, 0x4c, + 0xab, 0xa2, 0xb9, 0xb0, 0x8f, 0x86, 0x9d, 0x94, 0xe3, 0xea, 0xf1, 0xf8, 0xc7, 0xce, 0xd5, 0xdc, + 0x76, 0x7f, 0x64, 0x6d, 0x52, 0x5b, 0x40, 0x49, 0x3e, 0x37, 0x2c, 0x25, 0x1a, 0x13, 0x08, 0x01, + 0xe6, 0xef, 0xf4, 0xfd, 0xc2, 0xcb, 0xd0, 0xd9, 0xae, 0xa7, 0xbc, 0xb5, 0x8a, 0x83, 0x98, 0x91, + 0x4d, 0x44, 0x5f, 0x56, 0x69, 0x60, 0x7b, 0x72, 0x05, 0x0c, 0x17, 0x1e, 0x21, 0x28, 0x33, 0x3a, + 0xdd, 0xd4, 0xcf, 0xc6, 0xf9, 0xf0, 0xeb, 0xe2, 0x95, 0x9c, 0x87, 0x8e, 0xb1, 0xb8, 0xa3, 0xaa, + 0xec, 0xe5, 0xfe, 0xf7, 0xc8, 0xc1, 0xda, 0xd3, 0xa4, 0xad, 0xb6, 0xbf, 0x80, 0x89, 0x92, 0x9b, + 0x7c, 0x75, 0x6e, 0x67, 0x58, 0x51, 0x4a, 0x43, 0x34, 0x3d, 0x26, 0x2f, 0x10, 0x19, 0x02, 0x0b, + 0xd7, 0xde, 0xc5, 0xcc, 0xf3, 0xfa, 0xe1, 0xe8, 0x9f, 0x96, 0x8d, 0x84, 0xbb, 0xb2, 0xa9, 0xa0, + 0x47, 0x4e, 0x55, 0x5c, 0x63, 0x6a, 0x71, 0x78, 0x0f, 0x06, 0x1d, 0x14, 0x2b, 0x22, 0x39, 0x30, + 0x9a, 0x93, 0x88, 0x81, 0xbe, 0xb7, 0xac, 0xa5, 0xd2, 0xdb, 0xc0, 0xc9, 0xf6, 0xff, 0xe4, 0xed, + 0x0a, 0x03, 0x18, 0x11, 0x2e, 0x27, 0x3c, 0x35, 0x42, 0x4b, 0x50, 0x59, 0x66, 0x6f, 0x74, 0x7d, + 0xa1, 0xa8, 0xb3, 0xba, 0x85, 0x8c, 0x97, 0x9e, 0xe9, 0xe0, 0xfb, 0xf2, 0xcd, 0xc4, 0xdf, 0xd6, + 0x31, 0x38, 0x23, 0x2a, 0x15, 0x1c, 0x07, 0x0e, 0x79, 0x70, 0x6b, 0x62, 0x5d, 0x54, 0x4f, 0x46, + ], + /* A */ [0u8; 256], + /* B */ + [ + 0x00, 0x0b, 0x16, 0x1d, 0x2c, 0x27, 0x3a, 0x31, 0x58, 0x53, 0x4e, 0x45, 0x74, 0x7f, 0x62, 0x69, + 0xb0, 0xbb, 0xa6, 0xad, 0x9c, 0x97, 0x8a, 0x81, 0xe8, 0xe3, 0xfe, 0xf5, 0xc4, 0xcf, 0xd2, 0xd9, + 0x7b, 0x70, 0x6d, 0x66, 0x57, 0x5c, 0x41, 0x4a, 0x23, 0x28, 0x35, 0x3e, 0x0f, 0x04, 0x19, 0x12, + 0xcb, 0xc0, 0xdd, 0xd6, 0xe7, 0xec, 0xf1, 0xfa, 0x93, 0x98, 0x85, 0x8e, 0xbf, 0xb4, 0xa9, 0xa2, + 0xf6, 0xfd, 0xe0, 0xeb, 0xda, 0xd1, 0xcc, 0xc7, 0xae, 0xa5, 0xb8, 0xb3, 0x82, 0x89, 0x94, 0x9f, + 0x46, 0x4d, 0x50, 0x5b, 0x6a, 0x61, 0x7c, 0x77, 0x1e, 0x15, 0x08, 0x03, 0x32, 0x39, 0x24, 0x2f, + 0x8d, 0x86, 0x9b, 0x90, 0xa1, 0xaa, 0xb7, 0xbc, 0xd5, 0xde, 0xc3, 0xc8, 0xf9, 0xf2, 0xef, 0xe4, + 0x3d, 0x36, 0x2b, 0x20, 0x11, 0x1a, 0x07, 0x0c, 0x65, 0x6e, 0x73, 0x78, 0x49, 0x42, 0x5f, 0x54, + 0xf7, 0xfc, 0xe1, 0xea, 0xdb, 0xd0, 0xcd, 0xc6, 0xaf, 0xa4, 0xb9, 0xb2, 0x83, 0x88, 0x95, 0x9e, + 0x47, 0x4c, 0x51, 0x5a, 0x6b, 0x60, 0x7d, 0x76, 0x1f, 0x14, 0x09, 0x02, 0x33, 0x38, 0x25, 0x2e, + 0x8c, 0x87, 0x9a, 0x91, 0xa0, 0xab, 0xb6, 0xbd, 0xd4, 0xdf, 0xc2, 0xc9, 0xf8, 0xf3, 0xee, 0xe5, + 0x3c, 0x37, 0x2a, 0x21, 0x10, 0x1b, 0x06, 0x0d, 0x64, 0x6f, 0x72, 0x79, 0x48, 0x43, 0x5e, 0x55, + 0x01, 0x0a, 0x17, 0x1c, 0x2d, 0x26, 0x3b, 0x30, 0x59, 0x52, 0x4f, 0x44, 0x75, 0x7e, 0x63, 0x68, + 0xb1, 0xba, 0xa7, 0xac, 0x9d, 0x96, 0x8b, 0x80, 0xe9, 0xe2, 0xff, 0xf4, 0xc5, 0xce, 0xd3, 0xd8, + 0x7a, 0x71, 0x6c, 0x67, 0x56, 0x5d, 0x40, 0x4b, 0x22, 0x29, 0x34, 0x3f, 0x0e, 0x05, 0x18, 0x13, + 0xca, 0xc1, 0xdc, 0xd7, 0xe6, 0xed, 0xf0, 0xfb, 0x92, 0x99, 0x84, 0x8f, 0xbe, 0xb5, 0xa8, 0xa3, + ], + /* C */ [0u8; 256], + /* D */ + [ + 0x00, 0x0d, 0x1a, 0x17, 0x34, 0x39, 0x2e, 0x23, 0x68, 0x65, 0x72, 0x7f, 0x5c, 0x51, 0x46, 0x4b, + 0xd0, 0xdd, 0xca, 0xc7, 0xe4, 0xe9, 0xfe, 0xf3, 0xb8, 0xb5, 0xa2, 0xaf, 0x8c, 0x81, 0x96, 0x9b, + 0xbb, 0xb6, 0xa1, 0xac, 0x8f, 0x82, 0x95, 0x98, 0xd3, 0xde, 0xc9, 0xc4, 0xe7, 0xea, 0xfd, 0xf0, + 0x6b, 0x66, 0x71, 0x7c, 0x5f, 0x52, 0x45, 0x48, 0x03, 0x0e, 0x19, 0x14, 0x37, 0x3a, 0x2d, 0x20, + 0x6d, 0x60, 0x77, 0x7a, 0x59, 0x54, 0x43, 0x4e, 0x05, 0x08, 0x1f, 0x12, 0x31, 0x3c, 0x2b, 0x26, + 0xbd, 0xb0, 0xa7, 0xaa, 0x89, 0x84, 0x93, 0x9e, 0xd5, 0xd8, 0xcf, 0xc2, 0xe1, 0xec, 0xfb, 0xf6, + 0xd6, 0xdb, 0xcc, 0xc1, 0xe2, 0xef, 0xf8, 0xf5, 0xbe, 0xb3, 0xa4, 0xa9, 0x8a, 0x87, 0x90, 0x9d, + 0x06, 0x0b, 0x1c, 0x11, 0x32, 0x3f, 0x28, 0x25, 0x6e, 0x63, 0x74, 0x79, 0x5a, 0x57, 0x40, 0x4d, + 0xda, 0xd7, 0xc0, 0xcd, 0xee, 0xe3, 0xf4, 0xf9, 0xb2, 0xbf, 0xa8, 0xa5, 0x86, 0x8b, 0x9c, 0x91, + 0x0a, 0x07, 0x10, 0x1d, 0x3e, 0x33, 0x24, 0x29, 0x62, 0x6f, 0x78, 0x75, 0x56, 0x5b, 0x4c, 0x41, + 0x61, 0x6c, 0x7b, 0x76, 0x55, 0x58, 0x4f, 0x42, 0x09, 0x04, 0x13, 0x1e, 0x3d, 0x30, 0x27, 0x2a, + 0xb1, 0xbc, 0xab, 0xa6, 0x85, 0x88, 0x9f, 0x92, 0xd9, 0xd4, 0xc3, 0xce, 0xed, 0xe0, 0xf7, 0xfa, + 0xb7, 0xba, 0xad, 0xa0, 0x83, 0x8e, 0x99, 0x94, 0xdf, 0xd2, 0xc5, 0xc8, 0xeb, 0xe6, 0xf1, 0xfc, + 0x67, 0x6a, 0x7d, 0x70, 0x53, 0x5e, 0x49, 0x44, 0x0f, 0x02, 0x15, 0x18, 0x3b, 0x36, 0x21, 0x2c, + 0x0c, 0x01, 0x16, 0x1b, 0x38, 0x35, 0x22, 0x2f, 0x64, 0x69, 0x7e, 0x73, 0x50, 0x5d, 0x4a, 0x47, + 0xdc, 0xd1, 0xc6, 0xcb, 0xe8, 0xe5, 0xf2, 0xff, 0xb4, 0xb9, 0xae, 0xa3, 0x80, 0x8d, 0x9a, 0x97 + ], + /* E */ + [ + 0x00, 0x0e, 0x1c, 0x12, 0x38, 0x36, 0x24, 0x2a, 0x70, 0x7e, 0x6c, 0x62, 0x48, 0x46, 0x54, 0x5a, + 0xe0, 0xee, 0xfc, 0xf2, 0xd8, 0xd6, 0xc4, 0xca, 0x90, 0x9e, 0x8c, 0x82, 0xa8, 0xa6, 0xb4, 0xba, + 0xdb, 0xd5, 0xc7, 0xc9, 0xe3, 0xed, 0xff, 0xf1, 0xab, 0xa5, 0xb7, 0xb9, 0x93, 0x9d, 0x8f, 0x81, + 0x3b, 0x35, 0x27, 0x29, 0x03, 0x0d, 0x1f, 0x11, 0x4b, 0x45, 0x57, 0x59, 0x73, 0x7d, 0x6f, 0x61, + 0xad, 0xa3, 0xb1, 0xbf, 0x95, 0x9b, 0x89, 0x87, 0xdd, 0xd3, 0xc1, 0xcf, 0xe5, 0xeb, 0xf9, 0xf7, + 0x4d, 0x43, 0x51, 0x5f, 0x75, 0x7b, 0x69, 0x67, 0x3d, 0x33, 0x21, 0x2f, 0x05, 0x0b, 0x19, 0x17, + 0x76, 0x78, 0x6a, 0x64, 0x4e, 0x40, 0x52, 0x5c, 0x06, 0x08, 0x1a, 0x14, 0x3e, 0x30, 0x22, 0x2c, + 0x96, 0x98, 0x8a, 0x84, 0xae, 0xa0, 0xb2, 0xbc, 0xe6, 0xe8, 0xfa, 0xf4, 0xde, 0xd0, 0xc2, 0xcc, + 0x41, 0x4f, 0x5d, 0x53, 0x79, 0x77, 0x65, 0x6b, 0x31, 0x3f, 0x2d, 0x23, 0x09, 0x07, 0x15, 0x1b, + 0xa1, 0xaf, 0xbd, 0xb3, 0x99, 0x97, 0x85, 0x8b, 0xd1, 0xdf, 0xcd, 0xc3, 0xe9, 0xe7, 0xf5, 0xfb, + 0x9a, 0x94, 0x86, 0x88, 0xa2, 0xac, 0xbe, 0xb0, 0xea, 0xe4, 0xf6, 0xf8, 0xd2, 0xdc, 0xce, 0xc0, + 0x7a, 0x74, 0x66, 0x68, 0x42, 0x4c, 0x5e, 0x50, 0x0a, 0x04, 0x16, 0x18, 0x32, 0x3c, 0x2e, 0x20, + 0xec, 0xe2, 0xf0, 0xfe, 0xd4, 0xda, 0xc8, 0xc6, 0x9c, 0x92, 0x80, 0x8e, 0xa4, 0xaa, 0xb8, 0xb6, + 0x0c, 0x02, 0x10, 0x1e, 0x34, 0x3a, 0x28, 0x26, 0x7c, 0x72, 0x60, 0x6e, 0x44, 0x4a, 0x58, 0x56, + 0x37, 0x39, 0x2b, 0x25, 0x0f, 0x01, 0x13, 0x1d, 0x47, 0x49, 0x5b, 0x55, 0x7f, 0x71, 0x63, 0x6d, + 0xd7, 0xd9, 0xcb, 0xc5, 0xef, 0xe1, 0xf3, 0xfd, 0xa7, 0xa9, 0xbb, 0xb5, 0x9f, 0x91, 0x83, 0x8d + ], + /* F */ [0u8; 256], +]; + +pub enum AesKey { + AesKey128([Byte; 16]), + AesKey192([Byte; 24]), + AesKey256([Byte; 32]), +} + +#[derive(Clone, Copy)] +enum AesMode { + Encryption, + Decryption, +} + +pub fn aes_encrypt(plain_text: &[Byte], key: AesKey) -> Vec { + let (key, num_rounds) = match key { + AesKey::AesKey128(key) => (Vec::from(key), 10), + AesKey::AesKey192(key) => (Vec::from(key), 12), + AesKey::AesKey256(key) => (Vec::from(key), 14), + }; + + let round_keys = key_expansion(&key, num_rounds); + let mut data = padding::(plain_text, AES_BLOCK_SIZE); + + let round_key = &round_keys[0..AES_BLOCK_SIZE]; + add_round_key(&mut data, round_key); + + for round in 1..num_rounds { + sub_bytes_blocks(&mut data, AesMode::Encryption); + shift_rows_blocks(&mut data, AesMode::Encryption); + mix_column_blocks(&mut data, AesMode::Encryption); + let round_key = &round_keys[round * AES_BLOCK_SIZE..(round + 1) * AES_BLOCK_SIZE]; + add_round_key(&mut data, round_key); + } + + sub_bytes_blocks(&mut data, AesMode::Encryption); + shift_rows_blocks(&mut data, AesMode::Encryption); + let round_key = &round_keys[num_rounds * AES_BLOCK_SIZE..(num_rounds + 1) * AES_BLOCK_SIZE]; + add_round_key(&mut data, round_key); + + data +} + +pub fn aes_decrypt(cipher_text: &[Byte], key: AesKey) -> Vec { + let (key, num_rounds) = match key { + AesKey::AesKey128(key) => (Vec::from(key), 10), + AesKey::AesKey192(key) => (Vec::from(key), 12), + AesKey::AesKey256(key) => (Vec::from(key), 14), + }; + + let round_keys = key_expansion(&key, num_rounds); + let mut data = padding::(cipher_text, AES_BLOCK_SIZE); + + let round_key = &round_keys[num_rounds * AES_BLOCK_SIZE..(num_rounds + 1) * AES_BLOCK_SIZE]; + add_round_key(&mut data, round_key); + shift_rows_blocks(&mut data, AesMode::Decryption); + sub_bytes_blocks(&mut data, AesMode::Decryption); + + for round in (1..num_rounds).rev() { + let round_key = &round_keys[round * AES_BLOCK_SIZE..(round + 1) * AES_BLOCK_SIZE]; + add_round_key(&mut data, round_key); + mix_column_blocks(&mut data, AesMode::Decryption); + shift_rows_blocks(&mut data, AesMode::Decryption); + sub_bytes_blocks(&mut data, AesMode::Decryption); + } + + let round_key = &round_keys[0..AES_BLOCK_SIZE]; + add_round_key(&mut data, round_key); + + data +} + +fn key_expansion(init_key: &[Byte], num_rounds: usize) -> Vec { + let nr = num_rounds; + // number of words in initial key + let nk = init_key.len() / AES_WORD_SIZE; + let nb = AES_NUM_BLOCK_WORDS; + + let key = init_key + .chunks(AES_WORD_SIZE) + .map(bytes_to_word) + .collect::>(); + let mut key = padding::(&key, nk * (nr + 1)); + + for i in nk..nb * (nr + 1) { + let mut temp_word = key[i - 1]; + if i % nk == 0 { + temp_word = sub_word(rot_word(temp_word), AesMode::Encryption) ^ RCON[i / nk]; + } else if nk > 6 && i % nk == 4 { + temp_word = sub_word(temp_word, AesMode::Encryption); + } + key[i] = key[i - nk] ^ temp_word; + } + + key.iter() + .map(|&w| word_to_bytes(w)) + .collect::>() + .concat() +} + +fn add_round_key(data: &mut [Byte], round_key: &[Byte]) { + assert!(data.len() % AES_BLOCK_SIZE == 0 && round_key.len() == AES_BLOCK_SIZE); + let num_blocks = data.len() / AES_BLOCK_SIZE; + data.iter_mut() + .zip(round_key.repeat(num_blocks)) + .for_each(|(s, k)| *s ^= k); +} + +fn sub_bytes_blocks(data: &mut [Byte], mode: AesMode) { + for block in data.chunks_mut(AES_BLOCK_SIZE) { + sub_bytes(block, mode); + } +} + +fn shift_rows_blocks(blocks: &mut [Byte], mode: AesMode) { + for block in blocks.chunks_mut(AES_BLOCK_SIZE) { + transpose_block(block); + shift_rows(block, mode); + transpose_block(block); + } +} + +fn mix_column_blocks(data: &mut [Byte], mode: AesMode) { + for block in data.chunks_mut(AES_BLOCK_SIZE) { + transpose_block(block); + mix_column(block, mode); + transpose_block(block); + } +} + +fn padding(data: &[T], block_size: usize) -> Vec { + if data.len() % block_size == 0 { + Vec::from(data) + } else { + let num_blocks = data.len() / block_size + 1; + let mut padded = Vec::from(data); + padded.append(&mut vec![ + T::default(); + num_blocks * block_size - data.len() + ]); + padded + } +} + +fn sub_word(word: Word, mode: AesMode) -> Word { + let mut bytes = word_to_bytes(word); + sub_bytes(&mut bytes, mode); + bytes_to_word(&bytes) +} + +fn sub_bytes(data: &mut [Byte], mode: AesMode) { + let sbox = match mode { + AesMode::Encryption => &SBOX, + AesMode::Decryption => &INV_SBOX, + }; + for data_byte in data { + *data_byte = sbox[*data_byte as usize]; + } +} + +fn shift_rows(block: &mut [Byte], mode: AesMode) { + // skip the first row, index begin from 1 + for row in 1..4 { + let mut row_word: AesWord = [0u8; 4]; + row_word.copy_from_slice(&block[row * 4..row * 4 + 4]); + for col in 0..4 { + block[row * 4 + col] = match mode { + AesMode::Encryption => row_word[(col + row) % 4], + AesMode::Decryption => row_word[(col + 4 - row) % 4], + } + } + } +} + +fn mix_column(block: &mut [Byte], mode: AesMode) { + let mix_col_mat = match mode { + AesMode::Encryption => [ + [0x02, 0x03, 0x01, 0x01], + [0x01, 0x02, 0x03, 0x01], + [0x01, 0x01, 0x02, 0x03], + [0x03, 0x01, 0x01, 0x02], + ], + AesMode::Decryption => [ + [0x0e, 0x0b, 0x0d, 0x09], + [0x09, 0x0e, 0x0b, 0x0d], + [0x0d, 0x09, 0x0e, 0x0b], + [0x0b, 0x0d, 0x09, 0x0e], + ], + }; + + for col in 0..4 { + let col_word = block + .iter() + .zip(0..AES_BLOCK_SIZE) + .filter_map(|(&x, i)| if i % 4 == col { Some(x) } else { None }) + .collect::>(); + for row in 0..4 { + let mut word = 0; + for i in 0..4 { + word ^= GF_MUL_TABLE[mix_col_mat[row][i]][col_word[i] as usize] as Word + } + block[row * 4 + col] = word as Byte; + } + } +} + +fn transpose_block(block: &mut [u8]) { + let mut src_block = [0u8; AES_BLOCK_SIZE]; + src_block.copy_from_slice(block); + for row in 0..4 { + for col in 0..4 { + block[row * 4 + col] = src_block[col * 4 + row]; + } + } +} + +fn bytes_to_word(bytes: &[Byte]) -> Word { + assert!(bytes.len() == AES_WORD_SIZE); + let mut word = 0; + for (i, &byte) in bytes.iter().enumerate() { + word |= (byte as Word) << (8 * i); + } + word +} + +fn word_to_bytes(word: Word) -> AesWord { + let mut bytes = [0; AES_WORD_SIZE]; + for (i, byte) in bytes.iter_mut().enumerate() { + let bits_shift = 8 * i; + *byte = ((word & (0xff << bits_shift)) >> bits_shift) as Byte; + } + bytes +} + +fn rot_word(word: Word) -> Word { + let mut bytes = word_to_bytes(word); + let init = bytes[0]; + bytes[0] = bytes[1]; + bytes[1] = bytes[2]; + bytes[2] = bytes[3]; + bytes[3] = init; + bytes_to_word(&bytes) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_aes_128() { + let plain: [u8; 16] = [ + 0x32, 0x43, 0xf6, 0xa8, 0x88, 0x5a, 0x30, 0x8d, 0x31, 0x31, 0x98, 0xa2, 0xe0, 0x37, + 0x07, 0x34, + ]; + let key: [u8; 16] = [ + 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, + 0x4f, 0x3c, + ]; + let cipher: [u8; 16] = [ + 0x39, 0x25, 0x84, 0x1d, 0x02, 0xdc, 0x09, 0xfb, 0xdc, 0x11, 0x85, 0x97, 0x19, 0x6a, + 0x0b, 0x32, + ]; + let encrypted = aes_encrypt(&plain, AesKey::AesKey128(key)); + assert_eq!(cipher, encrypted[..]); + let decrypted = aes_decrypt(&encrypted, AesKey::AesKey128(key)); + assert_eq!(plain, decrypted[..]); + } + + #[test] + fn test_aes_192() { + let plain: [u8; 16] = [ + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, + 0xee, 0xff, + ]; + let key: [u8; 24] = [ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + ]; + let cipher: [u8; 16] = [ + 0xdd, 0xa9, 0x7c, 0xa4, 0x86, 0x4c, 0xdf, 0xe0, 0x6e, 0xaf, 0x70, 0xa0, 0xec, 0x0d, + 0x71, 0x91, + ]; + let encrypted = aes_encrypt(&plain, AesKey::AesKey192(key)); + assert_eq!(cipher, encrypted[..]); + let decrypted = aes_decrypt(&encrypted, AesKey::AesKey192(key)); + assert_eq!(plain, decrypted[..]); + } + + #[test] + fn test_aes_256() { + let plain: [u8; 16] = [ + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, + 0xee, 0xff, + ]; + let key: [u8; 32] = [ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, + ]; + let cipher: [u8; 16] = [ + 0x8e, 0xa2, 0xb7, 0xca, 0x51, 0x67, 0x45, 0xbf, 0xea, 0xfc, 0x49, 0x90, 0x4b, 0x49, + 0x60, 0x89, + ]; + let encrypted = aes_encrypt(&plain, AesKey::AesKey256(key)); + assert_eq!(cipher, encrypted[..]); + let decrypted = aes_decrypt(&encrypted, AesKey::AesKey256(key)); + assert_eq!(plain, decrypted[..]); + } + + #[test] + fn test_str() { + let str = "Hello, cipher world!"; + let plain = str.as_bytes(); + let key = [ + 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf, + 0x4f, 0x3c, + ]; + let encrypted = aes_encrypt(plain, AesKey::AesKey128(key)); + let decrypted = aes_decrypt(&encrypted, AesKey::AesKey128(key)); + assert_eq!( + str, + String::from_utf8(decrypted).unwrap().trim_end_matches('\0') + ); + } +} diff --git a/src/ciphers/another_rot13.rs b/src/ciphers/another_rot13.rs new file mode 100644 index 00000000000..3e39d976521 --- /dev/null +++ b/src/ciphers/another_rot13.rs @@ -0,0 +1,34 @@ +pub fn another_rot13(text: &str) -> String { + let input = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + let output = "NOPQRSTUVWXYZABCDEFGHIJKLMnopqrstuvwxyzabcdefghijklm"; + text.chars() + .map(|c| match input.find(c) { + Some(i) => output.chars().nth(i).unwrap(), + None => c, + }) + .collect() +} + +#[cfg(test)] +mod tests { + // Note this useful idiom: importing names from outer (for mod tests) scope. + use super::*; + + #[test] + fn test_simple() { + assert_eq!(another_rot13("ABCzyx"), "NOPmlk"); + } + + #[test] + fn test_every_alphabet_with_space() { + assert_eq!( + another_rot13("The quick brown fox jumps over the lazy dog"), + "Gur dhvpx oebja sbk whzcf bire gur ynml qbt" + ); + } + + #[test] + fn test_non_alphabet() { + assert_eq!(another_rot13("🎃 Jack-o'-lantern"), "🎃 Wnpx-b'-ynagrea"); + } +} diff --git a/src/ciphers/baconian_cipher.rs b/src/ciphers/baconian_cipher.rs new file mode 100644 index 00000000000..0ae71cab2cf --- /dev/null +++ b/src/ciphers/baconian_cipher.rs @@ -0,0 +1,70 @@ +// Author : cyrixninja +//Program to encode and decode Baconian or Bacon's Cipher +//Wikipedia reference : https://en.wikipedia.org/wiki/Bacon%27s_cipher +// Bacon's cipher or the Baconian cipher is a method of steganographic message encoding devised by Francis Bacon in 1605. +// A message is concealed in the presentation of text, rather than its content. Bacon cipher is categorized as both a substitution cipher (in plain code) and a concealment cipher (using the two typefaces). + +// Encode Baconian Cipher +pub fn baconian_encode(message: &str) -> String { + let alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + let baconian = [ + "AAAAA", "AAAAB", "AAABA", "AAABB", "AABAA", "AABAB", "AABBA", "AABBB", "ABAAA", "ABAAB", + "ABABA", "ABABB", "ABBAA", "ABBAB", "ABBBA", "ABBBB", "BAAAA", "BAAAB", "BAABA", "BAABB", + "BABAA", "BABAB", "BABBA", "BABBB", + ]; + + message + .chars() + .map(|c| { + if let Some(index) = alphabet.find(c.to_ascii_uppercase()) { + baconian[index].to_string() + } else { + c.to_string() + } + }) + .collect() +} + +// Decode Baconian Cipher +pub fn baconian_decode(encoded: &str) -> String { + let baconian = [ + "AAAAA", "AAAAB", "AAABA", "AAABB", "AABAA", "AABAB", "AABBA", "AABBB", "ABAAA", "ABAAB", + "ABABA", "ABABB", "ABBAA", "ABBAB", "ABBBA", "ABBBB", "BAAAA", "BAAAB", "BAABA", "BAABB", + "BABAA", "BABAB", "BABBA", "BABBB", + ]; + let alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + + encoded + .as_bytes() + .chunks(5) + .map(|chunk| { + if let Some(index) = baconian + .iter() + .position(|&x| x == String::from_utf8_lossy(chunk)) + { + alphabet.chars().nth(index).unwrap() + } else { + ' ' + } + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_baconian_encoding() { + let message = "HELLO"; + let encoded = baconian_encode(message); + assert_eq!(encoded, "AABBBAABAAABABBABABBABBBA"); + } + + #[test] + fn test_baconian_decoding() { + let message = "AABBBAABAAABABBABABBABBBA"; + let decoded = baconian_decode(message); + assert_eq!(decoded, "HELLO"); + } +} diff --git a/src/ciphers/base64.rs b/src/ciphers/base64.rs new file mode 100644 index 00000000000..81d4ac5dd6a --- /dev/null +++ b/src/ciphers/base64.rs @@ -0,0 +1,272 @@ +/* + A Rust implementation of a base64 encoder and decoder. + Written from scratch. +*/ + +// The charset and padding used for en- and decoding. +const CHARSET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +const PADDING: char = '='; + +/* + Combines the two provided bytes into an u16, + and collects 6 bits from it using an AND mask: + + Example: + Bytes: X and Y + (Bits of those bytes will be signified using the names of their byte) + Offset: 4 + + `combined` = 0bXXXXXXXXYYYYYYYY + AND mask: + 0b1111110000000000 >> offset (4) = 0b0000111111000000 + `combined` with mask applied: + 0b0000XXYYYY000000 + Shift the value right by (16 bit number) - (6 bit mask) - (4 offset) = 6: + 0b0000000000XXYYYY + And then turn it into an u8: + 0b00XXYYYY (Return value) +*/ +fn collect_six_bits(from: (u8, u8), offset: u8) -> u8 { + let combined: u16 = ((from.0 as u16) << 8) | (from.1 as u16); + ((combined & (0b1111110000000000u16 >> offset)) >> (10 - offset)) as u8 +} + +pub fn base64_encode(data: &[u8]) -> String { + let mut bits_encoded = 0usize; + let mut encoded_string = String::new(); + // Using modulo twice to prevent an underflow, Wolfram|Alpha says this is optimal + let padding_needed = ((6 - (data.len() * 8) % 6) / 2) % 3; + loop { + let lower_byte_index_to_encode = bits_encoded / 8usize; // Integer division + if lower_byte_index_to_encode == data.len() { + break; + } + let lower_byte_to_encode = data[lower_byte_index_to_encode]; + let upper_byte_to_encode = if (lower_byte_index_to_encode + 1) == data.len() { + 0u8 // Padding + } else { + data[lower_byte_index_to_encode + 1] + }; + let bytes_to_encode = (lower_byte_to_encode, upper_byte_to_encode); + let offset: u8 = (bits_encoded % 8) as u8; + encoded_string.push(CHARSET[collect_six_bits(bytes_to_encode, offset) as usize] as char); + bits_encoded += 6; + } + for _ in 0..padding_needed { + encoded_string.push(PADDING); + } + encoded_string +} + +/* + Performs the exact inverse of the above description of `base64_encode` +*/ +pub fn base64_decode(data: &str) -> Result, (&str, u8)> { + let mut collected_bits = 0; + let mut byte_buffer = 0u16; + let mut databytes = data.bytes(); + let mut outputbytes = Vec::::new(); + 'decodeloop: loop { + while collected_bits < 8 { + if let Some(nextbyte) = databytes.next() { + // Finds the first occurence of the latest byte + if let Some(idx) = CHARSET.iter().position(|&x| x == nextbyte) { + byte_buffer |= ((idx & 0b00111111) as u16) << (10 - collected_bits); + collected_bits += 6; + } else if nextbyte == (PADDING as u8) { + collected_bits -= 2; // Padding only comes at the end so this works + } else { + return Err(( + "Failed to decode base64: Expected byte from charset, found invalid byte.", + nextbyte, + )); + } + } else { + break 'decodeloop; + } + } + outputbytes.push(((0b1111111100000000 & byte_buffer) >> 8) as u8); + byte_buffer &= 0b0000000011111111; + byte_buffer <<= 8; + collected_bits -= 8; + } + if collected_bits != 0 { + return Err(("Failed to decode base64: Invalid padding.", collected_bits)); + } + Ok(outputbytes) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn pregenerated_random_bytes_encode() { + macro_rules! test_encode { + ($left: expr, $right: expr) => { + assert_eq!(base64_encode(&$left.to_vec()), $right); + }; + } + test_encode!( + b"\xd31\xc9\x87D\xfe\xaa\xb3\xff\xef\x8c\x0eoD", + "0zHJh0T+qrP/74wOb0Q=" + ); + test_encode!( + b"\x9f\x0e8\xbc\xf5\xd0-\xb4.\xd4\xf0?\x8f\xe7\t{.\xff/6\xcbTY!\xae9\x82", + "nw44vPXQLbQu1PA/j+cJey7/LzbLVFkhrjmC" + ); + test_encode!(b"\x7f3\x15\x1a\xd3\xf91\x9bS\xa44=", "fzMVGtP5MZtTpDQ9"); + test_encode!( + b"7:\xf5\xd1[\xbfV/P\x18\x03\x00\xdc\xcd\xa1\xecG", + "Nzr10Vu/Vi9QGAMA3M2h7Ec=" + ); + test_encode!( + b"\xc3\xc9\x18={\xc4\x08\x97wN\xda\x81\x84?\x94\xe6\x9e", + "w8kYPXvECJd3TtqBhD+U5p4=" + ); + test_encode!( + b"\x8cJ\xf8e\x13\r\x8fw\xa8\xe6G\xce\x93c*\xe7M\xb6\xd7", + "jEr4ZRMNj3eo5kfOk2Mq50221w==" + ); + test_encode!( + b"\xde\xc4~\xb2}\xb1\x14F.~\xa1z|s\x90\x8dd\x9b\x04\x81\xf2\x92{", + "3sR+sn2xFEYufqF6fHOQjWSbBIHykns=" + ); + test_encode!( + b"\xf0y\t\x14\xd161n\x03e\xed\x0e\x05\xdf\xc1\xb9\xda", + "8HkJFNE2MW4DZe0OBd/Budo=" + ); + test_encode!( + b"*.\x8e\x1d@\x1ac\xdd;\x9a\xcc \x0c\xc2KI", + "Ki6OHUAaY907mswgDMJLSQ==" + ); + test_encode!(b"\xd6\x829\x82\xbc\x00\xc9\xfe\x03", "1oI5grwAyf4D"); + test_encode!( + b"\r\xf2\xb4\xd4\xa1g\x8fhl\xaa@\x98\x00\xda\x95", + "DfK01KFnj2hsqkCYANqV" + ); + test_encode!( + b"\x1a\xfaV\x1a\xc2e\xc0\xad\xef|\x07\xcf\xa9\xb7O", + "GvpWGsJlwK3vfAfPqbdP" + ); + test_encode!(b"\xc20{_\x81\xac", "wjB7X4Gs"); + test_encode!( + b"B\xa85\xac\xe9\x0ev-\x8bT\xb3|\xde", + "Qqg1rOkOdi2LVLN83g==" + ); + test_encode!( + b"\x05\xe0\xeeSs\xfdY9\x0b7\x84\xfc-\xec", + "BeDuU3P9WTkLN4T8Lew=" + ); + test_encode!( + b"Qj\x92\xfa?\xa5\xe3_[\xde\x82\x97{$\xb2\xf9\xd5\x98\x0cy\x15\xe4R\x8d", + "UWqS+j+l419b3oKXeySy+dWYDHkV5FKN" + ); + test_encode!(b"\x853\xe0\xc0\x1d\xc1", "hTPgwB3B"); + test_encode!(b"}2\xd0\x13m\x8d\x8f#\x9c\xf5,\xc7", "fTLQE22NjyOc9SzH"); + } + + #[test] + fn pregenerated_random_bytes_decode() { + macro_rules! test_decode { + ($left: expr, $right: expr) => { + assert_eq!( + base64_decode(&String::from($left)).unwrap(), + $right.to_vec() + ); + }; + } + test_decode!( + "0zHJh0T+qrP/74wOb0Q=", + b"\xd31\xc9\x87D\xfe\xaa\xb3\xff\xef\x8c\x0eoD" + ); + test_decode!( + "nw44vPXQLbQu1PA/j+cJey7/LzbLVFkhrjmC", + b"\x9f\x0e8\xbc\xf5\xd0-\xb4.\xd4\xf0?\x8f\xe7\t{.\xff/6\xcbTY!\xae9\x82" + ); + test_decode!("fzMVGtP5MZtTpDQ9", b"\x7f3\x15\x1a\xd3\xf91\x9bS\xa44="); + test_decode!( + "Nzr10Vu/Vi9QGAMA3M2h7Ec=", + b"7:\xf5\xd1[\xbfV/P\x18\x03\x00\xdc\xcd\xa1\xecG" + ); + test_decode!( + "w8kYPXvECJd3TtqBhD+U5p4=", + b"\xc3\xc9\x18={\xc4\x08\x97wN\xda\x81\x84?\x94\xe6\x9e" + ); + test_decode!( + "jEr4ZRMNj3eo5kfOk2Mq50221w==", + b"\x8cJ\xf8e\x13\r\x8fw\xa8\xe6G\xce\x93c*\xe7M\xb6\xd7" + ); + test_decode!( + "3sR+sn2xFEYufqF6fHOQjWSbBIHykns=", + b"\xde\xc4~\xb2}\xb1\x14F.~\xa1z|s\x90\x8dd\x9b\x04\x81\xf2\x92{" + ); + test_decode!( + "8HkJFNE2MW4DZe0OBd/Budo=", + b"\xf0y\t\x14\xd161n\x03e\xed\x0e\x05\xdf\xc1\xb9\xda" + ); + test_decode!( + "Ki6OHUAaY907mswgDMJLSQ==", + b"*.\x8e\x1d@\x1ac\xdd;\x9a\xcc \x0c\xc2KI" + ); + test_decode!("1oI5grwAyf4D", b"\xd6\x829\x82\xbc\x00\xc9\xfe\x03"); + test_decode!( + "DfK01KFnj2hsqkCYANqV", + b"\r\xf2\xb4\xd4\xa1g\x8fhl\xaa@\x98\x00\xda\x95" + ); + test_decode!( + "GvpWGsJlwK3vfAfPqbdP", + b"\x1a\xfaV\x1a\xc2e\xc0\xad\xef|\x07\xcf\xa9\xb7O" + ); + test_decode!("wjB7X4Gs", b"\xc20{_\x81\xac"); + test_decode!( + "Qqg1rOkOdi2LVLN83g==", + b"B\xa85\xac\xe9\x0ev-\x8bT\xb3|\xde" + ); + test_decode!( + "BeDuU3P9WTkLN4T8Lew=", + b"\x05\xe0\xeeSs\xfdY9\x0b7\x84\xfc-\xec" + ); + test_decode!( + "UWqS+j+l419b3oKXeySy+dWYDHkV5FKN", + b"Qj\x92\xfa?\xa5\xe3_[\xde\x82\x97{$\xb2\xf9\xd5\x98\x0cy\x15\xe4R\x8d" + ); + test_decode!("hTPgwB3B", b"\x853\xe0\xc0\x1d\xc1"); + test_decode!("fTLQE22NjyOc9SzH", b"}2\xd0\x13m\x8d\x8f#\x9c\xf5,\xc7"); + } + + #[test] + fn encode_decode() { + macro_rules! test_e_d { + ($text: expr) => { + assert_eq!( + base64_decode(&base64_encode(&$text.to_vec())).unwrap(), + $text + ); + }; + } + test_e_d!(b"green"); + test_e_d!(b"The quick brown fox jumped over the lazy dog."); + test_e_d!(b"Lorem Ipsum sit dolor amet."); + test_e_d!(b"0"); + test_e_d!(b"01"); + test_e_d!(b"012"); + test_e_d!(b"0123"); + test_e_d!(b"0123456789"); + } + + #[test] + fn decode_encode() { + macro_rules! test_d_e { + ($data: expr) => { + assert_eq!( + base64_encode(&base64_decode(&String::from($data)).unwrap()), + String::from($data) + ); + }; + } + test_d_e!("TG9uZyBsaXZlIGVhc3RlciBlZ2dzIDop"); + test_d_e!("SGFwcHkgSGFja3RvYmVyZmVzdCE="); + test_d_e!("PVRoZSBBbGdvcml0aG1zPQ=="); + } +} diff --git a/src/ciphers/blake2b.rs b/src/ciphers/blake2b.rs new file mode 100644 index 00000000000..c28486489d6 --- /dev/null +++ b/src/ciphers/blake2b.rs @@ -0,0 +1,318 @@ +// For specification go to https://www.rfc-editor.org/rfc/rfc7693 + +use std::cmp::{max, min}; +use std::convert::{TryFrom, TryInto}; + +type Word = u64; + +const BB: usize = 128; + +const U64BYTES: usize = (u64::BITS as usize) / 8; + +type Block = [Word; BB / U64BYTES]; + +const KK_MAX: usize = 64; +const NN_MAX: u8 = 64; + +// Array of round constants used in mixing function G +const RC: [u32; 4] = [32, 24, 16, 63]; + +// IV[i] = floor(2**64 * frac(sqrt(prime(i+1)))) where prime(i) is the ith prime number +const IV: [Word; 8] = [ + 0x6A09E667F3BCC908, + 0xBB67AE8584CAA73B, + 0x3C6EF372FE94F82B, + 0xA54FF53A5F1D36F1, + 0x510E527FADE682D1, + 0x9B05688C2B3E6C1F, + 0x1F83D9ABFB41BD6B, + 0x5BE0CD19137E2179, +]; + +const SIGMA: [[usize; 16]; 10] = [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3], + [11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4], + [7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8], + [9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13], + [2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9], + [12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11], + [13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10], + [6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5], + [10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0], +]; + +#[inline] +const fn blank_block() -> Block { + [0u64; BB / U64BYTES] +} + +// Overflowing addition +#[inline] +fn add(a: &mut Word, b: Word) { + *a = a.overflowing_add(b).0; +} + +#[inline] +const fn ceil(dividend: usize, divisor: usize) -> usize { + (dividend / divisor) + ((dividend % divisor != 0) as usize) +} + +fn g(v: &mut [Word; 16], a: usize, b: usize, c: usize, d: usize, x: Word, y: Word) { + for (m, r) in [x, y].into_iter().zip(RC.chunks(2)) { + let v_b = v[b]; + add(&mut v[a], v_b); + add(&mut v[a], m); + + v[d] = (v[d] ^ v[a]).rotate_right(r[0]); + + let v_d = v[d]; + add(&mut v[c], v_d); + + v[b] = (v[b] ^ v[c]).rotate_right(r[1]); + } +} + +fn f(h: &mut [Word; 8], m: Block, t: u128, flag: bool) { + let mut v: [Word; 16] = [0; 16]; + + for (i, (h_i, iv_i)) in h.iter().zip(IV.iter()).enumerate() { + v[i] = *h_i; + v[i + 8] = *iv_i; + } + + v[12] ^= (t % (u64::MAX as u128)) as u64; + v[13] ^= (t >> 64) as u64; + + if flag { + v[14] = !v[14]; + } + + for i in 0..12 { + let s = SIGMA[i % 10]; + + let mut s_index = 0; + for j in 0..4 { + g( + &mut v, + j, + j + 4, + j + 8, + j + 12, + m[s[s_index]], + m[s[s_index + 1]], + ); + + s_index += 2; + } + + let i1d = |col, row| { + let col = col % 4; + let row = row % 4; + + (row * 4) + col + }; + + for j in 0..4 { + // Produces indeces for diagonals of a 4x4 matrix starting at 0,j + let idx: Vec = (0..4).map(|n| i1d(j + n, n) as usize).collect(); + + g( + &mut v, + idx[0], + idx[1], + idx[2], + idx[3], + m[s[s_index]], + m[s[s_index + 1]], + ); + + s_index += 2; + } + } + + for (i, n) in h.iter_mut().enumerate() { + *n ^= v[i] ^ v[i + 8]; + } +} + +fn blake2(d: Vec, ll: u128, kk: Word, nn: Word) -> Vec { + let mut h: [Word; 8] = IV + .iter() + .take(8) + .copied() + .collect::>() + .try_into() + .unwrap(); + + h[0] ^= 0x01010000u64 ^ (kk << 8) ^ nn; + + if d.len() > 1 { + for (i, w) in d.iter().enumerate().take(d.len() - 1) { + f(&mut h, *w, (i as u128 + 1) * BB as u128, false); + } + } + + let ll = if kk > 0 { ll + BB as u128 } else { ll }; + f(&mut h, d[d.len() - 1], ll, true); + + h.iter() + .flat_map(|n| n.to_le_bytes()) + .take(nn as usize) + .collect() +} + +// Take arbitrarily long slice of u8's and turn up to 8 bytes into u64 +fn bytes_to_word(bytes: &[u8]) -> Word { + if let Ok(arr) = <[u8; U64BYTES]>::try_from(bytes) { + Word::from_le_bytes(arr) + } else { + let mut arr = [0u8; 8]; + for (a_i, b_i) in arr.iter_mut().zip(bytes) { + *a_i = *b_i; + } + + Word::from_le_bytes(arr) + } +} + +pub fn blake2b(m: &[u8], k: &[u8], nn: u8) -> Vec { + let kk = min(k.len(), KK_MAX); + let nn = min(nn, NN_MAX); + + // Prevent user from giving a key that is too long + let k = &k[..kk]; + + let dd = max(ceil(kk, BB) + ceil(m.len(), BB), 1); + + let mut blocks: Vec = vec![blank_block(); dd]; + + // Copy key into blocks + for (w, c) in blocks[0].iter_mut().zip(k.chunks(U64BYTES)) { + *w = bytes_to_word(c); + } + + let first_index = (kk > 0) as usize; + + // Copy bytes from message into blocks + for (i, c) in m.chunks(U64BYTES).enumerate() { + let block_index = first_index + (i / (BB / U64BYTES)); + let word_in_block = i % (BB / U64BYTES); + + blocks[block_index][word_in_block] = bytes_to_word(c); + } + + blake2(blocks, m.len() as u128, kk as u64, nn as Word) +} + +#[cfg(test)] +mod test { + use super::*; + + macro_rules! digest_test { + ($fname:ident, $message:expr, $key:expr, $nn:literal, $expected:expr) => { + #[test] + fn $fname() { + let digest = blake2b($message, $key, $nn); + + let expected = Vec::from($expected); + + assert_eq!(digest, expected); + } + }; + } + + digest_test!( + blake2b_from_rfc, + &[0x61, 0x62, 0x63], + &[0; 0], + 64, + [ + 0xBA, 0x80, 0xA5, 0x3F, 0x98, 0x1C, 0x4D, 0x0D, 0x6A, 0x27, 0x97, 0xB6, 0x9F, 0x12, + 0xF6, 0xE9, 0x4C, 0x21, 0x2F, 0x14, 0x68, 0x5A, 0xC4, 0xB7, 0x4B, 0x12, 0xBB, 0x6F, + 0xDB, 0xFF, 0xA2, 0xD1, 0x7D, 0x87, 0xC5, 0x39, 0x2A, 0xAB, 0x79, 0x2D, 0xC2, 0x52, + 0xD5, 0xDE, 0x45, 0x33, 0xCC, 0x95, 0x18, 0xD3, 0x8A, 0xA8, 0xDB, 0xF1, 0x92, 0x5A, + 0xB9, 0x23, 0x86, 0xED, 0xD4, 0x00, 0x99, 0x23 + ] + ); + + digest_test!( + blake2b_empty, + &[0; 0], + &[0; 0], + 64, + [ + 0x78, 0x6a, 0x02, 0xf7, 0x42, 0x01, 0x59, 0x03, 0xc6, 0xc6, 0xfd, 0x85, 0x25, 0x52, + 0xd2, 0x72, 0x91, 0x2f, 0x47, 0x40, 0xe1, 0x58, 0x47, 0x61, 0x8a, 0x86, 0xe2, 0x17, + 0xf7, 0x1f, 0x54, 0x19, 0xd2, 0x5e, 0x10, 0x31, 0xaf, 0xee, 0x58, 0x53, 0x13, 0x89, + 0x64, 0x44, 0x93, 0x4e, 0xb0, 0x4b, 0x90, 0x3a, 0x68, 0x5b, 0x14, 0x48, 0xb7, 0x55, + 0xd5, 0x6f, 0x70, 0x1a, 0xfe, 0x9b, 0xe2, 0xce + ] + ); + + digest_test!( + blake2b_empty_with_key, + &[0; 0], + &[ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, + 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f + ], + 64, + [ + 0x10, 0xeb, 0xb6, 0x77, 0x00, 0xb1, 0x86, 0x8e, 0xfb, 0x44, 0x17, 0x98, 0x7a, 0xcf, + 0x46, 0x90, 0xae, 0x9d, 0x97, 0x2f, 0xb7, 0xa5, 0x90, 0xc2, 0xf0, 0x28, 0x71, 0x79, + 0x9a, 0xaa, 0x47, 0x86, 0xb5, 0xe9, 0x96, 0xe8, 0xf0, 0xf4, 0xeb, 0x98, 0x1f, 0xc2, + 0x14, 0xb0, 0x05, 0xf4, 0x2d, 0x2f, 0xf4, 0x23, 0x34, 0x99, 0x39, 0x16, 0x53, 0xdf, + 0x7a, 0xef, 0xcb, 0xc1, 0x3f, 0xc5, 0x15, 0x68 + ] + ); + + digest_test!( + blake2b_key_shortin, + &[0], + &[ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, + 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f + ], + 64, + [ + 0x96, 0x1f, 0x6d, 0xd1, 0xe4, 0xdd, 0x30, 0xf6, 0x39, 0x01, 0x69, 0x0c, 0x51, 0x2e, + 0x78, 0xe4, 0xb4, 0x5e, 0x47, 0x42, 0xed, 0x19, 0x7c, 0x3c, 0x5e, 0x45, 0xc5, 0x49, + 0xfd, 0x25, 0xf2, 0xe4, 0x18, 0x7b, 0x0b, 0xc9, 0xfe, 0x30, 0x49, 0x2b, 0x16, 0xb0, + 0xd0, 0xbc, 0x4e, 0xf9, 0xb0, 0xf3, 0x4c, 0x70, 0x03, 0xfa, 0xc0, 0x9a, 0x5e, 0xf1, + 0x53, 0x2e, 0x69, 0x43, 0x02, 0x34, 0xce, 0xbd + ] + ); + + digest_test!( + blake2b_keyed_filled, + &[ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, + 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f + ], + &[ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, + 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f + ], + 64, + [ + 0x65, 0x67, 0x6d, 0x80, 0x06, 0x17, 0x97, 0x2f, 0xbd, 0x87, 0xe4, 0xb9, 0x51, 0x4e, + 0x1c, 0x67, 0x40, 0x2b, 0x7a, 0x33, 0x10, 0x96, 0xd3, 0xbf, 0xac, 0x22, 0xf1, 0xab, + 0xb9, 0x53, 0x74, 0xab, 0xc9, 0x42, 0xf1, 0x6e, 0x9a, 0xb0, 0xea, 0xd3, 0x3b, 0x87, + 0xc9, 0x19, 0x68, 0xa6, 0xe5, 0x09, 0xe1, 0x19, 0xff, 0x07, 0x78, 0x7b, 0x3e, 0xf4, + 0x83, 0xe1, 0xdc, 0xdc, 0xcf, 0x6e, 0x30, 0x22 + ] + ); +} diff --git a/src/ciphers/caesar.rs b/src/ciphers/caesar.rs index d86a7e36b38..970ae575fce 100644 --- a/src/ciphers/caesar.rs +++ b/src/ciphers/caesar.rs @@ -1,43 +1,118 @@ -//! Caesar Cipher -//! Based on cipher_crypt::caesar -//! -//! # Algorithm -//! -//! Rotate each ascii character by shift. The most basic example is ROT 13, which rotates 'a' to -//! 'n'. This implementation does not rotate unicode characters. - -/// Caesar cipher to rotate cipher text by shift and return an owned String. -pub fn caesar(cipher: &str, shift: u8) -> String { - cipher +const ERROR_MESSAGE: &str = "Rotation must be in the range [0, 25]"; +const ALPHABET_LENGTH: u8 = b'z' - b'a' + 1; + +/// Encrypts a given text using the Caesar cipher technique. +/// +/// In cryptography, a Caesar cipher, also known as Caesar's cipher, the shift cipher, Caesar's code, +/// or Caesar shift, is one of the simplest and most widely known encryption techniques. +/// It is a type of substitution cipher in which each letter in the plaintext is replaced by a letter +/// some fixed number of positions down the alphabet. +/// +/// # Arguments +/// +/// * `text` - The text to be encrypted. +/// * `rotation` - The number of rotations (shift) to be applied. It should be within the range [0, 25]. +/// +/// # Returns +/// +/// Returns a `Result` containing the encrypted string if successful, or an error message if the rotation +/// is out of the valid range. +/// +/// # Errors +/// +/// Returns an error if the rotation value is out of the valid range [0, 25] +pub fn caesar(text: &str, rotation: isize) -> Result { + if !(0..ALPHABET_LENGTH as isize).contains(&rotation) { + return Err(ERROR_MESSAGE); + } + + let result = text .chars() .map(|c| { if c.is_ascii_alphabetic() { - let first = if c.is_ascii_lowercase() { b'a' } else { b'A' }; - // modulo the distance to keep character range - (first + (c as u8 + shift - first) % 26) as char + shift_char(c, rotation) } else { c } }) - .collect() + .collect(); + + Ok(result) +} + +/// Shifts a single ASCII alphabetic character by a specified number of positions in the alphabet. +/// +/// # Arguments +/// +/// * `c` - The ASCII alphabetic character to be shifted. +/// * `rotation` - The number of positions to shift the character. Should be within the range [0, 25]. +/// +/// # Returns +/// +/// Returns the shifted ASCII alphabetic character. +fn shift_char(c: char, rotation: isize) -> char { + let first = if c.is_ascii_lowercase() { b'a' } else { b'A' }; + let rotation = rotation as u8; // Safe cast as rotation is within [0, 25] + + (((c as u8 - first) + rotation) % ALPHABET_LENGTH + first) as char } #[cfg(test)] mod tests { use super::*; - #[test] - fn empty() { - assert_eq!(caesar("", 13), ""); + macro_rules! test_caesar_happy_path { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (text, rotation, expected) = $test_case; + assert_eq!(caesar(&text, rotation).unwrap(), expected); + + let backward_rotation = if rotation == 0 { 0 } else { ALPHABET_LENGTH as isize - rotation }; + assert_eq!(caesar(&expected, backward_rotation).unwrap(), text); + } + )* + }; } - #[test] - fn caesar_rot_13() { - assert_eq!(caesar("rust", 13), "ehfg"); + macro_rules! test_caesar_error_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (text, rotation) = $test_case; + assert_eq!(caesar(&text, rotation), Err(ERROR_MESSAGE)); + } + )* + }; } #[test] - fn caesar_unicode() { - assert_eq!(caesar("attack at dawn 攻", 5), "fyyfhp fy ifbs 攻"); + fn alphabet_length_should_be_26() { + assert_eq!(ALPHABET_LENGTH, 26); + } + + test_caesar_happy_path! { + empty_text: ("", 13, ""), + rot_13: ("rust", 13, "ehfg"), + unicode: ("attack at dawn 攻", 5, "fyyfhp fy ifbs 攻"), + rotation_within_alphabet_range: ("Hello, World!", 3, "Khoor, Zruog!"), + no_rotation: ("Hello, World!", 0, "Hello, World!"), + rotation_at_alphabet_end: ("Hello, World!", 25, "Gdkkn, Vnqkc!"), + longer: ("The quick brown fox jumps over the lazy dog.", 5, "Ymj vznhp gwtbs ktc ozrux tajw ymj qfed itl."), + non_alphabetic_characters: ("12345!@#$%", 3, "12345!@#$%"), + uppercase_letters: ("ABCDEFGHIJKLMNOPQRSTUVWXYZ", 1, "BCDEFGHIJKLMNOPQRSTUVWXYZA"), + mixed_case: ("HeLlO WoRlD", 7, "OlSsV DvYsK"), + with_whitespace: ("Hello, World!", 13, "Uryyb, Jbeyq!"), + with_special_characters: ("Hello!@#$%^&*()_+World", 4, "Lipps!@#$%^&*()_+Asvph"), + with_numbers: ("Abcd1234XYZ", 10, "Klmn1234HIJ"), + } + + test_caesar_error_cases! { + negative_rotation: ("Hello, World!", -5), + empty_input_negative_rotation: ("", -1), + empty_input_large_rotation: ("", 27), + large_rotation: ("Large rotation", 139), } } diff --git a/src/ciphers/chacha.rs b/src/ciphers/chacha.rs new file mode 100644 index 00000000000..6b0440a9d11 --- /dev/null +++ b/src/ciphers/chacha.rs @@ -0,0 +1,150 @@ +macro_rules! quarter_round { + ($a:expr,$b:expr,$c:expr,$d:expr) => { + $a = $a.wrapping_add($b); + $d = ($d ^ $a).rotate_left(16); + $c = $c.wrapping_add($d); + $b = ($b ^ $c).rotate_left(12); + $a = $a.wrapping_add($b); + $d = ($d ^ $a).rotate_left(8); + $c = $c.wrapping_add($d); + $b = ($b ^ $c).rotate_left(7); + }; +} + +#[allow(dead_code)] +// "expand 32-byte k", written in little-endian order +pub const C: [u32; 4] = [0x61707865, 0x3320646e, 0x79622d32, 0x6b206574]; + +/// ChaCha20 implementation based on RFC8439 +/// +/// ChaCha20 is a stream cipher developed independently by Daniel J. Bernstein.\ +/// To use it, the `chacha20` function should be called with appropriate +/// parameters and the output of the function should be XORed with plain text. +/// +/// `chacha20` function takes as input an array of 16 32-bit integers (512 bits) +/// of which 128 bits is the constant 'expand 32-byte k', 256 bits is the key, +/// and 128 bits are nonce and counter. According to RFC8439, the nonce should +/// be 96 bits long, which leaves 32 bits for the counter. Given that the block +/// length is 512 bits, this leaves enough counter values to encrypt 256GB of +/// data. +/// +/// The 16 input numbers can be thought of as the elements of a 4x4 matrix like +/// the one bellow, on which we do the main operations of the cipher. +/// +/// ```text +/// +----+----+----+----+ +/// | 00 | 01 | 02 | 03 | +/// +----+----+----+----+ +/// | 04 | 05 | 06 | 07 | +/// +----+----+----+----+ +/// | 08 | 09 | 10 | 11 | +/// +----+----+----+----+ +/// | 12 | 13 | 14 | 15 | +/// +----+----+----+----+ +/// ``` +/// +/// As per the diagram bellow, `input[0, 1, 2, 3]` are the constants mentioned +/// above, `input[4..=11]` is filled with the key, and `input[6..=9]` should be +/// filled with nonce and counter values. The output of the function is stored +/// in `output` variable and can be XORed with the plain text to produce the +/// cipher text. +/// +/// ```text +/// +------+------+------+------+ +/// | | | | | +/// | C[0] | C[1] | C[2] | C[3] | +/// | | | | | +/// +------+------+------+------+ +/// | | | | | +/// | key0 | key1 | key2 | key3 | +/// | | | | | +/// +------+------+------+------+ +/// | | | | | +/// | key4 | key5 | key6 | key7 | +/// | | | | | +/// +------+------+------+------+ +/// | | | | | +/// | ctr0 | no.0 | no.1 | no.2 | +/// | | | | | +/// +------+------+------+------+ +/// ``` +/// +/// Note that the constants, the key, and the nonce should be written in +/// little-endian order, meaning that for example if the key is 01:02:03:04 +/// (in hex), it corresponds to the integer `0x04030201`. It is important to +/// know that the hex value of the counter is meaningless, and only its integer +/// value matters, and it should start with (for example) `0x00000000`, and then +/// `0x00000001` and so on until `0xffffffff`. Keep in mind that as soon as we get +/// from bytes to words, we stop caring about their representation in memory, +/// and we only need the math to be correct. +/// +/// The output of the function can be used without any change, as long as the +/// plain text has the same endianness. For example if the plain text is +/// "hello world", and the first word of the output is `0x01020304`, then the +/// first byte of plain text ('h') should be XORed with the least-significant +/// byte of `0x01020304`, which is `0x04`. +pub fn chacha20(input: &[u32; 16], output: &mut [u32; 16]) { + output.copy_from_slice(&input[..]); + for _ in 0..10 { + // Odd round (column round) + quarter_round!(output[0], output[4], output[8], output[12]); // column 1 + quarter_round!(output[1], output[5], output[9], output[13]); // column 2 + quarter_round!(output[2], output[6], output[10], output[14]); // column 3 + quarter_round!(output[3], output[7], output[11], output[15]); // column 4 + + // Even round (diagonal round) + quarter_round!(output[0], output[5], output[10], output[15]); // diag 1 + quarter_round!(output[1], output[6], output[11], output[12]); // diag 2 + quarter_round!(output[2], output[7], output[8], output[13]); // diag 3 + quarter_round!(output[3], output[4], output[9], output[14]); // diag 4 + } + for (a, &b) in output.iter_mut().zip(input.iter()) { + *a = a.wrapping_add(b); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fmt::Write; + + fn output_hex(inp: &[u32; 16]) -> String { + let mut res = String::new(); + res.reserve(512 / 4); + for &x in inp { + write!(&mut res, "{x:08x}").unwrap(); + } + res + } + + #[test] + // test vector 1 + fn basic_tv1() { + let mut inp = [0u32; 16]; + let mut out = [0u32; 16]; + inp[0] = C[0]; + inp[1] = C[1]; + inp[2] = C[2]; + inp[3] = C[3]; + inp[4] = 0x03020100; // The key is 00:01:02:..:1f (hex) + inp[5] = 0x07060504; + inp[6] = 0x0b0a0908; + inp[7] = 0x0f0e0d0c; + inp[8] = 0x13121110; + inp[9] = 0x17161514; + inp[10] = 0x1b1a1918; + inp[11] = 0x1f1e1d1c; + inp[12] = 0x00000001; // The value of counter is 1 (an integer). Nonce: + inp[13] = 0x09000000; // 00:00:00:09 + inp[14] = 0x4a000000; // 00:00:00:4a + inp[15] = 0x00000000; // 00:00:00:00 + chacha20(&inp, &mut out); + assert_eq!( + output_hex(&out), + concat!( + "e4e7f11015593bd11fdd0f50c47120a3c7f4d1c70368c0339aaa22044e6cd4c3", + "466482d209aa9f0705d7c214a2028bd9d19c12b5b94e16dee883d0cb4e3c50a2" + ) + ); + } +} diff --git a/src/ciphers/diffie_hellman.rs b/src/ciphers/diffie_hellman.rs new file mode 100644 index 00000000000..3cfe53802bb --- /dev/null +++ b/src/ciphers/diffie_hellman.rs @@ -0,0 +1,347 @@ +// Based on the TheAlgorithms/Python +// RFC 3526 - More Modular Exponential (MODP) Diffie-Hellman groups for +// Internet Key Exchange (IKE) https://tools.ietf.org/html/rfc3526 + +use num_bigint::BigUint; +use num_traits::{Num, Zero}; +use std::{ + collections::HashMap, + sync::LazyLock, + time::{SystemTime, UNIX_EPOCH}, +}; + +// A map of predefined prime numbers for different bit lengths, as specified in RFC 3526 +static PRIMES: LazyLock> = LazyLock::new(|| { + let mut m: HashMap = HashMap::new(); + m.insert( + // 1536-bit + 5, + BigUint::parse_bytes( + b"FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1\ + 29024E088A67CC74020BBEA63B139B22514A08798E3404DD\ + EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245\ + E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED\ + EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3D\ + C2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F\ + 83655D23DCA3AD961C62F356208552BB9ED529077096966D\ + 670C354E4ABC9804F1746C08CA237327FFFFFFFFFFFFFFFF", + 16, + ) + .unwrap(), + ); + m.insert( + // 2048-bit + 14, + BigUint::parse_bytes( + b"FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1\ + 29024E088A67CC74020BBEA63B139B22514A08798E3404DD\ + EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245\ + E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED\ + EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3D\ + C2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F\ + 83655D23DCA3AD961C62F356208552BB9ED529077096966D\ + 670C354E4ABC9804F1746C08CA18217C32905E462E36CE3B\ + E39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9\ + DE2BCBF6955817183995497CEA956AE515D2261898FA0510\ + 15728E5A8AACAA68FFFFFFFFFFFFFFFF", + 16, + ) + .unwrap(), + ); + + m.insert( + // 3072-bit + 15, + BigUint::parse_bytes( + b"FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1\ + 29024E088A67CC74020BBEA63B139B22514A08798E3404DD\ + EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245\ + E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED\ + EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3D\ + C2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F\ + 83655D23DCA3AD961C62F356208552BB9ED529077096966D\ + 670C354E4ABC9804F1746C08CA18217C32905E462E36CE3B\ + E39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9\ + DE2BCBF6955817183995497CEA956AE515D2261898FA0510\ + 15728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64\ + ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7\ + ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6B\ + F12FFA06D98A0864D87602733EC86A64521F2B18177B200C\ + BBE117577A615D6C770988C0BAD946E208E24FA074E5AB31\ + 43DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF", + 16, + ) + .unwrap(), + ); + m.insert( + // 4096-bit + 16, + BigUint::parse_bytes( + b"FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1\ + 29024E088A67CC74020BBEA63B139B22514A08798E3404DD\ + EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245\ + E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED\ + EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3D\ + C2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F\ + 83655D23DCA3AD961C62F356208552BB9ED529077096966D\ + 670C354E4ABC9804F1746C08CA18217C32905E462E36CE3B\ + E39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9\ + DE2BCBF6955817183995497CEA956AE515D2261898FA0510\ + 15728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64\ + ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7\ + ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6B\ + F12FFA06D98A0864D87602733EC86A64521F2B18177B200C\ + BBE117577A615D6C770988C0BAD946E208E24FA074E5AB31\ + 43DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D7\ + 88719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA\ + 2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6\ + 287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED\ + 1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA9\ + 93B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199\ + FFFFFFFFFFFFFFFF", + 16, + ) + .unwrap(), + ); + m.insert( + // 6144-bit + 17, + BigUint::parse_bytes( + b"FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E08\ + 8A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B\ + 302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9\ + A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE6\ + 49286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8\ + FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D\ + 670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C\ + 180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF695581718\ + 3995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D\ + 04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7D\ + B3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D226\ + 1AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200C\ + BBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFC\ + E0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B26\ + 99C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB\ + 04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2\ + 233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127\ + D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934028492\ + 36C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406\ + AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918\ + DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B33205151\ + 2BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03\ + F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97F\ + BEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AA\ + CC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58B\ + B7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632\ + 387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E\ + 6DCC4024FFFFFFFFFFFFFFFF", + 16, + ) + .unwrap(), + ); + + m.insert( + // 8192-bit + 18, + BigUint::parse_bytes( + b"FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1\ + 29024E088A67CC74020BBEA63B139B22514A08798E3404DD\ + EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245\ + E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED\ + EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3D\ + C2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F\ + 83655D23DCA3AD961C62F356208552BB9ED529077096966D\ + 670C354E4ABC9804F1746C08CA18217C32905E462E36CE3B\ + E39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9\ + DE2BCBF6955817183995497CEA956AE515D2261898FA0510\ + 15728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64\ + ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7\ + ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6B\ + F12FFA06D98A0864D87602733EC86A64521F2B18177B200C\ + BBE117577A615D6C770988C0BAD946E208E24FA074E5AB31\ + 43DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D7\ + 88719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA\ + 2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6\ + 287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED\ + 1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA9\ + 93B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934028492\ + 36C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BD\ + F8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831\ + 179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1B\ + DB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF\ + 5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6\ + D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F3\ + 23A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AA\ + CC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE328\ + 06A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55C\ + DA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE\ + 12BF2D5B0B7474D6E694F91E6DBE115974A3926F12FEE5E4\ + 38777CB6A932DF8CD8BEC4D073B931BA3BC832B68D9DD300\ + 741FA7BF8AFC47ED2576F6936BA424663AAB639C5AE4F568\ + 3423B4742BF1C978238F16CBE39D652DE3FDB8BEFC848AD9\ + 22222E04A4037C0713EB57A81A23F0C73473FC646CEA306B\ + 4BCBC8862F8385DDFA9D4B7FA2C087E879683303ED5BDD3A\ + 062B3CF5B3A278A66D2A13F83F44F82DDF310EE074AB6A36\ + 4597E899A0255DC164F31CC50846851DF9AB48195DED7EA1\ + B1D510BD7EE74D73FAF36BC31ECFA268359046F4EB879F92\ + 4009438B481C6CD7889A002ED5EE382BC9190DA6FC026E47\ + 9558E4475677E9AA9E3050E2765694DFC81F56E880B96E71\ + 60C980DD98EDD3DFFFFFFFFFFFFFFFFF", + 16, + ) + .unwrap(), + ); + m +}); + +/// Generating random number, should use num_bigint::RandomBits if possible. +fn rand() -> usize { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .subsec_nanos() as usize +} + +pub struct DiffieHellman { + prime: BigUint, + private_key: BigUint, + public_key: BigUint, + generator: u8, +} + +impl DiffieHellman { + // Diffie-Hellman key exchange algorithm is based on the following mathematical concepts: + + // - A large prime number p (known as the prime modulus) is chosen and shared by both parties. + + // - A base number g (known as the generator) is chosen and shared by both parties. + + // - Each party generates a private key a or b (which are secret and only known to that party) and calculates a corresponding public key A or B using the following formulas: + // - A = g^a mod p + // - B = g^b mod p + + // - Each party then exchanges their public keys with each other. + + // - Each party then calculates the shared secret key s using the following formula: + // - s = B^a mod p or s = A^b mod p + + // Both parties now have the same shared secret key s which can be used for encryption or authentication. + + pub fn new(group: Option) -> Self { + let mut _group: u8 = 14; + if let Some(x) = group { + _group = x; + } + + if !PRIMES.contains_key(&_group) { + panic!("group not in primes") + } + + // generate private key + let private_key: BigUint = BigUint::from(rand()); + + Self { + prime: PRIMES[&_group].clone(), + private_key, + generator: 2, // the generator is 2 for all the primes if this would not be the case it can be added to hashmap + public_key: BigUint::default(), + } + } + + /// get private key as hexadecimal String + pub fn get_private_key(&self) -> String { + self.private_key.to_str_radix(16) + } + + /// Generate public key A = g**a mod p + pub fn generate_public_key(&mut self) -> String { + self.public_key = BigUint::from(self.generator).modpow(&self.private_key, &self.prime); + self.public_key.to_str_radix(16) + } + + pub fn is_valid_public_key(&self, key_str: &str) -> bool { + // the unwrap_or_else will make sure it is false, because 2 <= 0 and therefor False is returned + let key = BigUint::from_str_radix(key_str, 16) + .unwrap_or_else(|_| BigUint::parse_bytes(b"0", 16).unwrap()); + + // Check if the other public key is valid based on NIST SP800-56 + if BigUint::from(2_u8) <= key + && key <= &self.prime - BigUint::from(2_u8) + && !key + .modpow( + &((&self.prime - BigUint::from(1_u8)) / BigUint::from(2_u8)), + &self.prime, + ) + .is_zero() + { + return true; + } + false + } + + /// Generate the shared key + pub fn generate_shared_key(self, other_key_str: &str) -> Option { + let other_key = BigUint::from_str_radix(other_key_str, 16) + .unwrap_or_else(|_| BigUint::parse_bytes(b"0", 16).unwrap()); + if !self.is_valid_public_key(&other_key.to_str_radix(16)) { + return None; + } + + let shared_key = other_key.modpow(&self.private_key, &self.prime); + Some(shared_key.to_str_radix(16)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn verify_invalid_pub_key() { + let diffie = DiffieHellman::new(Some(14)); + assert!(!diffie.is_valid_public_key("0000")); + } + + #[test] + fn verify_valid_pub_key() { + let diffie = DiffieHellman::new(Some(14)); + assert!(diffie.is_valid_public_key("EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245")); + } + + #[test] + fn verify_invalid_pub_key_same_as_prime() { + let diffie = DiffieHellman::new(Some(14)); + assert!(!diffie.is_valid_public_key( + "FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1\ + 29024E088A67CC74020BBEA63B139B22514A08798E3404DD\ + EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245\ + E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED\ + EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3D\ + C2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F\ + 83655D23DCA3AD961C62F356208552BB9ED529077096966D\ + 670C354E4ABC9804F1746C08CA18217C32905E462E36CE3B\ + E39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9\ + DE2BCBF6955817183995497CEA956AE515D2261898FA0510\ + 15728E5A8AACAA68FFFFFFFFFFFFFFFF" + )); + } + + #[test] + fn verify_key_exchange() { + let mut alice = DiffieHellman::new(Some(16)); + let mut bob = DiffieHellman::new(Some(16)); + + // Private key not used, showed for illustrative purpose + let _alice_private = alice.get_private_key(); + let alice_public = alice.generate_public_key(); + + // Private key not used, showed for illustrative purpose + let _bob_private = bob.get_private_key(); + let bob_public = bob.generate_public_key(); + + // generating shared key using the struct implemenations + let alice_shared = alice.generate_shared_key(bob_public.as_str()).unwrap(); + let bob_shared = bob.generate_shared_key(alice_public.as_str()).unwrap(); + assert_eq!(alice_shared, bob_shared); + } +} diff --git a/src/ciphers/hashing_traits.rs b/src/ciphers/hashing_traits.rs new file mode 100644 index 00000000000..af8b8391f09 --- /dev/null +++ b/src/ciphers/hashing_traits.rs @@ -0,0 +1,89 @@ +pub trait Hasher { + /// return a new instance with default parameters + fn new_default() -> Self; + + /// Add new data + fn update(&mut self, data: &[u8]); + + /// Returns the hash of current data. If it is necessary does finalization + /// work on the instance, thus it may no longer make sense to do `update` + /// after calling this. + fn get_hash(&mut self) -> [u8; DIGEST_BYTES]; +} + +/// HMAC based on RFC2104, applicable to many cryptographic hash functions +pub struct HMAC> { + pub inner_internal_state: H, + pub outer_internal_state: H, +} + +impl> + HMAC +{ + pub fn new_default() -> Self { + HMAC { + inner_internal_state: H::new_default(), + outer_internal_state: H::new_default(), + } + } + + /// Note that `key` must be no longer than `KEY_BYTES`. According to RFC, + /// if it is so, you should replace it with its hash. We do not do this + /// automatically due to fear of `DIGEST_BYTES` not being the same as + /// `KEY_BYTES` or even being longer than it + pub fn add_key(&mut self, key: &[u8]) -> Result<(), &'static str> { + match key.len().cmp(&KEY_BYTES) { + std::cmp::Ordering::Less | std::cmp::Ordering::Equal => { + let mut tmp_key = [0u8; KEY_BYTES]; + for (d, s) in tmp_key.iter_mut().zip(key.iter()) { + *d = *s; + } + // key ^ 0x363636.. should be used as inner key + for b in tmp_key.iter_mut() { + *b ^= 0x36; + } + self.inner_internal_state.update(&tmp_key); + // key ^ 0x5c5c5c.. should be used as outer key, but the key is + // already XORed with 0x363636.. , so it must now be XORed with + // 0x6a6a6a.. + for b in tmp_key.iter_mut() { + *b ^= 0x6a; + } + self.outer_internal_state.update(&tmp_key); + Ok(()) + } + _ => Err("Key is longer than `KEY_BYTES`."), + } + } + + pub fn update(&mut self, data: &[u8]) { + self.inner_internal_state.update(data); + } + + pub fn finalize(&mut self) -> [u8; DIGEST_BYTES] { + self.outer_internal_state + .update(&self.inner_internal_state.get_hash()); + self.outer_internal_state.get_hash() + } +} + +#[cfg(test)] +mod tests { + use super::super::sha256::tests::get_hash_string; + use super::super::SHA256; + use super::HMAC; + + #[test] + fn sha256_basic() { + // To test this, use the following command on linux: + // echo -n "Hello World" | openssl sha256 -hex -mac HMAC -macopt hexkey:"deadbeef" + let mut hmac: HMAC<64, 32, SHA256> = HMAC::new_default(); + hmac.add_key(&[0xde, 0xad, 0xbe, 0xef]).unwrap(); + hmac.update(b"Hello World"); + let hash = hmac.finalize(); + assert_eq!( + get_hash_string(&hash), + "f585fc4536e8e7f378437465b65b6c2eb79036409b18a7d28b6d4c46d3a156f8" + ); + } +} diff --git a/src/ciphers/kerninghan.rs b/src/ciphers/kerninghan.rs new file mode 100644 index 00000000000..4263850ff3f --- /dev/null +++ b/src/ciphers/kerninghan.rs @@ -0,0 +1,23 @@ +pub fn kerninghan(n: u32) -> i32 { + let mut count = 0; + let mut n = n; + + while n > 0 { + n = n & (n - 1); + count += 1; + } + + count +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn count_set_bits() { + assert_eq!(kerninghan(0b0000_0000_0000_0000_0000_0000_0000_1011), 3); + assert_eq!(kerninghan(0b0000_0000_0000_0000_0000_0000_1000_0000), 1); + assert_eq!(kerninghan(0b1111_1111_1111_1111_1111_1111_1111_1101), 31); + } +} diff --git a/src/ciphers/mod.rs b/src/ciphers/mod.rs index bc628eca1f4..f7a55b0014d 100644 --- a/src/ciphers/mod.rs +++ b/src/ciphers/mod.rs @@ -1,7 +1,45 @@ +mod aes; +mod another_rot13; +mod baconian_cipher; +mod base64; +mod blake2b; mod caesar; +mod chacha; +mod diffie_hellman; +mod hashing_traits; +mod kerninghan; +mod morse_code; +mod polybius; +mod rail_fence; mod rot13; +mod salsa; +mod sha256; +mod sha3; +mod tea; +mod theoretical_rot13; +mod transposition; mod vigenere; - +mod xor; +pub use self::aes::{aes_decrypt, aes_encrypt, AesKey}; +pub use self::another_rot13::another_rot13; +pub use self::baconian_cipher::{baconian_decode, baconian_encode}; +pub use self::base64::{base64_decode, base64_encode}; +pub use self::blake2b::blake2b; pub use self::caesar::caesar; +pub use self::chacha::chacha20; +pub use self::diffie_hellman::DiffieHellman; +pub use self::hashing_traits::Hasher; +pub use self::hashing_traits::HMAC; +pub use self::kerninghan::kerninghan; +pub use self::morse_code::{decode, encode}; +pub use self::polybius::{decode_ascii, encode_ascii}; +pub use self::rail_fence::{rail_fence_decrypt, rail_fence_encrypt}; pub use self::rot13::rot13; +pub use self::salsa::salsa20; +pub use self::sha256::SHA256; +pub use self::sha3::{sha3_224, sha3_256, sha3_384, sha3_512}; +pub use self::tea::{tea_decrypt, tea_encrypt}; +pub use self::theoretical_rot13::theoretical_rot13; +pub use self::transposition::transposition; pub use self::vigenere::vigenere; +pub use self::xor::xor; diff --git a/src/ciphers/morse_code.rs b/src/ciphers/morse_code.rs new file mode 100644 index 00000000000..c1ecaa5b2ad --- /dev/null +++ b/src/ciphers/morse_code.rs @@ -0,0 +1,188 @@ +use std::collections::HashMap; +use std::io; + +const UNKNOWN_CHARACTER: &str = "........"; +const _UNKNOWN_MORSE_CHARACTER: &str = "_"; + +pub fn encode(message: &str) -> String { + let dictionary = _morse_dictionary(); + message + .chars() + .map(|char| char.to_uppercase().to_string()) + .map(|letter| dictionary.get(letter.as_str())) + .map(|option| (*option.unwrap_or(&UNKNOWN_CHARACTER)).to_string()) + .collect::>() + .join(" ") +} + +// Declarative macro for creating readable map declarations, for more info see https://doc.rust-lang.org/book/ch19-06-macros.html +macro_rules! map { + ($($key:expr => $value:expr),* $(,)?) => { + std::iter::Iterator::collect(IntoIterator::into_iter([$(($key, $value),)*])) + }; +} + +fn _morse_dictionary() -> HashMap<&'static str, &'static str> { + map! { + "A" => ".-", "B" => "-...", "C" => "-.-.", + "D" => "-..", "E" => ".", "F" => "..-.", + "G" => "--.", "H" => "....", "I" => "..", + "J" => ".---", "K" => "-.-", "L" => ".-..", + "M" => "--", "N" => "-.", "O" => "---", + "P" => ".--.", "Q" => "--.-", "R" => ".-.", + "S" => "...", "T" => "-", "U" => "..-", + "V" => "...-", "W" => ".--", "X" => "-..-", + "Y" => "-.--", "Z" => "--..", + + "1" => ".----", "2" => "..---", "3" => "...--", + "4" => "....-", "5" => ".....", "6" => "-....", + "7" => "--...", "8" => "---..", "9" => "----.", + "0" => "-----", + + "&" => ".-...", "@" => ".--.-.", ":" => "---...", + "," => "--..--", "." => ".-.-.-", "'" => ".----.", + "\"" => ".-..-.", "?" => "..--..", "/" => "-..-.", + "=" => "-...-", "+" => ".-.-.", "-" => "-....-", + "(" => "-.--.", ")" => "-.--.-", " " => "/", + "!" => "-.-.--", + } +} + +fn _morse_to_alphanumeric_dictionary() -> HashMap<&'static str, &'static str> { + map! { + ".-" => "A", "-..." => "B", "-.-." => "C", + "-.." => "D", "." => "E", "..-." => "F", + "--." => "G", "...." => "H", ".." => "I", + ".---" => "J", "-.-" => "K", ".-.." => "L", + "--" => "M", "-." => "N", "---" => "O", + ".--." => "P", "--.-" => "Q", ".-." => "R", + "..." => "S", "-" => "T", "..-" => "U", + "...-" => "V", ".--" => "W", "-..-" => "X", + "-.--" => "Y", "--.." => "Z", + + ".----" => "1", "..---" => "2", "...--" => "3", + "....-" => "4", "....." => "5", "-...." => "6", + "--..." => "7", "---.." => "8", "----." => "9", + "-----" => "0", + + ".-..." => "&", ".--.-." => "@", "---..." => ":", + "--..--" => ",", ".-.-.-" => ".", ".----." => "'", + ".-..-." => "\"", "..--.." => "?", "-..-." => "/", + "-...-" => "=", ".-.-." => "+", "-....-" => "-", + "-.--." => "(", "-.--.-" => ")", "/" => " ", + "-.-.--" => "!", " " => " ", "" => "" + } +} + +fn _check_part(string: &str) -> bool { + for c in string.chars() { + match c { + '.' | '-' | ' ' => (), + _ => return false, + } + } + true +} + +fn _check_all_parts(string: &str) -> bool { + string.split('/').all(_check_part) +} + +fn _decode_token(string: &str) -> String { + (*_morse_to_alphanumeric_dictionary() + .get(string) + .unwrap_or(&_UNKNOWN_MORSE_CHARACTER)) + .to_string() +} + +fn _decode_part(string: &str) -> String { + string.split(' ').map(_decode_token).collect::() +} + +/// Convert morse code to ascii. +/// +/// Given a morse code, return the corresponding message. +/// If the code is invalid, the undecipherable part of the code is replaced by `_`. +pub fn decode(string: &str) -> Result { + if !_check_all_parts(string) { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Invalid morse code", + )); + } + + let mut partitions: Vec = vec![]; + + for part in string.split('/') { + partitions.push(_decode_part(part)); + } + + Ok(partitions.join(" ")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn encrypt_only_letters() { + let message = "Hello Morse"; + let cipher = encode(message); + assert_eq!( + cipher, + ".... . .-.. .-.. --- / -- --- .-. ... .".to_string() + ) + } + + #[test] + fn encrypt_letters_and_special_characters() { + let message = "What's a great day!"; + let cipher = encode(message); + assert_eq!( + cipher, + ".-- .... .- - .----. ... / .- / --. .-. . .- - / -.. .- -.-- -.-.--".to_string() + ) + } + + #[test] + fn encrypt_message_with_unsupported_character() { + let message = "Error?? {}"; + let cipher = encode(message); + assert_eq!( + cipher, + ". .-. .-. --- .-. ..--.. ..--.. / ........ ........".to_string() + ) + } + + #[test] + fn decrypt_valid_morsecode_with_spaces() { + let expected = "Hello Morse! How's it goin, \"eh\"?" + .to_string() + .to_uppercase(); + let encypted = encode(&expected); + let result = decode(&encypted).unwrap(); + + assert_eq!(expected, result); + } + + #[test] + fn decrypt_valid_character_set_invalid_morsecode() { + let expected = format!( + "{_UNKNOWN_MORSE_CHARACTER}{_UNKNOWN_MORSE_CHARACTER}{_UNKNOWN_MORSE_CHARACTER}{_UNKNOWN_MORSE_CHARACTER} {_UNKNOWN_MORSE_CHARACTER}", + ); + + let encypted = ".-.-.--.-.-. --------. ..---.-.-. .-.-.--.-.-. / .-.-.--.-.-.".to_string(); + let result = decode(&encypted).unwrap(); + + assert_eq!(expected, result); + } + + #[test] + fn decrypt_invalid_morsecode_with_spaces() { + let encypted = "1... . .-.. .-.. --- / -- --- .-. ... ."; + let result = decode(encypted).map_err(|e| e.kind()); + let expected = Err(io::ErrorKind::InvalidData); + + assert_eq!(expected, result); + } +} diff --git a/src/ciphers/polybius.rs b/src/ciphers/polybius.rs new file mode 100644 index 00000000000..9afe81d106f --- /dev/null +++ b/src/ciphers/polybius.rs @@ -0,0 +1,149 @@ +/// Encode an ASCII string into its location in a Polybius square. +/// Only alphabetical characters are encoded. +pub fn encode_ascii(string: &str) -> String { + string + .chars() + .map(|c| match c { + 'a' | 'A' => "11", + 'b' | 'B' => "12", + 'c' | 'C' => "13", + 'd' | 'D' => "14", + 'e' | 'E' => "15", + 'f' | 'F' => "21", + 'g' | 'G' => "22", + 'h' | 'H' => "23", + 'i' | 'I' | 'j' | 'J' => "24", + 'k' | 'K' => "25", + 'l' | 'L' => "31", + 'm' | 'M' => "32", + 'n' | 'N' => "33", + 'o' | 'O' => "34", + 'p' | 'P' => "35", + 'q' | 'Q' => "41", + 'r' | 'R' => "42", + 's' | 'S' => "43", + 't' | 'T' => "44", + 'u' | 'U' => "45", + 'v' | 'V' => "51", + 'w' | 'W' => "52", + 'x' | 'X' => "53", + 'y' | 'Y' => "54", + 'z' | 'Z' => "55", + _ => "", + }) + .collect() +} + +/// Decode a string of ints into their corresponding +/// letters in a Polybius square. +/// +/// Any invalid characters, or whitespace will be ignored. +pub fn decode_ascii(string: &str) -> String { + string + .chars() + .filter(|c| !c.is_whitespace()) + .collect::() + .as_bytes() + .chunks(2) + .map(|s| match std::str::from_utf8(s) { + Ok(v) => v.parse::().unwrap_or(0), + Err(_) => 0, + }) + .map(|i| match i { + 11 => 'A', + 12 => 'B', + 13 => 'C', + 14 => 'D', + 15 => 'E', + 21 => 'F', + 22 => 'G', + 23 => 'H', + 24 => 'I', + 25 => 'K', + 31 => 'L', + 32 => 'M', + 33 => 'N', + 34 => 'O', + 35 => 'P', + 41 => 'Q', + 42 => 'R', + 43 => 'S', + 44 => 'T', + 45 => 'U', + 51 => 'V', + 52 => 'W', + 53 => 'X', + 54 => 'Y', + 55 => 'Z', + _ => ' ', + }) + .collect::() + .replace(' ', "") +} + +#[cfg(test)] +mod tests { + use super::{decode_ascii, encode_ascii}; + + #[test] + fn encode_empty() { + assert_eq!(encode_ascii(""), ""); + } + + #[test] + fn encode_valid_string() { + assert_eq!(encode_ascii("This is a test"), "4423244324431144154344"); + } + + #[test] + fn encode_emoji() { + assert_eq!(encode_ascii("🙂"), ""); + } + + #[test] + fn decode_empty() { + assert_eq!(decode_ascii(""), ""); + } + + #[test] + fn decode_valid_string() { + assert_eq!( + decode_ascii("44 23 24 43 24 43 11 44 15 43 44 "), + "THISISATEST" + ); + } + + #[test] + fn decode_emoji() { + assert_eq!(decode_ascii("🙂"), ""); + } + + #[test] + fn decode_string_with_whitespace() { + assert_eq!( + decode_ascii("44\n23\t\r24\r\n43 2443\n 11 \t 44\r \r15 \n43 44"), + "THISISATEST" + ); + } + + #[test] + fn decode_unknown_string() { + assert_eq!(decode_ascii("94 63 64 83 64 48 77 00 05 47 48 "), ""); + } + + #[test] + fn decode_odd_length() { + assert_eq!(decode_ascii("11 22 33 4"), "AGN"); + } + + #[test] + fn encode_and_decode() { + let string = "Do you ever wonder why we're here?"; + let encode = encode_ascii(string); + assert_eq!( + "1434543445155115425234331415425223545215421523154215", + encode, + ); + assert_eq!("DOYOUEVERWONDERWHYWEREHERE", decode_ascii(&encode)); + } +} diff --git a/src/ciphers/rail_fence.rs b/src/ciphers/rail_fence.rs new file mode 100644 index 00000000000..aedff07ea31 --- /dev/null +++ b/src/ciphers/rail_fence.rs @@ -0,0 +1,41 @@ +// wiki: https://en.wikipedia.org/wiki/Rail_fence_cipher +pub fn rail_fence_encrypt(plain_text: &str, key: usize) -> String { + let mut cipher = vec![Vec::new(); key]; + + for (c, i) in plain_text.chars().zip(zigzag(key)) { + cipher[i].push(c); + } + + return cipher.iter().flatten().collect::(); +} + +pub fn rail_fence_decrypt(cipher: &str, key: usize) -> String { + let mut indices: Vec<_> = zigzag(key).zip(1..).take(cipher.len()).collect(); + indices.sort(); + + let mut cipher_text: Vec<_> = cipher + .chars() + .zip(indices) + .map(|(c, (_, i))| (i, c)) + .collect(); + + cipher_text.sort(); + return cipher_text.iter().map(|(_, c)| c).collect(); +} + +fn zigzag(n: usize) -> impl Iterator { + (0..n - 1).chain((1..n).rev()).cycle() +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn rails_basic() { + assert_eq!(rail_fence_encrypt("attack at once", 2), "atc toctaka ne"); + assert_eq!(rail_fence_decrypt("atc toctaka ne", 2), "attack at once"); + + assert_eq!(rail_fence_encrypt("rust is cool", 3), "r cuti olsso"); + assert_eq!(rail_fence_decrypt("r cuti olsso", 3), "rust is cool"); + } +} diff --git a/src/ciphers/salsa.rs b/src/ciphers/salsa.rs new file mode 100644 index 00000000000..83b37556ff1 --- /dev/null +++ b/src/ciphers/salsa.rs @@ -0,0 +1,128 @@ +macro_rules! quarter_round { + ($v1:expr,$v2:expr,$v3:expr,$v4:expr) => { + $v2 ^= ($v1.wrapping_add($v4).rotate_left(7)); + $v3 ^= ($v2.wrapping_add($v1).rotate_left(9)); + $v4 ^= ($v3.wrapping_add($v2).rotate_left(13)); + $v1 ^= ($v4.wrapping_add($v3).rotate_left(18)); + }; +} + +/// This is a `Salsa20` implementation based on \ +/// `Salsa20` is a stream cipher developed by Daniel J. Bernstein.\ +/// To use it, the `salsa20` function should be called with appropriate parameters and the +/// output of the function should be XORed with plain text. +/// +/// `salsa20` function takes as input an array of 16 32-bit integers (512 bits) +/// of which 128 bits is the constant 'expand 32-byte k', 256 bits is the key, +/// and 128 bits are nonce and counter. It is up to the user to determine how +/// many bits each of nonce and counter take, but a default of 64 bits each +/// seems to be a sane choice. +/// +/// The 16 input numbers can be thought of as the elements of a 4x4 matrix like +/// the one bellow, on which we do the main operations of the cipher. +/// +/// ```text +/// +----+----+----+----+ +/// | 00 | 01 | 02 | 03 | +/// +----+----+----+----+ +/// | 04 | 05 | 06 | 07 | +/// +----+----+----+----+ +/// | 08 | 09 | 10 | 11 | +/// +----+----+----+----+ +/// | 12 | 13 | 14 | 15 | +/// +----+----+----+----+ +/// ``` +/// +/// As per the diagram bellow, `input[0, 5, 10, 15]` are the constants mentioned +/// above, `input[1, 2, 3, 4, 11, 12, 13, 14]` is filled with the key, and +/// `input[6, 7, 8, 9]` should be filled with nonce and counter values. The output +/// of the function is stored in `output` variable and can be XORed with the +/// plain text to produce the cipher text. +/// +/// ```text +/// +------+------+------+------+ +/// | | | | | +/// | C[0] | key1 | key2 | key3 | +/// | | | | | +/// +------+------+------+------+ +/// | | | | | +/// | key4 | C[1] | no1 | no2 | +/// | | | | | +/// +------+------+------+------+ +/// | | | | | +/// | ctr1 | ctr2 | C[2] | key5 | +/// | | | | | +/// +------+------+------+------+ +/// | | | | | +/// | key6 | key7 | key8 | C[3] | +/// | | | | | +/// +------+------+------+------+ +/// ``` +pub fn salsa20(input: &[u32; 16], output: &mut [u32; 16]) { + output.copy_from_slice(&input[..]); + for _ in 0..10 { + // Odd round + quarter_round!(output[0], output[4], output[8], output[12]); // column 1 + quarter_round!(output[5], output[9], output[13], output[1]); // column 2 + quarter_round!(output[10], output[14], output[2], output[6]); // column 3 + quarter_round!(output[15], output[3], output[7], output[11]); // column 4 + + // Even round + quarter_round!(output[0], output[1], output[2], output[3]); // row 1 + quarter_round!(output[5], output[6], output[7], output[4]); // row 2 + quarter_round!(output[10], output[11], output[8], output[9]); // row 3 + quarter_round!(output[15], output[12], output[13], output[14]); // row 4 + } + for (a, &b) in output.iter_mut().zip(input.iter()) { + *a = a.wrapping_add(b); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fmt::Write; + + const C: [u32; 4] = [0x65787061, 0x6e642033, 0x322d6279, 0x7465206b]; + + fn output_hex(inp: &[u32; 16]) -> String { + let mut res = String::new(); + res.reserve(512 / 4); + for &x in inp { + write!(&mut res, "{x:08x}").unwrap(); + } + res + } + #[test] + // test vector 1 + fn basic_tv1() { + let mut inp = [0u32; 16]; + let mut out = [0u32; 16]; + inp[0] = C[0]; + inp[1] = 0x01020304; // 1, 2, 3, 4 + inp[2] = 0x05060708; // 5, 6, 7, 8, ... + inp[3] = 0x090a0b0c; + inp[4] = 0x0d0e0f10; + inp[5] = C[1]; + inp[6] = 0x65666768; // 101, 102, 103, 104 + inp[7] = 0x696a6b6c; // 105, 106, 107, 108, ... + inp[8] = 0x6d6e6f70; + inp[9] = 0x71727374; + inp[10] = C[2]; + inp[11] = 0xc9cacbcc; // 201, 202, 203, 204 + inp[12] = 0xcdcecfd0; // 205, 206, 207, 208, ... + inp[13] = 0xd1d2d3d4; + inp[14] = 0xd5d6d7d8; + inp[15] = C[3]; + salsa20(&inp, &mut out); + // Checked with wikipedia implementation, does not agree with + // "https://cr.yp.to/snuffle/spec.pdf" + assert_eq!( + output_hex(&out), + concat!( + "de1d6f8d91dbf69d0db4b70c8b4320d236694432896d98b05aa7b76d5738ca13", + "04e5a170c8e479af1542ed2f30f26ba57da20203cfe955c66f4cc7a06dd34359" + ) + ); + } +} diff --git a/src/ciphers/sha256.rs b/src/ciphers/sha256.rs new file mode 100644 index 00000000000..af6ce814434 --- /dev/null +++ b/src/ciphers/sha256.rs @@ -0,0 +1,340 @@ +/*! + * SHA-2 256 bit implementation + * This implementation is based on RFC6234 + * Keep in mind that the amount of data (in bits) processed should always be an + * integer multiple of 8 + */ + +// The constants are tested to make sure they are correct +pub const H0: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +pub const K: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; + +// The following functions are implemented according to page 10 of RFC6234 +#[inline] +fn ch(x: u32, y: u32, z: u32) -> u32 { + (x & y) ^ ((!x) & z) +} + +#[inline] +fn maj(x: u32, y: u32, z: u32) -> u32 { + (x & y) ^ (x & z) ^ (y & z) +} + +#[inline] +fn bsig0(x: u32) -> u32 { + x.rotate_right(2) ^ x.rotate_right(13) ^ x.rotate_right(22) +} + +#[inline] +fn bsig1(x: u32) -> u32 { + x.rotate_right(6) ^ x.rotate_right(11) ^ x.rotate_right(25) +} + +#[inline] +fn ssig0(x: u32) -> u32 { + x.rotate_right(7) ^ x.rotate_right(18) ^ (x >> 3) +} + +#[inline] +fn ssig1(x: u32) -> u32 { + x.rotate_right(17) ^ x.rotate_right(19) ^ (x >> 10) +} + +pub struct SHA256 { + /// The current block to be processed, 512 bits long + buffer: [u32; 16], + /// Length (bits) of the message, should always be a multiple of 8 + length: u64, + /// The current hash value. Note: this value is invalid unless `finalize` + /// is called + pub h: [u32; 8], + /// Message schedule + w: [u32; 64], + pub finalized: bool, + // Temporary values: + round: [u32; 8], +} + +fn process_block(h: &mut [u32; 8], w: &mut [u32; 64], round: &mut [u32; 8], buf: &[u32; 16]) { + // Prepare the message schedule: + w[..buf.len()].copy_from_slice(&buf[..]); + for i in buf.len()..w.len() { + w[i] = ssig1(w[i - 2]) + .wrapping_add(w[i - 7]) + .wrapping_add(ssig0(w[i - 15])) + .wrapping_add(w[i - 16]); + } + round.copy_from_slice(h); + for i in 0..w.len() { + let t1 = round[7] + .wrapping_add(bsig1(round[4])) + .wrapping_add(ch(round[4], round[5], round[6])) + .wrapping_add(K[i]) + .wrapping_add(w[i]); + let t2 = bsig0(round[0]).wrapping_add(maj(round[0], round[1], round[2])); + round[7] = round[6]; + round[6] = round[5]; + round[5] = round[4]; + round[4] = round[3].wrapping_add(t1); + round[3] = round[2]; + round[2] = round[1]; + round[1] = round[0]; + round[0] = t1.wrapping_add(t2); + } + for i in 0..h.len() { + h[i] = h[i].wrapping_add(round[i]); + } +} + +impl SHA256 { + pub fn new_default() -> Self { + SHA256 { + buffer: [0u32; 16], + length: 0, + h: H0, + w: [0u32; 64], + round: [0u32; 8], + finalized: false, + } + } + /// Note: buffer should be empty before calling this! + pub fn process_block(&mut self, buf: &[u32; 16]) { + process_block(&mut self.h, &mut self.w, &mut self.round, buf); + self.length += 512; + } + + pub fn update(&mut self, data: &[u8]) { + if data.is_empty() { + return; + } + let offset = (((32 - (self.length & 31)) & 31) >> 3) as usize; + let mut buf_ind = ((self.length & 511) >> 5) as usize; + for (i, &byte) in data.iter().enumerate().take(offset) { + self.buffer[buf_ind] ^= (byte as u32) << ((offset - i - 1) << 3); + } + self.length += (data.len() as u64) << 3; + if offset > data.len() { + return; + } + if offset > 0 { + buf_ind += 1; + } + if data.len() > 3 { + for i in (offset..(data.len() - 3)).step_by(4) { + if buf_ind & 16 == 16 { + process_block(&mut self.h, &mut self.w, &mut self.round, &self.buffer); + buf_ind = 0; + } + self.buffer[buf_ind] = ((data[i] as u32) << 24) + ^ ((data[i + 1] as u32) << 16) + ^ ((data[i + 2] as u32) << 8) + ^ data[i + 3] as u32; + buf_ind += 1; + } + } + if buf_ind & 16 == 16 { + process_block(&mut self.h, &mut self.w, &mut self.round, &self.buffer); + buf_ind = 0; + } + self.buffer[buf_ind] = 0; + let rem_ind = offset + ((data.len() - offset) & !0b11); + for (i, &byte) in data[rem_ind..].iter().enumerate() { + self.buffer[buf_ind] ^= (byte as u32) << ((3 - i) << 3); + } + } + + pub fn get_hash(&mut self) -> [u8; 32] { + // we should first add a `1` bit to the end of the buffer, then we will + // add enough 0s so that the length becomes (512k + 448). After that we + // will append the binary representation of length to the data + if !self.finalized { + self.finalized = true; + let clen = (self.length + 8) & 511; + let num_0 = match clen.cmp(&448) { + std::cmp::Ordering::Greater => (448 + 512 - clen) >> 3, + _ => (448 - clen) >> 3, + }; + let mut padding: Vec = vec![0_u8; (num_0 + 9) as usize]; + let len = padding.len(); + padding[0] = 0x80; + padding[len - 8] = (self.length >> 56) as u8; + padding[len - 7] = (self.length >> 48) as u8; + padding[len - 6] = (self.length >> 40) as u8; + padding[len - 5] = (self.length >> 32) as u8; + padding[len - 4] = (self.length >> 24) as u8; + padding[len - 3] = (self.length >> 16) as u8; + padding[len - 2] = (self.length >> 8) as u8; + padding[len - 1] = self.length as u8; + self.update(&padding); + } + assert_eq!(self.length & 511, 0); + let mut result = [0u8; 32]; + for i in (0..32).step_by(4) { + result[i] = (self.h[i >> 2] >> 24) as u8; + result[i + 1] = (self.h[i >> 2] >> 16) as u8; + result[i + 2] = (self.h[i >> 2] >> 8) as u8; + result[i + 3] = self.h[i >> 2] as u8; + } + result + } +} + +impl super::Hasher<32> for SHA256 { + fn new_default() -> Self { + SHA256::new_default() + } + + fn update(&mut self, data: &[u8]) { + self.update(data); + } + + fn get_hash(&mut self) -> [u8; 32] { + self.get_hash() + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + use crate::math::LinearSieve; + use std::fmt::Write; + + // Let's keep this utility function + pub fn get_hash_string(hash: &[u8; 32]) -> String { + let mut result = String::new(); + result.reserve(64); + for &ch in hash { + write!(&mut result, "{ch:02x}").unwrap(); + } + result + } + + #[test] + fn test_constants() { + let mut ls = LinearSieve::new(); + ls.prepare(311).unwrap(); + assert_eq!(64, ls.primes.len()); + assert_eq!(311, ls.primes[63]); + + let float_len = 52; + let constant_len = 32; + for (pos, &k) in K.iter().enumerate() { + let a: f64 = ls.primes[pos] as f64; + let bits = a.cbrt().to_bits(); + let exp = bits >> float_len; // The sign bit is already 0 + //(exp - 1023) can be bigger than 0, we must include more bits. + let k_ref = ((bits & ((1_u64 << float_len) - 1)) + >> (float_len - constant_len + 1023 - exp)) as u32; + assert_eq!(k, k_ref); + } + + for (pos, &h) in H0.iter().enumerate() { + let a: f64 = ls.primes[pos] as f64; + let bits = a.sqrt().to_bits(); + let exp = bits >> float_len; + let h_ref = ((bits & ((1_u64 << float_len) - 1)) + >> (float_len - constant_len + 1023 - exp)) as u32; + assert_eq!(h, h_ref); + } + } + + // To test the hashes, you can use the following command on linux: + // echo -n 'STRING' | sha256sum + // the `-n` is because by default, echo adds a `\n` to its output + + #[test] + fn empty() { + let mut res = SHA256::new_default(); + assert_eq!( + res.get_hash(), + [ + 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, + 0xb9, 0x24, 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, 0xa4, 0x95, 0x99, 0x1b, + 0x78, 0x52, 0xb8, 0x55 + ] + ); + } + + #[test] + fn ascii() { + let mut res = SHA256::new_default(); + res.update(b"The quick brown fox jumps over the lazy dog"); + assert_eq!( + res.get_hash(), + [ + 0xD7, 0xA8, 0xFB, 0xB3, 0x07, 0xD7, 0x80, 0x94, 0x69, 0xCA, 0x9A, 0xBC, 0xB0, 0x08, + 0x2E, 0x4F, 0x8D, 0x56, 0x51, 0xE4, 0x6D, 0x3C, 0xDB, 0x76, 0x2D, 0x02, 0xD0, 0xBF, + 0x37, 0xC9, 0xE5, 0x92 + ] + ) + } + + #[test] + fn ascii_avalanche() { + let mut res = SHA256::new_default(); + res.update(b"The quick brown fox jumps over the lazy dog."); + assert_eq!( + res.get_hash(), + [ + 0xEF, 0x53, 0x7F, 0x25, 0xC8, 0x95, 0xBF, 0xA7, 0x82, 0x52, 0x65, 0x29, 0xA9, 0xB6, + 0x3D, 0x97, 0xAA, 0x63, 0x15, 0x64, 0xD5, 0xD7, 0x89, 0xC2, 0xB7, 0x65, 0x44, 0x8C, + 0x86, 0x35, 0xFB, 0x6C + ] + ); + // Test if finalization is not repeated twice + assert_eq!( + res.get_hash(), + [ + 0xEF, 0x53, 0x7F, 0x25, 0xC8, 0x95, 0xBF, 0xA7, 0x82, 0x52, 0x65, 0x29, 0xA9, 0xB6, + 0x3D, 0x97, 0xAA, 0x63, 0x15, 0x64, 0xD5, 0xD7, 0x89, 0xC2, 0xB7, 0x65, 0x44, 0x8C, + 0x86, 0x35, 0xFB, 0x6C + ] + ) + } + #[test] + fn long_ascii() { + let mut res = SHA256::new_default(); + let val = b"The quick brown fox jumps over the lazy dog."; + for _ in 0..1000 { + res.update(val); + } + let hash = res.get_hash(); + assert_eq!( + &get_hash_string(&hash), + "c264fca077807d391df72fadf39dd63be21f1823f65ca530c9637760eabfc18c" + ); + let mut res = SHA256::new_default(); + let val = b"a"; + for _ in 0..999 { + res.update(val); + } + let hash = res.get_hash(); + assert_eq!( + &get_hash_string(&hash), + "d9fe27f3d807a7c46467325f7189495e82b099ce2e14c5b16cc76697fa909f81" + ) + } + #[test] + fn short_ascii() { + let mut res = SHA256::new_default(); + let val = b"a"; + res.update(val); + let hash = res.get_hash(); + assert_eq!( + &get_hash_string(&hash), + "ca978112ca1bbdcafac231b39a23dc4da786eff8147c4e72b9807785afee48bb" + ); + } +} diff --git a/src/ciphers/sha3.rs b/src/ciphers/sha3.rs new file mode 100644 index 00000000000..f3791214f3f --- /dev/null +++ b/src/ciphers/sha3.rs @@ -0,0 +1,590 @@ +/// Size of the state array in bits +const B: usize = 1600; + +const W: usize = B / 25; +const L: usize = W.ilog2() as usize; + +const U8BITS: usize = u8::BITS as usize; + +// Macro for looping through the whole state array +macro_rules! iterate { + ( $x:ident, $y:ident, $z:ident => $b:block ) => { + for $y in 0..5 { + for $x in 0..5 { + for $z in 0..W { + $b + } + } + } + }; +} + +/// A function that produces a padding string such that the length of the padding + the length of +/// the string to be padded (2nd parameter) is divisible by the 1st parameter +type PadFn = fn(isize, isize) -> Vec; +type SpongeFn = fn(&[bool]) -> [bool; B]; + +type State = [[[bool; W]; 5]; 5]; + +fn state_new() -> State { + [[[false; W]; 5]; 5] +} + +fn state_fill(dest: &mut State, bits: &[bool]) { + let mut i = 0usize; + + iterate!(x, y, z => { + if i >= bits.len() { return; } + dest[x][y][z] = bits[i]; + i += 1; + }); +} + +fn state_copy(dest: &mut State, src: &State) { + iterate!(x, y, z => { + dest[x][y][z] = src[x][y][z]; + }); +} + +fn state_dump(state: &State) -> [bool; B] { + let mut bits = [false; B]; + + let mut i = 0usize; + + iterate!(x, y, z => { + bits[i] = state[x][y][z]; + i += 1; + }); + + bits +} + +/// XORs the state with the parities of two columns in the state array +fn theta(state: &mut State) { + let mut c = [[false; W]; 5]; + let mut d = [[false; W]; 5]; + + // Assign values of C[x,z] + for x in 0..5 { + for z in 0..W { + c[x][z] = state[x][0][z]; + + for y in 1..5 { + c[x][z] ^= state[x][y][z]; + } + } + } + + // Assign values of D[x,z] + for x in 0..5 { + for z in 0..W { + let x1 = (x as isize - 1).rem_euclid(5) as usize; + let z2 = (z as isize - 1).rem_euclid(W as isize) as usize; + + d[x][z] = c[x1][z] ^ c[(x + 1) % 5][z2]; + } + } + + // Xor values of D[x,z] into our state array + iterate!(x, y, z => { + state[x][y][z] ^= d[x][z]; + }); +} + +/// Rotates each lane by an offset depending of the x and y indeces +fn rho(state: &mut State) { + let mut new_state = state_new(); + + for z in 0..W { + new_state[0][0][z] = state[0][0][z]; + } + + let mut x = 1; + let mut y = 0; + + for t in 0..=23isize { + for z in 0..W { + let z_offset: isize = ((t + 1) * (t + 2)) / 2; + let new_z = (z as isize - z_offset).rem_euclid(W as isize) as usize; + + new_state[x][y][z] = state[x][y][new_z]; + } + + let old_y = y; + y = ((2 * x) + (3 * y)) % 5; + x = old_y; + } + + state_copy(state, &new_state); +} + +/// Rearrange the positions of the lanes of the state array +fn pi(state: &mut State) { + let mut new_state = state_new(); + + iterate!(x, y, z => { + new_state[x][y][z] = state[(x + (3 * y)) % 5][x][z]; + }); + + state_copy(state, &new_state); +} + +fn chi(state: &mut State) { + let mut new_state = state_new(); + + iterate!(x, y, z => { + new_state[x][y][z] = state[x][y][z] ^ ((state[(x + 1) % 5][y][z] ^ true) & state[(x + 2) % 5][y][z]); + }); + + state_copy(state, &new_state); +} + +/// Calculates the round constant depending on what the round number is +fn rc(t: u8) -> bool { + let mut b1: u16; + let mut b2: u16; + let mut r: u16 = 0x80; // tread r as an array of bits + + //if t % 0xFF == 0 { return true; } + + for _i in 0..(t % 255) { + b1 = r >> 8; + b2 = r & 1; + r |= (b1 ^ b2) << 8; + + b1 = (r >> 4) & 1; + r &= 0x1EF; // clear r[4] + r |= (b1 ^ b2) << 4; + + b1 = (r >> 3) & 1; + r &= 0x1F7; // clear r[3] + r |= (b1 ^ b2) << 3; + + b1 = (r >> 2) & 1; + r &= 0x1FB; // clear r[2] + r |= (b1 ^ b2) << 2; + + r >>= 1; + } + + (r >> 7) != 0 +} + +/// Applies the round constant to the first lane of the state array +fn iota(state: &mut State, i_r: u8) { + let mut rc_arr = [false; W]; + + for j in 0..=L { + rc_arr[(1 << j) - 1] = rc((j as u8) + (7 * i_r)); + } + + for (z, bit) in rc_arr.iter().enumerate() { + state[0][0][z] ^= *bit; + } +} + +fn rnd(state: &mut State, i_r: u8) { + theta(state); + rho(state); + pi(state); + chi(state); + iota(state, i_r); +} + +fn keccak_f(bits: &[bool]) -> [bool; B] { + let n_r = 12 + (2 * L); + + let mut state = state_new(); + state_fill(&mut state, bits); + + for i_r in 0..n_r { + rnd(&mut state, i_r as u8); + } + + state_dump(&state) +} + +fn pad101(x: isize, m: isize) -> Vec { + let mut j = -m - 2; + + while j < 0 { + j += x; + } + + j %= x; + + let mut ret = vec![false; (j as usize) + 2]; + *ret.first_mut().unwrap() = true; + *ret.last_mut().unwrap() = true; + + ret +} + +/// Sponge construction is a method of compression needing 1) a function on fixed-length bit +/// strings( here we use keccak_f), 2) a padding function (pad10*1), and 3) a rate. The input and +/// output of this method can be arbitrarily long +fn sponge(f: SpongeFn, pad: PadFn, r: usize, n: &[bool], d: usize) -> Vec { + let mut p = Vec::from(n); + p.append(&mut pad(r as isize, n.len() as isize)); + + assert!(r < B); + + let mut s = [false; B]; + for chunk in p.chunks(r) { + for (s_i, c_i) in s.iter_mut().zip(chunk) { + *s_i ^= c_i; + } + + s = f(&s); + } + + let mut z = Vec::::new(); + while z.len() < d { + z.extend(&s); + + s = f(&s); + } + + z.truncate(d); + z +} + +fn keccak(c: usize, n: &[bool], d: usize) -> Vec { + sponge(keccak_f, pad101, B - c, n, d) +} + +fn h2b(h: &[u8], n: usize) -> Vec { + let mut bits = Vec::with_capacity(h.len() * U8BITS); + + for byte in h { + for i in 0..u8::BITS { + let mask: u8 = 1 << i; + + bits.push((byte & mask) != 0); + } + } + + assert!(bits.len() == h.len() * U8BITS); + + bits.truncate(n); + bits +} + +fn b2h(s: &[bool]) -> Vec { + let m = if s.len() % U8BITS != 0 { + (s.len() / 8) + 1 + } else { + s.len() / 8 + }; + let mut bytes = vec![0u8; m]; + + for (i, bit) in s.iter().enumerate() { + let byte_index = i / U8BITS; + let mask = (*bit as u8) << (i % U8BITS); + + bytes[byte_index] |= mask; + } + + bytes +} +/// Macro to implement all sha3 hash functions as they only differ in digest size +macro_rules! sha3 { + ($name:ident, $n:literal) => { + pub fn $name(m: &[u8]) -> [u8; ($n / U8BITS)] { + let mut temp = h2b(m, m.len() * U8BITS); + temp.append(&mut vec![false, true]); + + temp = keccak($n * 2, &temp, $n); + + let mut ret = [0u8; ($n / U8BITS)]; + + let temp = b2h(&temp); + assert!(temp.len() == $n / U8BITS); + + for (i, byte) in temp.iter().enumerate() { + ret[i] = *byte; + } + + ret + } + }; +} + +sha3!(sha3_224, 224); +sha3!(sha3_256, 256); +sha3!(sha3_384, 384); +sha3!(sha3_512, 512); + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! digest_test { + ($fname:ident, $hash:ident, $size:literal, $message:expr, $expected:expr) => { + #[test] + fn $fname() { + let digest = $hash(&$message); + + let expected: [u8; $size / U8BITS] = $expected; + + assert_eq!(digest, expected); + } + }; + } + + digest_test!( + sha3_224_0, + sha3_224, + 224, + [0; 0], + [ + 0x6b, 0x4e, 0x03, 0x42, 0x36, 0x67, 0xdb, 0xb7, 0x3b, 0x6e, 0x15, 0x45, 0x4f, 0x0e, + 0xb1, 0xab, 0xd4, 0x59, 0x7f, 0x9a, 0x1b, 0x07, 0x8e, 0x3f, 0x5b, 0x5a, 0x6b, 0xc7, + ] + ); + + digest_test!( + sha3_224_8, + sha3_224, + 224, + [1u8], + [ + 0x48, 0x82, 0x86, 0xd9, 0xd3, 0x27, 0x16, 0xe5, 0x88, 0x1e, 0xa1, 0xee, 0x51, 0xf3, + 0x6d, 0x36, 0x60, 0xd7, 0x0f, 0x0d, 0xb0, 0x3b, 0x3f, 0x61, 0x2c, 0xe9, 0xed, 0xa4, + ] + ); + + // Done on large input to verify sponge function is working properly + digest_test!( + sha3_224_2312, + sha3_224, + 224, + [ + 0x31, 0xc8, 0x2d, 0x71, 0x78, 0x5b, 0x7c, 0xa6, 0xb6, 0x51, 0xcb, 0x6c, 0x8c, 0x9a, + 0xd5, 0xe2, 0xac, 0xeb, 0x0b, 0x06, 0x33, 0xc0, 0x88, 0xd3, 0x3a, 0xa2, 0x47, 0xad, + 0xa7, 0xa5, 0x94, 0xff, 0x49, 0x36, 0xc0, 0x23, 0x25, 0x13, 0x19, 0x82, 0x0a, 0x9b, + 0x19, 0xfc, 0x6c, 0x48, 0xde, 0x8a, 0x6f, 0x7a, 0xda, 0x21, 0x41, 0x76, 0xcc, 0xda, + 0xad, 0xae, 0xef, 0x51, 0xed, 0x43, 0x71, 0x4a, 0xc0, 0xc8, 0x26, 0x9b, 0xbd, 0x49, + 0x7e, 0x46, 0xe7, 0x8b, 0xb5, 0xe5, 0x81, 0x96, 0x49, 0x4b, 0x24, 0x71, 0xb1, 0x68, + 0x0e, 0x2d, 0x4c, 0x6d, 0xbd, 0x24, 0x98, 0x31, 0xbd, 0x83, 0xa4, 0xd3, 0xbe, 0x06, + 0xc8, 0xa2, 0xe9, 0x03, 0x93, 0x39, 0x74, 0xaa, 0x05, 0xee, 0x74, 0x8b, 0xfe, 0x6e, + 0xf3, 0x59, 0xf7, 0xa1, 0x43, 0xed, 0xf0, 0xd4, 0x91, 0x8d, 0xa9, 0x16, 0xbd, 0x6f, + 0x15, 0xe2, 0x6a, 0x79, 0x0c, 0xff, 0x51, 0x4b, 0x40, 0xa5, 0xda, 0x7f, 0x72, 0xe1, + 0xed, 0x2f, 0xe6, 0x3a, 0x05, 0xb8, 0x14, 0x95, 0x87, 0xbe, 0xa0, 0x56, 0x53, 0x71, + 0x8c, 0xc8, 0x98, 0x0e, 0xad, 0xbf, 0xec, 0xa8, 0x5b, 0x7c, 0x9c, 0x28, 0x6d, 0xd0, + 0x40, 0x93, 0x65, 0x85, 0x93, 0x8b, 0xe7, 0xf9, 0x82, 0x19, 0x70, 0x0c, 0x83, 0xa9, + 0x44, 0x3c, 0x28, 0x56, 0xa8, 0x0f, 0xf4, 0x68, 0x52, 0xb2, 0x6d, 0x1b, 0x1e, 0xdf, + 0x72, 0xa3, 0x02, 0x03, 0xcf, 0x6c, 0x44, 0xa1, 0x0f, 0xa6, 0xea, 0xf1, 0x92, 0x01, + 0x73, 0xce, 0xdf, 0xb5, 0xc4, 0xcf, 0x3a, 0xc6, 0x65, 0xb3, 0x7a, 0x86, 0xed, 0x02, + 0x15, 0x5b, 0xbb, 0xf1, 0x7d, 0xc2, 0xe7, 0x86, 0xaf, 0x94, 0x78, 0xfe, 0x08, 0x89, + 0xd8, 0x6c, 0x5b, 0xfa, 0x85, 0xa2, 0x42, 0xeb, 0x08, 0x54, 0xb1, 0x48, 0x2b, 0x7b, + 0xd1, 0x6f, 0x67, 0xf8, 0x0b, 0xef, 0x9c, 0x7a, 0x62, 0x8f, 0x05, 0xa1, 0x07, 0x93, + 0x6a, 0x64, 0x27, 0x3a, 0x97, 0xb0, 0x08, 0x8b, 0x0e, 0x51, 0x54, 0x51, 0xf9, 0x16, + 0xb5, 0x65, 0x62, 0x30, 0xa1, 0x2b, 0xa6, 0xdc, 0x78 + ], + [ + 0xaa, 0xb2, 0x3c, 0x9e, 0x7f, 0xb9, 0xd7, 0xda, 0xce, 0xfd, 0xfd, 0x0b, 0x1a, 0xe8, + 0x5a, 0xb1, 0x37, 0x4a, 0xbf, 0xf7, 0xc4, 0xe3, 0xf7, 0x55, 0x6e, 0xca, 0xe4, 0x12 + ] + ); + + digest_test!( + sha3_256_0, + sha3_256, + 256, + [0; 0], + [ + 0xa7, 0xff, 0xc6, 0xf8, 0xbf, 0x1e, 0xd7, 0x66, 0x51, 0xc1, 0x47, 0x56, 0xa0, 0x61, + 0xd6, 0x62, 0xf5, 0x80, 0xff, 0x4d, 0xe4, 0x3b, 0x49, 0xfa, 0x82, 0xd8, 0x0a, 0x4b, + 0x80, 0xf8, 0x43, 0x4a, + ] + ); + + digest_test!( + sha3_256_8, + sha3_256, + 256, + [0xe9u8], + [ + 0xf0, 0xd0, 0x4d, 0xd1, 0xe6, 0xcf, 0xc2, 0x9a, 0x44, 0x60, 0xd5, 0x21, 0x79, 0x68, + 0x52, 0xf2, 0x5d, 0x9e, 0xf8, 0xd2, 0x8b, 0x44, 0xee, 0x91, 0xff, 0x5b, 0x75, 0x9d, + 0x72, 0xc1, 0xe6, 0xd6, + ] + ); + + digest_test!( + sha3_256_2184, + sha3_256, + 256, + [ + 0xb1, 0xca, 0xa3, 0x96, 0x77, 0x1a, 0x09, 0xa1, 0xdb, 0x9b, 0xc2, 0x05, 0x43, 0xe9, + 0x88, 0xe3, 0x59, 0xd4, 0x7c, 0x2a, 0x61, 0x64, 0x17, 0xbb, 0xca, 0x1b, 0x62, 0xcb, + 0x02, 0x79, 0x6a, 0x88, 0x8f, 0xc6, 0xee, 0xff, 0x5c, 0x0b, 0x5c, 0x3d, 0x50, 0x62, + 0xfc, 0xb4, 0x25, 0x6f, 0x6a, 0xe1, 0x78, 0x2f, 0x49, 0x2c, 0x1c, 0xf0, 0x36, 0x10, + 0xb4, 0xa1, 0xfb, 0x7b, 0x81, 0x4c, 0x05, 0x78, 0x78, 0xe1, 0x19, 0x0b, 0x98, 0x35, + 0x42, 0x5c, 0x7a, 0x4a, 0x0e, 0x18, 0x2a, 0xd1, 0xf9, 0x15, 0x35, 0xed, 0x2a, 0x35, + 0x03, 0x3a, 0x5d, 0x8c, 0x67, 0x0e, 0x21, 0xc5, 0x75, 0xff, 0x43, 0xc1, 0x94, 0xa5, + 0x8a, 0x82, 0xd4, 0xa1, 0xa4, 0x48, 0x81, 0xdd, 0x61, 0xf9, 0xf8, 0x16, 0x1f, 0xc6, + 0xb9, 0x98, 0x86, 0x0c, 0xbe, 0x49, 0x75, 0x78, 0x0b, 0xe9, 0x3b, 0x6f, 0x87, 0x98, + 0x0b, 0xad, 0x0a, 0x99, 0xaa, 0x2c, 0xb7, 0x55, 0x6b, 0x47, 0x8c, 0xa3, 0x5d, 0x1f, + 0x37, 0x46, 0xc3, 0x3e, 0x2b, 0xb7, 0xc4, 0x7a, 0xf4, 0x26, 0x64, 0x1c, 0xc7, 0xbb, + 0xb3, 0x42, 0x5e, 0x21, 0x44, 0x82, 0x03, 0x45, 0xe1, 0xd0, 0xea, 0x5b, 0x7d, 0xa2, + 0xc3, 0x23, 0x6a, 0x52, 0x90, 0x6a, 0xcd, 0xc3, 0xb4, 0xd3, 0x4e, 0x47, 0x4d, 0xd7, + 0x14, 0xc0, 0xc4, 0x0b, 0xf0, 0x06, 0xa3, 0xa1, 0xd8, 0x89, 0xa6, 0x32, 0x98, 0x38, + 0x14, 0xbb, 0xc4, 0xa1, 0x4f, 0xe5, 0xf1, 0x59, 0xaa, 0x89, 0x24, 0x9e, 0x7c, 0x73, + 0x8b, 0x3b, 0x73, 0x66, 0x6b, 0xac, 0x2a, 0x61, 0x5a, 0x83, 0xfd, 0x21, 0xae, 0x0a, + 0x1c, 0xe7, 0x35, 0x2a, 0xde, 0x7b, 0x27, 0x8b, 0x58, 0x71, 0x58, 0xfd, 0x2f, 0xab, + 0xb2, 0x17, 0xaa, 0x1f, 0xe3, 0x1d, 0x0b, 0xda, 0x53, 0x27, 0x20, 0x45, 0x59, 0x80, + 0x15, 0xa8, 0xae, 0x4d, 0x8c, 0xec, 0x22, 0x6f, 0xef, 0xa5, 0x8d, 0xaa, 0x05, 0x50, + 0x09, 0x06, 0xc4, 0xd8, 0x5e, 0x75, 0x67 + ], + [ + 0xcb, 0x56, 0x48, 0xa1, 0xd6, 0x1c, 0x6c, 0x5b, 0xda, 0xcd, 0x96, 0xf8, 0x1c, 0x95, + 0x91, 0xde, 0xbc, 0x39, 0x50, 0xdc, 0xf6, 0x58, 0x14, 0x5b, 0x8d, 0x99, 0x65, 0x70, + 0xba, 0x88, 0x1a, 0x05 + ] + ); + + digest_test!( + sha3_384_0, + sha3_384, + 384, + [0; 0], + [ + 0x0c, 0x63, 0xa7, 0x5b, 0x84, 0x5e, 0x4f, 0x7d, 0x01, 0x10, 0x7d, 0x85, 0x2e, 0x4c, + 0x24, 0x85, 0xc5, 0x1a, 0x50, 0xaa, 0xaa, 0x94, 0xfc, 0x61, 0x99, 0x5e, 0x71, 0xbb, + 0xee, 0x98, 0x3a, 0x2a, 0xc3, 0x71, 0x38, 0x31, 0x26, 0x4a, 0xdb, 0x47, 0xfb, 0x6b, + 0xd1, 0xe0, 0x58, 0xd5, 0xf0, 0x04, + ] + ); + + digest_test!( + sha3_384_8, + sha3_384, + 384, + [0x80u8], + [ + 0x75, 0x41, 0x38, 0x48, 0x52, 0xe1, 0x0f, 0xf1, 0x0d, 0x5f, 0xb6, 0xa7, 0x21, 0x3a, + 0x4a, 0x6c, 0x15, 0xcc, 0xc8, 0x6d, 0x8b, 0xc1, 0x06, 0x8a, 0xc0, 0x4f, 0x69, 0x27, + 0x71, 0x42, 0x94, 0x4f, 0x4e, 0xe5, 0x0d, 0x91, 0xfd, 0xc5, 0x65, 0x53, 0xdb, 0x06, + 0xb2, 0xf5, 0x03, 0x9c, 0x8a, 0xb7, + ] + ); + + digest_test!( + sha3_384_2512, + sha3_384, + 384, + [ + 0x03, 0x5a, 0xdc, 0xb6, 0x39, 0xe5, 0xf2, 0x8b, 0xb5, 0xc8, 0x86, 0x58, 0xf4, 0x5c, + 0x1c, 0xe0, 0xbe, 0x16, 0xe7, 0xda, 0xfe, 0x08, 0x3b, 0x98, 0xd0, 0xab, 0x45, 0xe8, + 0xdc, 0xdb, 0xfa, 0x38, 0xe3, 0x23, 0x4d, 0xfd, 0x97, 0x3b, 0xa5, 0x55, 0xb0, 0xcf, + 0x8e, 0xea, 0x3c, 0x82, 0xae, 0x1a, 0x36, 0x33, 0xfc, 0x56, 0x5b, 0x7f, 0x2c, 0xc8, + 0x39, 0x87, 0x6d, 0x39, 0x89, 0xf3, 0x57, 0x31, 0xbe, 0x37, 0x1f, 0x60, 0xde, 0x14, + 0x0e, 0x3c, 0x91, 0x62, 0x31, 0xec, 0x78, 0x0e, 0x51, 0x65, 0xbf, 0x5f, 0x25, 0xd3, + 0xf6, 0x7d, 0xc7, 0x3a, 0x1c, 0x33, 0x65, 0x5d, 0xfd, 0xf4, 0x39, 0xdf, 0xbf, 0x1c, + 0xbb, 0xa8, 0xb7, 0x79, 0x15, 0x8a, 0x81, 0x0a, 0xd7, 0x24, 0x4f, 0x06, 0xec, 0x07, + 0x81, 0x20, 0xcd, 0x18, 0x76, 0x0a, 0xf4, 0x36, 0xa2, 0x38, 0x94, 0x1c, 0xe1, 0xe6, + 0x87, 0x88, 0x0b, 0x5c, 0x87, 0x9d, 0xc9, 0x71, 0xa2, 0x85, 0xa7, 0x4e, 0xe8, 0x5c, + 0x6a, 0x74, 0x67, 0x49, 0xa3, 0x01, 0x59, 0xee, 0x84, 0x2e, 0x9b, 0x03, 0xf3, 0x1d, + 0x61, 0x3d, 0xdd, 0xd2, 0x29, 0x75, 0xcd, 0x7f, 0xed, 0x06, 0xbd, 0x04, 0x9d, 0x77, + 0x2c, 0xb6, 0xcc, 0x5a, 0x70, 0x5f, 0xaa, 0x73, 0x4e, 0x87, 0x32, 0x1d, 0xc8, 0xf2, + 0xa4, 0xea, 0x36, 0x6a, 0x36, 0x8a, 0x98, 0xbf, 0x06, 0xee, 0x2b, 0x0b, 0x54, 0xac, + 0x3a, 0x3a, 0xee, 0xa6, 0x37, 0xca, 0xeb, 0xe7, 0x0a, 0xd0, 0x9c, 0xcd, 0xa9, 0x3c, + 0xc0, 0x6d, 0xe9, 0x5d, 0xf7, 0x33, 0x94, 0xa8, 0x7a, 0xc9, 0xbb, 0xb5, 0x08, 0x3a, + 0x4d, 0x8a, 0x24, 0x58, 0xe9, 0x1c, 0x7d, 0x5b, 0xf1, 0x13, 0xae, 0xca, 0xe0, 0xce, + 0x27, 0x9f, 0xdd, 0xa7, 0x6b, 0xa6, 0x90, 0x78, 0x7d, 0x26, 0x34, 0x5e, 0x94, 0xc3, + 0xed, 0xbc, 0x16, 0xa3, 0x5c, 0x83, 0xc4, 0xd0, 0x71, 0xb1, 0x32, 0xdd, 0x81, 0x18, + 0x7b, 0xcd, 0x99, 0x61, 0x32, 0x30, 0x11, 0x50, 0x9c, 0x8f, 0x64, 0x4a, 0x1c, 0x0a, + 0x3f, 0x14, 0xee, 0x40, 0xd7, 0xdd, 0x18, 0x6f, 0x80, 0x7f, 0x9e, 0xdc, 0x7c, 0x02, + 0xf6, 0x76, 0x10, 0x61, 0xbb, 0xb6, 0xdd, 0x91, 0xa6, 0xc9, 0x6e, 0xc0, 0xb9, 0xf1, + 0x0e, 0xdb, 0xbd, 0x29, 0xdc, 0x52 + ], + [ + 0x02, 0x53, 0x5d, 0x86, 0xcc, 0x75, 0x18, 0x48, 0x4a, 0x2a, 0x23, 0x8c, 0x92, 0x1b, + 0x73, 0x9b, 0x17, 0x04, 0xa5, 0x03, 0x70, 0xa2, 0x92, 0x4a, 0xbf, 0x39, 0x95, 0x8c, + 0x59, 0x76, 0xe6, 0x58, 0xdc, 0x5e, 0x87, 0x44, 0x00, 0x63, 0x11, 0x24, 0x59, 0xbd, + 0xdb, 0x40, 0x30, 0x8b, 0x1c, 0x70 + ] + ); + + digest_test!( + sha3_512_0, + sha3_512, + 512, + [0u8; 0], + [ + 0xa6, 0x9f, 0x73, 0xcc, 0xa2, 0x3a, 0x9a, 0xc5, 0xc8, 0xb5, 0x67, 0xdc, 0x18, 0x5a, + 0x75, 0x6e, 0x97, 0xc9, 0x82, 0x16, 0x4f, 0xe2, 0x58, 0x59, 0xe0, 0xd1, 0xdc, 0xc1, + 0x47, 0x5c, 0x80, 0xa6, 0x15, 0xb2, 0x12, 0x3a, 0xf1, 0xf5, 0xf9, 0x4c, 0x11, 0xe3, + 0xe9, 0x40, 0x2c, 0x3a, 0xc5, 0x58, 0xf5, 0x00, 0x19, 0x9d, 0x95, 0xb6, 0xd3, 0xe3, + 0x01, 0x75, 0x85, 0x86, 0x28, 0x1d, 0xcd, 0x26, + ] + ); + + digest_test!( + sha3_512_8, + sha3_512, + 512, + [0xe5u8], + [ + 0x15, 0x02, 0x40, 0xba, 0xf9, 0x5f, 0xb3, 0x6f, 0x8c, 0xcb, 0x87, 0xa1, 0x9a, 0x41, + 0x76, 0x7e, 0x7a, 0xed, 0x95, 0x12, 0x50, 0x75, 0xa2, 0xb2, 0xdb, 0xba, 0x6e, 0x56, + 0x5e, 0x1c, 0xe8, 0x57, 0x5f, 0x2b, 0x04, 0x2b, 0x62, 0xe2, 0x9a, 0x04, 0xe9, 0x44, + 0x03, 0x14, 0xa8, 0x21, 0xc6, 0x22, 0x41, 0x82, 0x96, 0x4d, 0x8b, 0x55, 0x7b, 0x16, + 0xa4, 0x92, 0xb3, 0x80, 0x6f, 0x4c, 0x39, 0xc1 + ] + ); + + digest_test!( + sha3_512_4080, + sha3_512, + 512, + [ + 0x43, 0x02, 0x56, 0x15, 0x52, 0x1d, 0x66, 0xfe, 0x8e, 0xc3, 0xa3, 0xf8, 0xcc, 0xc5, + 0xab, 0xfa, 0xb8, 0x70, 0xa4, 0x62, 0xc6, 0xb3, 0xd1, 0x39, 0x6b, 0x84, 0x62, 0xb9, + 0x8c, 0x7f, 0x91, 0x0c, 0x37, 0xd0, 0xea, 0x57, 0x91, 0x54, 0xea, 0xf7, 0x0f, 0xfb, + 0xcc, 0x0b, 0xe9, 0x71, 0xa0, 0x32, 0xcc, 0xfd, 0x9d, 0x96, 0xd0, 0xa9, 0xb8, 0x29, + 0xa9, 0xa3, 0x76, 0x2e, 0x21, 0xe3, 0xfe, 0xfc, 0xc6, 0x0e, 0x72, 0xfe, 0xdf, 0x9a, + 0x7f, 0xff, 0xa5, 0x34, 0x33, 0xa4, 0xb0, 0x5e, 0x0f, 0x3a, 0xb0, 0x5d, 0x5e, 0xb2, + 0x5d, 0x52, 0xc5, 0xea, 0xb1, 0xa7, 0x1a, 0x2f, 0x54, 0xac, 0x79, 0xff, 0x58, 0x82, + 0x95, 0x13, 0x26, 0x39, 0x4d, 0x9d, 0xb8, 0x35, 0x80, 0xce, 0x09, 0xd6, 0x21, 0x9b, + 0xca, 0x58, 0x8e, 0xc1, 0x57, 0xf7, 0x1d, 0x06, 0xe9, 0x57, 0xf8, 0xc2, 0x0d, 0x24, + 0x2c, 0x9f, 0x55, 0xf5, 0xfc, 0x9d, 0x4d, 0x77, 0x7b, 0x59, 0xb0, 0xc7, 0x5a, 0x8e, + 0xdc, 0x1f, 0xfe, 0xdc, 0x84, 0xb5, 0xd5, 0xc8, 0xa5, 0xe0, 0xeb, 0x05, 0xbb, 0x7d, + 0xb8, 0xf2, 0x34, 0x91, 0x3d, 0x63, 0x25, 0x30, 0x4f, 0xa4, 0x3c, 0x9d, 0x32, 0xbb, + 0xf6, 0xb2, 0x69, 0xee, 0x11, 0x82, 0xcd, 0x85, 0x45, 0x3e, 0xdd, 0xd1, 0x2f, 0x55, + 0x55, 0x6d, 0x8e, 0xdf, 0x02, 0xc4, 0xb1, 0x3c, 0xd4, 0xd3, 0x30, 0xf8, 0x35, 0x31, + 0xdb, 0xf2, 0x99, 0x4c, 0xf0, 0xbe, 0x56, 0xf5, 0x91, 0x47, 0xb7, 0x1f, 0x74, 0xb9, + 0x4b, 0xe3, 0xdd, 0x9e, 0x83, 0xc8, 0xc9, 0x47, 0x7c, 0x42, 0x6c, 0x6d, 0x1a, 0x78, + 0xde, 0x18, 0x56, 0x4a, 0x12, 0xc0, 0xd9, 0x93, 0x07, 0xb2, 0xc9, 0xab, 0x42, 0xb6, + 0xe3, 0x31, 0x7b, 0xef, 0xca, 0x07, 0x97, 0x02, 0x9e, 0x9d, 0xd6, 0x7b, 0xd1, 0x73, + 0x4e, 0x6c, 0x36, 0xd9, 0x98, 0x56, 0x5b, 0xfa, 0xc9, 0x4d, 0x19, 0x18, 0xa3, 0x58, + 0x69, 0x19, 0x0d, 0x17, 0x79, 0x43, 0xc1, 0xa8, 0x00, 0x44, 0x45, 0xca, 0xce, 0x75, + 0x1c, 0x43, 0xa7, 0x5f, 0x3d, 0x80, 0x51, 0x7f, 0xc4, 0x7c, 0xec, 0x46, 0xe8, 0xe3, + 0x82, 0x64, 0x2d, 0x76, 0xdf, 0x46, 0xda, 0xb1, 0xa3, 0xdd, 0xae, 0xab, 0x95, 0xa2, + 0xcf, 0x3f, 0x3a, 0xd7, 0x03, 0x69, 0xa7, 0x0f, 0x22, 0xf2, 0x93, 0xf0, 0xcc, 0x50, + 0xb0, 0x38, 0x57, 0xc8, 0x3c, 0xfe, 0x0b, 0xd5, 0xd2, 0x3b, 0x92, 0xcd, 0x87, 0x88, + 0xaa, 0xc2, 0x32, 0x29, 0x1d, 0xa6, 0x0b, 0x4b, 0xf3, 0xb3, 0x78, 0x8a, 0xe6, 0x0a, + 0x23, 0xb6, 0x16, 0x9b, 0x50, 0xd7, 0xfe, 0x44, 0x6e, 0x6e, 0xa7, 0x3d, 0xeb, 0xfe, + 0x1b, 0xb3, 0x4d, 0xcb, 0x1d, 0xb3, 0x7f, 0xe2, 0x17, 0x4a, 0x68, 0x59, 0x54, 0xeb, + 0xc2, 0xd8, 0x6f, 0x10, 0x2a, 0x59, 0x0c, 0x24, 0x73, 0x2b, 0xc5, 0xa1, 0x40, 0x3d, + 0x68, 0x76, 0xd2, 0x99, 0x5f, 0xab, 0x1e, 0x2f, 0x6f, 0x47, 0x23, 0xd4, 0xa6, 0x72, + 0x7a, 0x8a, 0x8e, 0xd7, 0x2f, 0x02, 0xa7, 0x4c, 0xcf, 0x5f, 0x14, 0xb5, 0xc2, 0x3d, + 0x95, 0x25, 0xdb, 0xf2, 0xb5, 0x47, 0x2e, 0x13, 0x45, 0xfd, 0x22, 0x3b, 0x08, 0x46, + 0xc7, 0x07, 0xb0, 0x65, 0x69, 0x65, 0x09, 0x40, 0x65, 0x0f, 0x75, 0x06, 0x3b, 0x52, + 0x98, 0x14, 0xe5, 0x14, 0x54, 0x1a, 0x67, 0x15, 0xf8, 0x79, 0xa8, 0x75, 0xb4, 0xf0, + 0x80, 0x77, 0x51, 0x78, 0x12, 0x84, 0x1e, 0x6c, 0x5c, 0x73, 0x2e, 0xed, 0x0c, 0x07, + 0xc0, 0x85, 0x95, 0xb9, 0xff, 0x0a, 0x83, 0xb8, 0xec, 0xc6, 0x0b, 0x2f, 0x98, 0xd4, + 0xe7, 0xc6, 0x96, 0xcd, 0x61, 0x6b, 0xb0, 0xa5, 0xad, 0x52, 0xd9, 0xcf, 0x7b, 0x3a, + 0x63, 0xa8, 0xcd, 0xf3, 0x72, 0x12 + ], + [ + 0xe7, 0xba, 0x73, 0x40, 0x7a, 0xa4, 0x56, 0xae, 0xce, 0x21, 0x10, 0x77, 0xd9, 0x20, + 0x87, 0xd5, 0xcd, 0x28, 0x3e, 0x38, 0x68, 0xd2, 0x84, 0xe0, 0x7e, 0xd1, 0x24, 0xb2, + 0x7c, 0xbc, 0x66, 0x4a, 0x6a, 0x47, 0x5a, 0x8d, 0x7b, 0x4c, 0xf6, 0xa8, 0xa4, 0x92, + 0x7e, 0xe0, 0x59, 0xa2, 0x62, 0x6a, 0x4f, 0x98, 0x39, 0x23, 0x36, 0x01, 0x45, 0xb2, + 0x65, 0xeb, 0xfd, 0x4f, 0x5b, 0x3c, 0x44, 0xfd + ] + ); +} diff --git a/src/ciphers/tea.rs b/src/ciphers/tea.rs new file mode 100644 index 00000000000..f482226e693 --- /dev/null +++ b/src/ciphers/tea.rs @@ -0,0 +1,143 @@ +use std::num::Wrapping as W; + +struct TeaContext { + key0: u64, + key1: u64, +} + +impl TeaContext { + pub fn new(key: &[u64; 2]) -> TeaContext { + TeaContext { + key0: key[0], + key1: key[1], + } + } + + pub fn encrypt_block(&self, block: u64) -> u64 { + let (mut b0, mut b1) = divide_u64(block); + let (k0, k1) = divide_u64(self.key0); + let (k2, k3) = divide_u64(self.key1); + let mut sum = W(0u32); + + for _ in 0..32 { + sum += W(0x9E3779B9); + b0 += ((b1 << 4) + k0) ^ (b1 + sum) ^ ((b1 >> 5) + k1); + b1 += ((b0 << 4) + k2) ^ (b0 + sum) ^ ((b0 >> 5) + k3); + } + + ((b1.0 as u64) << 32) | b0.0 as u64 + } + + pub fn decrypt_block(&self, block: u64) -> u64 { + let (mut b0, mut b1) = divide_u64(block); + let (k0, k1) = divide_u64(self.key0); + let (k2, k3) = divide_u64(self.key1); + let mut sum = W(0xC6EF3720u32); + + for _ in 0..32 { + b1 -= ((b0 << 4) + k2) ^ (b0 + sum) ^ ((b0 >> 5) + k3); + b0 -= ((b1 << 4) + k0) ^ (b1 + sum) ^ ((b1 >> 5) + k1); + sum -= W(0x9E3779B9); + } + + ((b1.0 as u64) << 32) | b0.0 as u64 + } +} + +#[inline] +fn divide_u64(n: u64) -> (W, W) { + (W(n as u32), W((n >> 32) as u32)) +} + +pub fn tea_encrypt(plain: &[u8], key: &[u8]) -> Vec { + let tea = TeaContext::new(&[to_block(&key[..8]), to_block(&key[8..16])]); + let mut result: Vec = Vec::new(); + + for i in (0..plain.len()).step_by(8) { + let block = to_block(&plain[i..i + 8]); + result.extend(from_block(tea.encrypt_block(block)).iter()); + } + + result +} + +pub fn tea_decrypt(cipher: &[u8], key: &[u8]) -> Vec { + let tea = TeaContext::new(&[to_block(&key[..8]), to_block(&key[8..16])]); + let mut result: Vec = Vec::new(); + + for i in (0..cipher.len()).step_by(8) { + let block = to_block(&cipher[i..i + 8]); + result.extend(from_block(tea.decrypt_block(block)).iter()); + } + + result +} + +#[inline] +fn to_block(data: &[u8]) -> u64 { + data[0] as u64 + | (data[1] as u64) << 8 + | (data[2] as u64) << 16 + | (data[3] as u64) << 24 + | (data[4] as u64) << 32 + | (data[5] as u64) << 40 + | (data[6] as u64) << 48 + | (data[7] as u64) << 56 +} + +fn from_block(block: u64) -> [u8; 8] { + [ + block as u8, + (block >> 8) as u8, + (block >> 16) as u8, + (block >> 24) as u8, + (block >> 32) as u8, + (block >> 40) as u8, + (block >> 48) as u8, + (block >> 56) as u8, + ] +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_block_convert() { + assert_eq!( + to_block(&[0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef]), + 0xefcdab8967452301 + ); + + assert_eq!( + from_block(0xefcdab8967452301), + [0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef] + ); + } + + #[test] + fn test_tea_encrypt() { + assert_eq!( + tea_encrypt( + &[0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00], + &[ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00 + ] + ), + [0x0A, 0x3A, 0xEA, 0x41, 0x40, 0xA9, 0xBA, 0x94] + ); + } + + #[test] + fn test_tea_encdec() { + let plain = &[0x1b, 0xcc, 0xd4, 0x31, 0xa0, 0xf6, 0x8a, 0x55]; + let key = &[ + 0x20, 0x45, 0x08, 0x10, 0xb0, 0x23, 0xe2, 0x17, 0xc3, 0x81, 0xd6, 0xf2, 0xee, 0x00, + 0xa4, 0x8a, + ]; + let cipher = tea_encrypt(plain, key); + + assert_eq!(tea_decrypt(&cipher[..], key), plain); + } +} diff --git a/src/ciphers/theoretical_rot13.rs b/src/ciphers/theoretical_rot13.rs new file mode 100644 index 00000000000..1077a38dbce --- /dev/null +++ b/src/ciphers/theoretical_rot13.rs @@ -0,0 +1,43 @@ +// in theory rot-13 only affects the lowercase characters in a cipher +pub fn theoretical_rot13(text: &str) -> String { + let mut pos: u8 = 0; + let mut npos: u8 = 0; + text.to_owned() + .chars() + .map(|mut c| { + if c.is_ascii_lowercase() { + // ((c as u8) + 13) as char + pos = c as u8 - b'a'; + npos = (pos + 13) % 26; + c = (npos + b'a') as char; + c + } else { + c + } + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_single_letter() { + assert_eq!("n", theoretical_rot13("a")); + } + #[test] + fn test_bunch_of_letters() { + assert_eq!("nop op", theoretical_rot13("abc bc")); + } + + #[test] + fn test_non_ascii() { + assert_eq!("😀ab", theoretical_rot13("😀no")); + } + + #[test] + fn test_twice() { + assert_eq!("abcd", theoretical_rot13(&theoretical_rot13("abcd"))); + } +} diff --git a/src/ciphers/transposition.rs b/src/ciphers/transposition.rs new file mode 100644 index 00000000000..d5b2a75196e --- /dev/null +++ b/src/ciphers/transposition.rs @@ -0,0 +1,288 @@ +//! Transposition Cipher +//! +//! The Transposition Cipher is a method of encryption by which a message is shifted +//! according to a regular system, so that the ciphertext is a rearrangement of the +//! original message. The most commonly referred to Transposition Cipher is the +//! COLUMNAR TRANSPOSITION cipher, which is demonstrated below. + +use std::ops::RangeInclusive; + +/// Encrypts or decrypts a message, using multiple keys. The +/// encryption is based on the columnar transposition method. +pub fn transposition(decrypt_mode: bool, msg: &str, key: &str) -> String { + let key_uppercase = key.to_uppercase(); + let mut cipher_msg: String = msg.to_string(); + + let keys: Vec<&str> = if decrypt_mode { + key_uppercase.split_whitespace().rev().collect() + } else { + key_uppercase.split_whitespace().collect() + }; + + for cipher_key in keys.iter() { + let mut key_order: Vec = Vec::new(); + + // Removes any non-alphabet characters from 'msg' + cipher_msg = cipher_msg + .to_uppercase() + .chars() + .filter(|&c| c.is_ascii_alphabetic()) + .collect(); + + // Determines the sequence of the columns, as dictated by the + // alphabetical order of the keyword's letters + let mut key_ascii: Vec<(usize, u8)> = + cipher_key.bytes().enumerate().collect::>(); + + key_ascii.sort_by_key(|&(_, key)| key); + + for (counter, (_, key)) in key_ascii.iter_mut().enumerate() { + *key = counter as u8; + } + + key_ascii.sort_by_key(|&(index, _)| index); + + key_ascii + .into_iter() + .for_each(|(_, key)| key_order.push(key.into())); + + // Determines whether to encrypt or decrypt the message, + // and returns the result + cipher_msg = if decrypt_mode { + decrypt(cipher_msg, key_order) + } else { + encrypt(cipher_msg, key_order) + }; + } + + cipher_msg +} + +/// Performs the columnar transposition encryption +fn encrypt(mut msg: String, key_order: Vec) -> String { + let mut encrypted_msg: String = String::from(""); + let mut encrypted_vec: Vec = Vec::new(); + + let msg_len = msg.len(); + let key_len: usize = key_order.len(); + + let mut msg_index: usize = msg_len; + let mut key_index: usize = key_len; + + // Loop each column, pushing it to a Vec + while !msg.is_empty() { + let mut chars: String = String::from(""); + let mut index: usize = 0; + key_index -= 1; + + // Loop every nth character, determined by key length, to create a column + while index < msg_index { + let ch = msg.remove(index); + chars.push(ch); + + index += key_index; + msg_index -= 1; + } + + encrypted_vec.push(chars); + } + + // Concatenate the columns into a string, determined by the + // alphabetical order of the keyword's characters + let mut indexed_vec: Vec<(usize, &String)> = Vec::new(); + let mut indexed_msg: String = String::from(""); + + for (counter, key_index) in key_order.into_iter().enumerate() { + indexed_vec.push((key_index, &encrypted_vec[counter])); + } + + indexed_vec.sort(); + + for (_, column) in indexed_vec { + indexed_msg.push_str(column); + } + + // Split the message by a space every nth character, determined by + // 'message length divided by keyword length' to the next highest integer. + let msg_div: usize = (msg_len as f32 / key_len as f32).ceil() as usize; + let mut counter: usize = 0; + + indexed_msg.chars().for_each(|c| { + encrypted_msg.push(c); + counter += 1; + if counter == msg_div { + encrypted_msg.push(' '); + counter = 0; + } + }); + + encrypted_msg.trim_end().to_string() +} + +/// Performs the columnar transposition decryption +fn decrypt(mut msg: String, key_order: Vec) -> String { + let mut decrypted_msg: String = String::from(""); + let mut decrypted_vec: Vec = Vec::new(); + let mut indexed_vec: Vec<(usize, String)> = Vec::new(); + + let msg_len = msg.len(); + let key_len: usize = key_order.len(); + + // Split the message into columns, determined by 'message length divided by keyword length'. + // Some columns are larger by '+1', where the prior calculation leaves a remainder. + let split_size: usize = (msg_len as f64 / key_len as f64) as usize; + let msg_mod: usize = msg_len % key_len; + let mut counter: usize = msg_mod; + + let mut key_split: Vec = key_order.clone(); + let (split_large, split_small) = key_split.split_at_mut(msg_mod); + + split_large.sort_unstable(); + split_small.sort_unstable(); + + split_large.iter_mut().rev().for_each(|key_index| { + counter -= 1; + let range: RangeInclusive = + ((*key_index * split_size) + counter)..=(((*key_index + 1) * split_size) + counter); + + let slice: String = msg[range.clone()].to_string(); + indexed_vec.push((*key_index, slice)); + + msg.replace_range(range, ""); + }); + + for key_index in split_small.iter_mut() { + let (slice, rest_of_msg) = msg.split_at(split_size); + indexed_vec.push((*key_index, (slice.to_string()))); + msg = rest_of_msg.to_string(); + } + + indexed_vec.sort(); + + for key in key_order { + if let Some((_, column)) = indexed_vec.iter().find(|(key_index, _)| key_index == &key) { + decrypted_vec.push(column.clone()); + } + } + + // Concatenate the columns into a string, determined by the + // alphabetical order of the keyword's characters + for _ in 0..split_size { + decrypted_vec.iter_mut().for_each(|column| { + decrypted_msg.push(column.remove(0)); + }) + } + + if !decrypted_vec.is_empty() { + decrypted_vec.into_iter().for_each(|chars| { + decrypted_msg.push_str(&chars); + }) + } + + decrypted_msg +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn encryption() { + assert_eq!( + transposition( + false, + "The quick brown fox jumps over the lazy dog", + "Archive", + ), + "TKOOL ERJEZ CFSEG QOURY UWMTD HBXVA INPHO" + ); + + assert_eq!( + transposition( + false, + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.,/;'[]{}:|_+=-`~() ", + "Tenacious" + ), + "DMVENW ENWFOX BKTCLU FOXGPY CLUDMV GPYHQZ IRAJSA JSBKTH QZIR" + ); + + assert_eq!( + transposition(false, "WE ARE DISCOVERED. FLEE AT ONCE.", "ZEBRAS"), + "EVLNA CDTES EAROF ODEEC WIREE" + ); + } + + #[test] + fn decryption() { + assert_eq!( + transposition(true, "TKOOL ERJEZ CFSEG QOURY UWMTD HBXVA INPHO", "Archive"), + "THEQUICKBROWNFOXJUMPSOVERTHELAZYDOG" + ); + + assert_eq!( + transposition( + true, + "DMVENW ENWFOX BKTCLU FOXGPY CLUDMV GPYHQZ IRAJSA JSBKTH QZIR", + "Tenacious" + ), + "ABCDEFGHIJKLMNOPQRSTUVWXYZABCDEFGHIJKLMNOPQRSTUVWXYZ" + ); + + assert_eq!( + transposition(true, "EVLNA CDTES EAROF ODEEC WIREE", "ZEBRAS"), + "WEAREDISCOVEREDFLEEATONCE" + ); + } + + #[test] + fn double_encryption() { + assert_eq!( + transposition( + false, + "The quick brown fox jumps over the lazy dog", + "Archive Snow" + ), + "KEZEUWHAH ORCGRMBIO TLESOUDVP OJFQYTXN" + ); + + assert_eq!( + transposition( + false, + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.,/;'[]{}:|_+=-`~() ", + "Tenacious Drink" + ), + "DWOCXLGZSKI VNBUPDYRJHN FTOCVQJBZEW KFYMHASQMEX LGUPIATR" + ); + + assert_eq!( + transposition(false, "WE ARE DISCOVERED. FLEE AT ONCE.", "ZEBRAS STRIPE"), + "CAEEN SOIAE DRLEF WEDRE EVTOC" + ); + } + + #[test] + fn double_decryption() { + assert_eq!( + transposition( + true, + "KEZEUWHAH ORCGRMBIO TLESOUDVP OJFQYTXN", + "Archive Snow" + ), + "THEQUICKBROWNFOXJUMPSOVERTHELAZYDOG" + ); + + assert_eq!( + transposition( + true, + "DWOCXLGZSKI VNBUPDYRJHN FTOCVQJBZEW KFYMHASQMEX LGUPIATR", + "Tenacious Drink", + ), + "ABCDEFGHIJKLMNOPQRSTUVWXYZABCDEFGHIJKLMNOPQRSTUVWXYZ" + ); + + assert_eq!( + transposition(true, "CAEEN SOIAE DRLEF WEDRE EVTOC", "ZEBRAS STRIPE"), + "WEAREDISCOVEREDFLEEATONCE" + ); + } +} diff --git a/src/ciphers/vigenere.rs b/src/ciphers/vigenere.rs index f01ad585714..2f1f5d6ab1e 100644 --- a/src/ciphers/vigenere.rs +++ b/src/ciphers/vigenere.rs @@ -10,7 +10,7 @@ pub fn vigenere(plain_text: &str, key: &str) -> String { // Remove all unicode and non-ascii characters from key let key: String = key.chars().filter(|&c| c.is_ascii_alphabetic()).collect(); - key.to_ascii_lowercase(); + let key = key.to_ascii_lowercase(); let key_len = key.len(); if key_len == 0 { diff --git a/src/ciphers/xor.rs b/src/ciphers/xor.rs new file mode 100644 index 00000000000..a01351611da --- /dev/null +++ b/src/ciphers/xor.rs @@ -0,0 +1,50 @@ +pub fn xor_bytes(text: &[u8], key: u8) -> Vec { + text.iter().map(|c| c ^ key).collect() +} + +pub fn xor(text: &str, key: u8) -> Vec { + xor_bytes(text.as_bytes(), key) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_simple() { + let test_string = "test string"; + let ciphered_text = xor(test_string, 32); + assert_eq!(test_string.as_bytes(), xor_bytes(&ciphered_text, 32)); + } + + #[test] + fn test_every_alphabet_with_space() { + let test_string = "The quick brown fox jumps over the lazy dog"; + let ciphered_text = xor(test_string, 64); + assert_eq!(test_string.as_bytes(), xor_bytes(&ciphered_text, 64)); + } + + #[test] + fn test_multi_byte() { + let test_string = "日本語"; + let key = 42; + let ciphered_text = xor(test_string, key); + assert_eq!(test_string.as_bytes(), xor_bytes(&ciphered_text, key)); + } + + #[test] + fn test_zero_byte() { + let test_string = "The quick brown fox jumps over the lazy dog"; + let key = b' '; + let ciphered_text = xor(test_string, key); + assert_eq!(test_string.as_bytes(), xor_bytes(&ciphered_text, key)); + } + + #[test] + fn test_invalid_byte() { + let test_string = "The quick brown fox jumps over the lazy dog"; + let key = !0; + let ciphered_text = xor(test_string, key); + assert_eq!(test_string.as_bytes(), xor_bytes(&ciphered_text, key)); + } +} diff --git a/src/compression/mod.rs b/src/compression/mod.rs new file mode 100644 index 00000000000..7acbee56ec5 --- /dev/null +++ b/src/compression/mod.rs @@ -0,0 +1,5 @@ +mod move_to_front; +mod run_length_encoding; + +pub use self::move_to_front::{move_to_front_decode, move_to_front_encode}; +pub use self::run_length_encoding::{run_length_decode, run_length_encode}; diff --git a/src/compression/move_to_front.rs b/src/compression/move_to_front.rs new file mode 100644 index 00000000000..fe38b02ef7c --- /dev/null +++ b/src/compression/move_to_front.rs @@ -0,0 +1,60 @@ +// https://en.wikipedia.org/wiki/Move-to-front_transform + +fn blank_char_table() -> Vec { + (0..=255).map(|ch| ch as u8 as char).collect() +} + +pub fn move_to_front_encode(text: &str) -> Vec { + let mut char_table = blank_char_table(); + let mut result = Vec::new(); + + for ch in text.chars() { + if let Some(position) = char_table.iter().position(|&x| x == ch) { + result.push(position as u8); + char_table.remove(position); + char_table.insert(0, ch); + } + } + + result +} + +pub fn move_to_front_decode(encoded: &[u8]) -> String { + let mut char_table = blank_char_table(); + let mut result = String::new(); + + for &pos in encoded { + let ch = char_table[pos as usize]; + result.push(ch); + char_table.remove(pos as usize); + char_table.insert(0, ch); + } + + result +} + +#[cfg(test)] +mod test { + use super::*; + + macro_rules! test_mtf { + ($($name:ident: ($text:expr, $encoded:expr),)*) => { + $( + #[test] + fn $name() { + assert_eq!(move_to_front_encode($text), $encoded); + assert_eq!(move_to_front_decode($encoded), $text); + } + )* + } + } + + test_mtf! { + empty: ("", &[]), + single_char: ("@", &[64]), + repeated_chars: ("aaba", &[97, 0, 98, 1]), + mixed_chars: ("aZ!", &[97, 91, 35]), + word: ("banana", &[98, 98, 110, 1, 1, 1]), + special_chars: ("\0\n\t", &[0, 10, 10]), + } +} diff --git a/src/compression/run_length_encoding.rs b/src/compression/run_length_encoding.rs new file mode 100644 index 00000000000..03b554b0a3e --- /dev/null +++ b/src/compression/run_length_encoding.rs @@ -0,0 +1,74 @@ +// https://en.wikipedia.org/wiki/Run-length_encoding + +pub fn run_length_encode(text: &str) -> Vec<(char, i32)> { + let mut count = 1; + let mut encoded: Vec<(char, i32)> = vec![]; + + for (i, c) in text.chars().enumerate() { + if i + 1 < text.len() && c == text.chars().nth(i + 1).unwrap() { + count += 1; + } else { + encoded.push((c, count)); + count = 1; + } + } + + encoded +} + +pub fn run_length_decode(encoded: &[(char, i32)]) -> String { + let res = encoded + .iter() + .map(|x| (x.0).to_string().repeat(x.1 as usize)) + .collect::(); + + res +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_run_length_decode() { + let res = run_length_decode(&[('A', 0)]); + assert_eq!(res, ""); + let res = run_length_decode(&[('B', 1)]); + assert_eq!(res, "B"); + let res = run_length_decode(&[('A', 5), ('z', 3), ('B', 1)]); + assert_eq!(res, "AAAAAzzzB"); + } + + #[test] + fn test_run_length_encode() { + let res = run_length_encode(""); + assert_eq!(res, []); + + let res = run_length_encode("A"); + assert_eq!(res, [('A', 1)]); + + let res = run_length_encode("AA"); + assert_eq!(res, [('A', 2)]); + + let res = run_length_encode("AAAABBBCCDAA"); + assert_eq!(res, [('A', 4), ('B', 3), ('C', 2), ('D', 1), ('A', 2)]); + + let res = run_length_encode("Rust-Trends"); + assert_eq!( + res, + [ + ('R', 1), + ('u', 1), + ('s', 1), + ('t', 1), + ('-', 1), + ('T', 1), + ('r', 1), + ('e', 1), + ('n', 1), + ('d', 1), + ('s', 1) + ] + ); + } +} diff --git a/src/conversions/binary_to_decimal.rs b/src/conversions/binary_to_decimal.rs new file mode 100644 index 00000000000..2be7b0f9f07 --- /dev/null +++ b/src/conversions/binary_to_decimal.rs @@ -0,0 +1,90 @@ +use num_traits::CheckedAdd; + +pub fn binary_to_decimal(binary: &str) -> Option { + if binary.len() > 128 { + return None; + } + let mut num = 0; + let mut idx_val = 1; + for bit in binary.chars().rev() { + match bit { + '1' => { + if let Some(sum) = num.checked_add(&idx_val) { + num = sum; + } else { + return None; + } + } + '0' => {} + _ => return None, + } + idx_val <<= 1; + } + Some(num) +} + +#[cfg(test)] +mod tests { + use super::binary_to_decimal; + + #[test] + fn basic_binary_to_decimal() { + assert_eq!(binary_to_decimal("0000000110"), Some(6)); + assert_eq!(binary_to_decimal("1000011110"), Some(542)); + assert_eq!(binary_to_decimal("1111111111"), Some(1023)); + } + #[test] + fn big_binary_to_decimal() { + assert_eq!( + binary_to_decimal("111111111111111111111111"), + Some(16_777_215) + ); + // 32 bits + assert_eq!( + binary_to_decimal("11111111111111111111111111111111"), + Some(4_294_967_295) + ); + // 64 bits + assert_eq!( + binary_to_decimal("1111111111111111111111111111111111111111111111111111111111111111"), + Some(18_446_744_073_709_551_615u128) + ); + } + #[test] + fn very_big_binary_to_decimal() { + // 96 bits + assert_eq!( + binary_to_decimal( + "1111111111111111111111111111111111111111111111111111111111111111\ + 11111111111111111111111111111111" + ), + Some(79_228_162_514_264_337_593_543_950_335u128) + ); + + // 128 bits + assert_eq!( + binary_to_decimal( + "1111111111111111111111111111111111111111111111111111111111111111\ + 1111111111111111111111111111111111111111111111111111111111111111" + ), + Some(340_282_366_920_938_463_463_374_607_431_768_211_455u128) + ); + // 129 bits, should overflow + assert!(binary_to_decimal( + "1111111111111111111111111111111111111111111111111111111111111111\ + 11111111111111111111111111111111111111111111111111111111111111111" + ) + .is_none()); + // obviously none + assert!(binary_to_decimal( + "1111111111111111111111111111111111111111111111111111111111111111\ + 1111111111111111111111111111111111111111111111111111111111111\ + 1111111111111111111111111111111111111111111111111111111111111\ + 1111111111111111111111111111111111111111111111111111111111111\ + 1111111111111111111111111111111111111111111111111111111111111\ + 1111111111111111111111111111111111111111111111111111111111111\ + 1111111111111111111111111111111111111111111111111111111111111" + ) + .is_none()); + } +} diff --git a/src/conversions/binary_to_hexadecimal.rs b/src/conversions/binary_to_hexadecimal.rs new file mode 100644 index 00000000000..0a4e52291a6 --- /dev/null +++ b/src/conversions/binary_to_hexadecimal.rs @@ -0,0 +1,106 @@ +// Author : cyrixninja +// Binary to Hex Converter : Converts Binary to Hexadecimal +// Wikipedia References : 1. https://en.wikipedia.org/wiki/Hexadecimal +// 2. https://en.wikipedia.org/wiki/Binary_number + +static BITS_TO_HEX: &[(u8, &str)] = &[ + (0b0000, "0"), + (0b0001, "1"), + (0b0010, "2"), + (0b0011, "3"), + (0b0100, "4"), + (0b0101, "5"), + (0b0110, "6"), + (0b0111, "7"), + (0b1000, "8"), + (0b1001, "9"), + (0b1010, "a"), + (0b1011, "b"), + (0b1100, "c"), + (0b1101, "d"), + (0b1110, "e"), + (0b1111, "f"), +]; + +pub fn binary_to_hexadecimal(binary_str: &str) -> String { + let binary_str = binary_str.trim(); + + if binary_str.is_empty() { + return String::from("Invalid Input"); + } + + let is_negative = binary_str.starts_with('-'); + let binary_str = if is_negative { + &binary_str[1..] + } else { + binary_str + }; + + if !binary_str.chars().all(|c| c == '0' || c == '1') { + return String::from("Invalid Input"); + } + + let padded_len = (4 - (binary_str.len() % 4)) % 4; + let binary_str = format!( + "{:0width$}", + binary_str, + width = binary_str.len() + padded_len + ); + + // Convert binary to hexadecimal + let mut hexadecimal = String::with_capacity(binary_str.len() / 4 + 2); + hexadecimal.push_str("0x"); + + for chunk in binary_str.as_bytes().chunks(4) { + let mut nibble = 0; + for (i, &byte) in chunk.iter().enumerate() { + nibble |= (byte - b'0') << (3 - i); + } + + let hex_char = BITS_TO_HEX + .iter() + .find(|&&(bits, _)| bits == nibble) + .map(|&(_, hex)| hex) + .unwrap(); + hexadecimal.push_str(hex_char); + } + + if is_negative { + format!("-{hexadecimal}") + } else { + hexadecimal + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_string() { + let input = ""; + let expected = "Invalid Input"; + assert_eq!(binary_to_hexadecimal(input), expected); + } + + #[test] + fn test_invalid_binary() { + let input = "a"; + let expected = "Invalid Input"; + assert_eq!(binary_to_hexadecimal(input), expected); + } + + #[test] + fn test_binary() { + let input = "00110110"; + let expected = "0x36"; + assert_eq!(binary_to_hexadecimal(input), expected); + } + + #[test] + fn test_padded_binary() { + let input = " 1010 "; + let expected = "0xa"; + assert_eq!(binary_to_hexadecimal(input), expected); + } +} diff --git a/src/conversions/decimal_to_binary.rs b/src/conversions/decimal_to_binary.rs new file mode 100644 index 00000000000..40e180b2914 --- /dev/null +++ b/src/conversions/decimal_to_binary.rs @@ -0,0 +1,26 @@ +pub fn decimal_to_binary(base_num: u64) -> String { + let mut num = base_num; + let mut binary_num = String::new(); + loop { + let bit = (num % 2).to_string(); + binary_num.push_str(&bit); + num /= 2; + if num == 0 { + break; + } + } + + let bits = binary_num.chars(); + bits.rev().collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn converting_decimal_to_binary() { + assert_eq!(decimal_to_binary(542), "1000011110"); + assert_eq!(decimal_to_binary(92), "1011100"); + } +} diff --git a/src/conversions/decimal_to_hexadecimal.rs b/src/conversions/decimal_to_hexadecimal.rs new file mode 100644 index 00000000000..57ff3e11b71 --- /dev/null +++ b/src/conversions/decimal_to_hexadecimal.rs @@ -0,0 +1,56 @@ +pub fn decimal_to_hexadecimal(base_num: u64) -> String { + let mut num = base_num; + let mut hexadecimal_num = String::new(); + + loop { + let remainder = num % 16; + let hex_char = if remainder < 10 { + (remainder as u8 + b'0') as char + } else { + (remainder as u8 - 10 + b'A') as char + }; + + hexadecimal_num.insert(0, hex_char); + num /= 16; + if num == 0 { + break; + } + } + + hexadecimal_num +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_zero() { + assert_eq!(decimal_to_hexadecimal(0), "0"); + } + + #[test] + fn test_single_digit_decimal() { + assert_eq!(decimal_to_hexadecimal(9), "9"); + } + + #[test] + fn test_single_digit_hexadecimal() { + assert_eq!(decimal_to_hexadecimal(12), "C"); + } + + #[test] + fn test_multiple_digit_hexadecimal() { + assert_eq!(decimal_to_hexadecimal(255), "FF"); + } + + #[test] + fn test_big() { + assert_eq!(decimal_to_hexadecimal(u64::MAX), "FFFFFFFFFFFFFFFF"); + } + + #[test] + fn test_random() { + assert_eq!(decimal_to_hexadecimal(123456), "1E240"); + } +} diff --git a/src/conversions/hexadecimal_to_binary.rs b/src/conversions/hexadecimal_to_binary.rs new file mode 100644 index 00000000000..490b69e8fb0 --- /dev/null +++ b/src/conversions/hexadecimal_to_binary.rs @@ -0,0 +1,67 @@ +// Author : cyrixninja +// Hexadecimal to Binary Converter : Converts Hexadecimal to Binary +// Wikipedia References : 1. https://en.wikipedia.org/wiki/Hexadecimal +// 2. https://en.wikipedia.org/wiki/Binary_number +// Other References for Testing : https://www.rapidtables.com/convert/number/hex-to-binary.html + +pub fn hexadecimal_to_binary(hex_str: &str) -> Result { + let hex_chars = hex_str.chars().collect::>(); + let mut binary = String::new(); + + for c in hex_chars { + let bin_rep = match c { + '0' => "0000", + '1' => "0001", + '2' => "0010", + '3' => "0011", + '4' => "0100", + '5' => "0101", + '6' => "0110", + '7' => "0111", + '8' => "1000", + '9' => "1001", + 'a' | 'A' => "1010", + 'b' | 'B' => "1011", + 'c' | 'C' => "1100", + 'd' | 'D' => "1101", + 'e' | 'E' => "1110", + 'f' | 'F' => "1111", + _ => return Err("Invalid".to_string()), + }; + binary.push_str(bin_rep); + } + + Ok(binary) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_string() { + let input = ""; + let expected = Ok("".to_string()); + assert_eq!(hexadecimal_to_binary(input), expected); + } + + #[test] + fn test_hexadecimal() { + let input = "1a2"; + let expected = Ok("000110100010".to_string()); + assert_eq!(hexadecimal_to_binary(input), expected); + } + #[test] + fn test_hexadecimal2() { + let input = "1b3"; + let expected = Ok("000110110011".to_string()); + assert_eq!(hexadecimal_to_binary(input), expected); + } + + #[test] + fn test_invalid_hexadecimal() { + let input = "1g3"; + let expected = Err("Invalid".to_string()); + assert_eq!(hexadecimal_to_binary(input), expected); + } +} diff --git a/src/conversions/hexadecimal_to_decimal.rs b/src/conversions/hexadecimal_to_decimal.rs new file mode 100644 index 00000000000..5f71716d039 --- /dev/null +++ b/src/conversions/hexadecimal_to_decimal.rs @@ -0,0 +1,61 @@ +pub fn hexadecimal_to_decimal(hexadecimal_str: &str) -> Result { + if hexadecimal_str.is_empty() { + return Err("Empty input"); + } + + for hexadecimal_str in hexadecimal_str.chars() { + if !hexadecimal_str.is_ascii_hexdigit() { + return Err("Input was not a hexadecimal number"); + } + } + + match u64::from_str_radix(hexadecimal_str, 16) { + Ok(decimal) => Ok(decimal), + Err(_e) => Err("Failed to convert octal to hexadecimal"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_hexadecimal_to_decimal_empty() { + assert_eq!(hexadecimal_to_decimal(""), Err("Empty input")); + } + + #[test] + fn test_hexadecimal_to_decimal_invalid() { + assert_eq!( + hexadecimal_to_decimal("xyz"), + Err("Input was not a hexadecimal number") + ); + assert_eq!( + hexadecimal_to_decimal("0xabc"), + Err("Input was not a hexadecimal number") + ); + } + + #[test] + fn test_hexadecimal_to_decimal_valid1() { + assert_eq!(hexadecimal_to_decimal("45"), Ok(69)); + assert_eq!(hexadecimal_to_decimal("2b3"), Ok(691)); + assert_eq!(hexadecimal_to_decimal("4d2"), Ok(1234)); + assert_eq!(hexadecimal_to_decimal("1267a"), Ok(75386)); + } + + #[test] + fn test_hexadecimal_to_decimal_valid2() { + assert_eq!(hexadecimal_to_decimal("1a"), Ok(26)); + assert_eq!(hexadecimal_to_decimal("ff"), Ok(255)); + assert_eq!(hexadecimal_to_decimal("a1b"), Ok(2587)); + assert_eq!(hexadecimal_to_decimal("7fffffff"), Ok(2147483647)); + } + + #[test] + fn test_hexadecimal_to_decimal_valid3() { + assert_eq!(hexadecimal_to_decimal("0"), Ok(0)); + assert_eq!(hexadecimal_to_decimal("7f"), Ok(127)); + assert_eq!(hexadecimal_to_decimal("80000000"), Ok(2147483648)); + } +} diff --git a/src/conversions/length_conversion.rs b/src/conversions/length_conversion.rs new file mode 100644 index 00000000000..4a056ed3052 --- /dev/null +++ b/src/conversions/length_conversion.rs @@ -0,0 +1,94 @@ +/// Author : https://github.com/ali77gh +/// Conversion of length units. +/// +/// Available Units: +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Millimeter +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Centimeter +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Meter +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Kilometer +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Inch +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Foot +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Yard +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Mile + +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub enum LengthUnit { + Millimeter, + Centimeter, + Meter, + Kilometer, + Inch, + Foot, + Yard, + Mile, +} + +fn unit_to_meter_multiplier(from: LengthUnit) -> f64 { + match from { + LengthUnit::Millimeter => 0.001, + LengthUnit::Centimeter => 0.01, + LengthUnit::Meter => 1.0, + LengthUnit::Kilometer => 1000.0, + LengthUnit::Inch => 0.0254, + LengthUnit::Foot => 0.3048, + LengthUnit::Yard => 0.9144, + LengthUnit::Mile => 1609.34, + } +} + +fn unit_to_meter(input: f64, from: LengthUnit) -> f64 { + input * unit_to_meter_multiplier(from) +} + +fn meter_to_unit(input: f64, to: LengthUnit) -> f64 { + input / unit_to_meter_multiplier(to) +} + +/// This function will convert a value in unit of [from] to value in unit of [to] +/// by first converting it to meter and than convert it to destination unit +pub fn length_conversion(input: f64, from: LengthUnit, to: LengthUnit) -> f64 { + meter_to_unit(unit_to_meter(input, from), to) +} + +#[cfg(test)] +mod length_conversion_tests { + use std::collections::HashMap; + + use super::LengthUnit::*; + use super::*; + + #[test] + fn zero_to_zero() { + let units = vec![ + Millimeter, Centimeter, Meter, Kilometer, Inch, Foot, Yard, Mile, + ]; + + for u1 in units.clone() { + for u2 in units.clone() { + assert_eq!(length_conversion(0f64, u1, u2), 0f64); + } + } + } + + #[test] + fn length_of_one_meter() { + let meter_in_different_units = HashMap::from([ + (Millimeter, 1000f64), + (Centimeter, 100f64), + (Kilometer, 0.001f64), + (Inch, 39.37007874015748f64), + (Foot, 3.280839895013123f64), + (Yard, 1.0936132983377078f64), + (Mile, 0.0006213727366498068f64), + ]); + for (input_unit, input_value) in &meter_in_different_units { + for (target_unit, target_value) in &meter_in_different_units { + assert!( + num_traits::abs( + length_conversion(*input_value, *input_unit, *target_unit) - *target_value + ) < 0.0000001 + ); + } + } + } +} diff --git a/src/conversions/mod.rs b/src/conversions/mod.rs new file mode 100644 index 00000000000..a83c46bf600 --- /dev/null +++ b/src/conversions/mod.rs @@ -0,0 +1,20 @@ +mod binary_to_decimal; +mod binary_to_hexadecimal; +mod decimal_to_binary; +mod decimal_to_hexadecimal; +mod hexadecimal_to_binary; +mod hexadecimal_to_decimal; +mod length_conversion; +mod octal_to_binary; +mod octal_to_decimal; +mod rgb_cmyk_conversion; +pub use self::binary_to_decimal::binary_to_decimal; +pub use self::binary_to_hexadecimal::binary_to_hexadecimal; +pub use self::decimal_to_binary::decimal_to_binary; +pub use self::decimal_to_hexadecimal::decimal_to_hexadecimal; +pub use self::hexadecimal_to_binary::hexadecimal_to_binary; +pub use self::hexadecimal_to_decimal::hexadecimal_to_decimal; +pub use self::length_conversion::length_conversion; +pub use self::octal_to_binary::octal_to_binary; +pub use self::octal_to_decimal::octal_to_decimal; +pub use self::rgb_cmyk_conversion::rgb_to_cmyk; diff --git a/src/conversions/octal_to_binary.rs b/src/conversions/octal_to_binary.rs new file mode 100644 index 00000000000..ba4a9ccebd2 --- /dev/null +++ b/src/conversions/octal_to_binary.rs @@ -0,0 +1,60 @@ +// Author : cyrixninja +// Octal to Binary Converter : Converts Octal to Binary +// Wikipedia References : 1. https://en.wikipedia.org/wiki/Octal +// 2. https://en.wikipedia.org/wiki/Binary_number + +pub fn octal_to_binary(octal_str: &str) -> Result { + let octal_str = octal_str.trim(); + + if octal_str.is_empty() { + return Err("Empty"); + } + + if !octal_str.chars().all(|c| ('0'..'7').contains(&c)) { + return Err("Non-octal Value"); + } + + // Convert octal to binary + let binary = octal_str + .chars() + .map(|c| match c { + '0' => "000", + '1' => "001", + '2' => "010", + '3' => "011", + '4' => "100", + '5' => "101", + '6' => "110", + '7' => "111", + _ => unreachable!(), + }) + .collect::(); + + Ok(binary) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_string() { + let input = ""; + let expected = Err("Empty"); + assert_eq!(octal_to_binary(input), expected); + } + + #[test] + fn test_invalid_octal() { + let input = "89"; + let expected = Err("Non-octal Value"); + assert_eq!(octal_to_binary(input), expected); + } + + #[test] + fn test_valid_octal() { + let input = "123"; + let expected = Ok("001010011".to_string()); + assert_eq!(octal_to_binary(input), expected); + } +} diff --git a/src/conversions/octal_to_decimal.rs b/src/conversions/octal_to_decimal.rs new file mode 100644 index 00000000000..18ab5076916 --- /dev/null +++ b/src/conversions/octal_to_decimal.rs @@ -0,0 +1,60 @@ +// Author: cyrixninja +// Octal to Decimal Converter: Converts Octal to Decimal +// Wikipedia References: +// 1. https://en.wikipedia.org/wiki/Octal +// 2. https://en.wikipedia.org/wiki/Decimal + +pub fn octal_to_decimal(octal_str: &str) -> Result { + let octal_str = octal_str.trim(); + + if octal_str.is_empty() { + return Err("Empty"); + } + + if !octal_str.chars().all(|c| ('0'..='7').contains(&c)) { + return Err("Non-octal Value"); + } + + // Convert octal to decimal and directly return the Result + u64::from_str_radix(octal_str, 8).map_err(|_| "Conversion error") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_string() { + let input = ""; + let expected = Err("Empty"); + assert_eq!(octal_to_decimal(input), expected); + } + + #[test] + fn test_invalid_octal() { + let input = "89"; + let expected = Err("Non-octal Value"); + assert_eq!(octal_to_decimal(input), expected); + } + + #[test] + fn test_valid_octal() { + let input = "123"; + let expected = Ok(83); + assert_eq!(octal_to_decimal(input), expected); + } + + #[test] + fn test_valid_octal2() { + let input = "1234"; + let expected = Ok(668); + assert_eq!(octal_to_decimal(input), expected); + } + + #[test] + fn test_valid_octal3() { + let input = "12345"; + let expected = Ok(5349); + assert_eq!(octal_to_decimal(input), expected); + } +} diff --git a/src/conversions/rgb_cmyk_conversion.rs b/src/conversions/rgb_cmyk_conversion.rs new file mode 100644 index 00000000000..30a8bc9bd84 --- /dev/null +++ b/src/conversions/rgb_cmyk_conversion.rs @@ -0,0 +1,60 @@ +/// Author : https://github.com/ali77gh\ +/// References:\ +/// RGB: https://en.wikipedia.org/wiki/RGB_color_model\ +/// CMYK: https://en.wikipedia.org/wiki/CMYK_color_model\ + +/// This function Converts RGB to CMYK format +/// +/// ### Params +/// * `r` - red +/// * `g` - green +/// * `b` - blue +/// +/// ### Returns +/// (C, M, Y, K) +pub fn rgb_to_cmyk(rgb: (u8, u8, u8)) -> (u8, u8, u8, u8) { + // Safety: no need to check if input is positive and less than 255 because it's u8 + + // change scale from [0,255] to [0,1] + let (r, g, b) = ( + rgb.0 as f64 / 255f64, + rgb.1 as f64 / 255f64, + rgb.2 as f64 / 255f64, + ); + + match 1f64 - r.max(g).max(b) { + 1f64 => (0, 0, 0, 100), // pure black + k => ( + (100f64 * (1f64 - r - k) / (1f64 - k)) as u8, // c + (100f64 * (1f64 - g - k) / (1f64 - k)) as u8, // m + (100f64 * (1f64 - b - k) / (1f64 - k)) as u8, // y + (100f64 * k) as u8, // k + ), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_rgb_to_cmyk { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (rgb, cmyk) = $tc; + assert_eq!(rgb_to_cmyk(rgb), cmyk); + } + )* + } + } + + test_rgb_to_cmyk! { + white: ((255, 255, 255), (0, 0, 0, 0)), + gray: ((128, 128, 128), (0, 0, 0, 49)), + black: ((0, 0, 0), (0, 0, 0, 100)), + red: ((255, 0, 0), (0, 100, 100, 0)), + green: ((0, 255, 0), (100, 0, 100, 0)), + blue: ((0, 0, 255), (100, 100, 0, 0)), + } +} diff --git a/src/data_structures/README.md b/src/data_structures/README.md index a5b3e0f0267..688602052da 100644 --- a/src/data_structures/README.md +++ b/src/data_structures/README.md @@ -1,9 +1,9 @@ -### [Binary](./binary_search.rs) +### [B-Trees](./b_tree.rs) B-Trees are version of 2-3 trees, which are self-balancing. They are used to improve Disk reads and have a complexity of O(log(n)), for every tree operations.The number of Childrens/Keys a particular node has, is determined by the Branching Factor/Degree of that tree. -Btrees will always have sorted keys. +B-Trees will always have sorted keys. - Branching Factor(B) / Degree (D): If B = n, n <= Children per Node < 2(n), n-1 <= Keys per Node < 2(n) - 1 @@ -18,3 +18,66 @@ __Sources to read:__ * [Rust API Docs](https://doc.rust-lang.org/std/collections/struct.BTreeMap.html) * [Keon Algorithms](https://github.com/keon/algorithms) * [MIT Open Course](https://www.youtube.com/watch?v=TOb1tuEZ2X4) + +### [AVL Tree](./avl_tree.rs) + +An AVL Tree is a self-balancing binary search tree. The heights of any two sibling +nodes must differ by at most one; the tree may rebalance itself after insertion or +deletion to uphold this property. + +__Properties__ +* Worst/Average time complexity for basic operations: O(log n) +* Worst/Average space complexity: O(n) + +__Sources to read:__ +* [Wikipedia](https://en.wikipedia.org/wiki/AVL_tree) +* Geeksforgeeks +([Insertion](https://www.geeksforgeeks.org/avl-tree-set-1-insertion), +[Deletion](https://www.geeksforgeeks.org/avl-tree-set-2-deletion)) + + +### [Doubly linked list](./linked_list.rs) +![alt text][doubly-linked-list] + +A linked list is also a `linear` data structure, and each element in the linked list is actually a separate object while all the objects are `linked together by the reference filed` in each element. In a `doubly linked list`, each node contains, besides the `next` node link, a second link field pointing to the `previous` node in the sequence. The two links may be called `next` and `prev`. And many modern operating systems use doubly linked lists to maintain references to active processes, threads and other dynamic objects. + +__Properties__ +* Indexing O(n) +* Insertion O(1) + * Beginning O(1) + * Middle (Indexing time+O(1)) + * End O(n) +* Deletion O(1) + * Beginning O(1) + * Middle (Indexing time+O(1)) + * End O(n) +* Search O(n) + +__Source to read:__ +* [Wikipedia](https://en.wikipedia.org/wiki/Linked_list) +* [LeetCode](https://leetcode.com/explore/learn/card/linked-list/) +* [Brilliant](https://brilliant.org/wiki/linked-lists/) +* [Rust API Docs](https://doc.rust-lang.org/std/collections/struct.LinkedList.html) + + +### [Stack Using Singly Linked List](./stack_using_singly_linked_list.rs) +![][stack] + +From Wikipedia, a stack is an abstract data type that serves as a collection of elements, with two main principal operations, `Push` and `Pop`. + +__Properties__ +* Push O(1) +* Pop head.data O(1) tail.data O(n) +* Peek O(1) + + +__Source to read:__ +* [Wikipedia](https://en.wikipedia.org/wiki/Linked_list) +* [rust-unofficial](https://rust-unofficial.github.io/too-many-lists/index.html) +* [Stack Implementation and complexity](https://medium.com/@kaichimomose/stack-implementation-and-complexity-c176924e6a6b) + + + +[doubly-linked-list]: https://upload.wikimedia.org/wikipedia/commons/thumb/5/5e/Doubly-linked-list.svg/610px-Doubly-linked-list.svg.png + +[stack]: https://upload.wikimedia.org/wikipedia/commons/thumb/b/b4/Lifo_stack.png/700px-Lifo_stack.png \ No newline at end of file diff --git a/src/data_structures/avl_tree.rs b/src/data_structures/avl_tree.rs new file mode 100644 index 00000000000..64800a405ca --- /dev/null +++ b/src/data_structures/avl_tree.rs @@ -0,0 +1,380 @@ +use std::{ + cmp::{max, Ordering}, + iter::FromIterator, + mem, + ops::Not, +}; + +/// An internal node of an `AVLTree`. +struct AVLNode { + value: T, + height: usize, + left: Option>>, + right: Option>>, +} + +/// A set based on an AVL Tree. +/// +/// An AVL Tree is a self-balancing binary search tree. It tracks the height of each node +/// and performs internal rotations to maintain a height difference of at most 1 between +/// each sibling pair. +pub struct AVLTree { + root: Option>>, + length: usize, +} + +/// Refers to the left or right subtree of an `AVLNode`. +#[derive(Clone, Copy)] +enum Side { + Left, + Right, +} + +impl AVLTree { + /// Creates an empty `AVLTree`. + pub fn new() -> AVLTree { + AVLTree { + root: None, + length: 0, + } + } + + /// Returns `true` if the tree contains a value. + pub fn contains(&self, value: &T) -> bool { + let mut current = &self.root; + while let Some(node) = current { + current = match value.cmp(&node.value) { + Ordering::Equal => return true, + Ordering::Less => &node.left, + Ordering::Greater => &node.right, + } + } + false + } + + /// Adds a value to the tree. + /// + /// Returns `true` if the tree did not yet contain the value. + pub fn insert(&mut self, value: T) -> bool { + let inserted = insert(&mut self.root, value); + if inserted { + self.length += 1; + } + inserted + } + + /// Removes a value from the tree. + /// + /// Returns `true` if the tree contained the value. + pub fn remove(&mut self, value: &T) -> bool { + let removed = remove(&mut self.root, value); + if removed { + self.length -= 1; + } + removed + } + + /// Returns the number of values in the tree. + pub fn len(&self) -> usize { + self.length + } + + /// Returns `true` if the tree contains no values. + pub fn is_empty(&self) -> bool { + self.length == 0 + } + + /// Returns an iterator that visits the nodes in the tree in order. + fn node_iter(&self) -> NodeIter { + let cap = self.root.as_ref().map_or(0, |n| n.height); + let mut node_iter = NodeIter { + stack: Vec::with_capacity(cap), + }; + // Initialize stack with path to leftmost child + let mut child = &self.root; + while let Some(node) = child { + node_iter.stack.push(node.as_ref()); + child = &node.left; + } + node_iter + } + + /// Returns an iterator that visits the values in the tree in ascending order. + pub fn iter(&self) -> Iter { + Iter { + node_iter: self.node_iter(), + } + } +} + +/// Recursive helper function for `AVLTree` insertion. +fn insert(tree: &mut Option>>, value: T) -> bool { + if let Some(node) = tree { + let inserted = match value.cmp(&node.value) { + Ordering::Equal => false, + Ordering::Less => insert(&mut node.left, value), + Ordering::Greater => insert(&mut node.right, value), + }; + if inserted { + node.rebalance(); + } + inserted + } else { + *tree = Some(Box::new(AVLNode { + value, + height: 1, + left: None, + right: None, + })); + true + } +} + +/// Recursive helper function for `AVLTree` deletion. +fn remove(tree: &mut Option>>, value: &T) -> bool { + if let Some(node) = tree { + let removed = match value.cmp(&node.value) { + Ordering::Less => remove(&mut node.left, value), + Ordering::Greater => remove(&mut node.right, value), + Ordering::Equal => { + *tree = match (node.left.take(), node.right.take()) { + (None, None) => None, + (Some(b), None) | (None, Some(b)) => Some(b), + (Some(left), Some(right)) => Some(merge(left, right)), + }; + return true; + } + }; + if removed { + node.rebalance(); + } + removed + } else { + false + } +} + +/// Merges two trees and returns the root of the merged tree. +fn merge(left: Box>, right: Box>) -> Box> { + let mut op_right = Some(right); + // Guaranteed not to panic since right has at least one node + let mut root = take_min(&mut op_right).unwrap(); + root.left = Some(left); + root.right = op_right; + root.rebalance(); + root +} + +/// Removes the smallest node from the tree, if one exists. +fn take_min(tree: &mut Option>>) -> Option>> { + if let Some(mut node) = tree.take() { + // Recurse along the left side + if let Some(small) = take_min(&mut node.left) { + // Took the smallest from below; update this node and put it back in the tree + node.rebalance(); + *tree = Some(node); + Some(small) + } else { + // Take this node and replace it with its right child + *tree = node.right.take(); + Some(node) + } + } else { + None + } +} + +impl AVLNode { + /// Returns a reference to the left or right child. + fn child(&self, side: Side) -> &Option>> { + match side { + Side::Left => &self.left, + Side::Right => &self.right, + } + } + + /// Returns a mutable reference to the left or right child. + fn child_mut(&mut self, side: Side) -> &mut Option>> { + match side { + Side::Left => &mut self.left, + Side::Right => &mut self.right, + } + } + + /// Returns the height of the left or right subtree. + fn height(&self, side: Side) -> usize { + self.child(side).as_ref().map_or(0, |n| n.height) + } + + /// Returns the height difference between the left and right subtrees. + fn balance_factor(&self) -> i8 { + let (left, right) = (self.height(Side::Left), self.height(Side::Right)); + if left < right { + (right - left) as i8 + } else { + -((left - right) as i8) + } + } + + /// Recomputes the `height` field. + fn update_height(&mut self) { + self.height = 1 + max(self.height(Side::Left), self.height(Side::Right)); + } + + /// Performs a left or right rotation. + fn rotate(&mut self, side: Side) { + let mut subtree = self.child_mut(!side).take().unwrap(); + *self.child_mut(!side) = subtree.child_mut(side).take(); + self.update_height(); + // Swap root and child nodes in memory + mem::swap(self, subtree.as_mut()); + // Set old root (subtree) as child of new root (self) + *self.child_mut(side) = Some(subtree); + self.update_height(); + } + + /// Performs left or right tree rotations to balance this node. + fn rebalance(&mut self) { + self.update_height(); + let side = match self.balance_factor() { + -2 => Side::Left, + 2 => Side::Right, + _ => return, + }; + let subtree = self.child_mut(side).as_mut().unwrap(); + // Left-Right and Right-Left require rotation of heavy subtree + if let (Side::Left, 1) | (Side::Right, -1) = (side, subtree.balance_factor()) { + subtree.rotate(side); + } + // Rotate in opposite direction of heavy side + self.rotate(!side); + } +} + +impl Default for AVLTree { + fn default() -> Self { + Self::new() + } +} + +impl Not for Side { + type Output = Side; + + fn not(self) -> Self::Output { + match self { + Side::Left => Side::Right, + Side::Right => Side::Left, + } + } +} + +impl FromIterator for AVLTree { + fn from_iter>(iter: I) -> Self { + let mut tree = AVLTree::new(); + for value in iter { + tree.insert(value); + } + tree + } +} + +/// An iterator over the nodes of an `AVLTree`. +/// +/// This struct is created by the `node_iter` method of `AVLTree`. +struct NodeIter<'a, T: Ord> { + stack: Vec<&'a AVLNode>, +} + +impl<'a, T: Ord> Iterator for NodeIter<'a, T> { + type Item = &'a AVLNode; + + fn next(&mut self) -> Option { + if let Some(node) = self.stack.pop() { + // Push left path of right subtree to stack + let mut child = &node.right; + while let Some(subtree) = child { + self.stack.push(subtree.as_ref()); + child = &subtree.left; + } + Some(node) + } else { + None + } + } +} + +/// An iterator over the items of an `AVLTree`. +/// +/// This struct is created by the `iter` method of `AVLTree`. +pub struct Iter<'a, T: Ord> { + node_iter: NodeIter<'a, T>, +} + +impl<'a, T: Ord> Iterator for Iter<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option<&'a T> { + match self.node_iter.next() { + Some(node) => Some(&node.value), + None => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::AVLTree; + + /// Returns `true` if all nodes in the tree are balanced. + fn is_balanced(tree: &AVLTree) -> bool { + tree.node_iter() + .all(|n| (-1..=1).contains(&n.balance_factor())) + } + + #[test] + fn len() { + let tree: AVLTree<_> = (1..4).collect(); + assert_eq!(tree.len(), 3); + } + + #[test] + fn contains() { + let tree: AVLTree<_> = (1..4).collect(); + assert!(tree.contains(&1)); + assert!(!tree.contains(&4)); + } + + #[test] + fn insert() { + let mut tree = AVLTree::new(); + // First insert succeeds + assert!(tree.insert(1)); + // Second insert fails + assert!(!tree.insert(1)); + } + + #[test] + fn remove() { + let mut tree: AVLTree<_> = (1..8).collect(); + // First remove succeeds + assert!(tree.remove(&4)); + // Second remove fails + assert!(!tree.remove(&4)); + } + + #[test] + fn sorted() { + let tree: AVLTree<_> = (1..8).rev().collect(); + assert!((1..8).eq(tree.iter().copied())); + } + + #[test] + fn balanced() { + let mut tree: AVLTree<_> = (1..8).collect(); + assert!(is_balanced(&tree)); + for x in 1..8 { + tree.remove(&x); + assert!(is_balanced(&tree)); + } + } +} diff --git a/src/data_structures/b_tree.rs b/src/data_structures/b_tree.rs index a17a85215d3..bf7266932ae 100644 --- a/src/data_structures/b_tree.rs +++ b/src/data_structures/b_tree.rs @@ -100,19 +100,18 @@ impl BTreeProps { self.insert_non_full(&mut node.children[u_index], key); } } - - fn traverse_node(&self, node: &Node, depth: usize) { + fn traverse_node(node: &Node, depth: usize) { if node.is_leaf() { print!(" {0:{<1$}{2:?}{0:}<1$} ", "", depth, node.keys); } else { let _depth = depth + 1; for (index, key) in node.keys.iter().enumerate() { - self.traverse_node(&node.children[index], _depth); + Self::traverse_node(&node.children[index], _depth); // Check https://doc.rust-lang.org/std/fmt/index.html // And https://stackoverflow.com/a/35280799/2849127 print!("{0:{<1$}{2:?}{0:}<1$}", "", depth, key); } - self.traverse_node(node.children.last().unwrap(), _depth); + Self::traverse_node(node.children.last().unwrap(), _depth); } } } @@ -141,26 +140,21 @@ where } pub fn traverse(&self) { - self.props.traverse_node(&self.root, 0); + BTreeProps::traverse_node(&self.root, 0); println!(); } pub fn search(&self, key: T) -> bool { let mut current_node = &self.root; - let mut index: isize; loop { - index = isize::try_from(current_node.keys.len()).ok().unwrap() - 1; - while index >= 0 && current_node.keys[index as usize] > key { - index -= 1; - } - - let u_index: usize = usize::try_from(index + 1).ok().unwrap(); - if index >= 0 && current_node.keys[u_index - 1] == key { - break true; - } else if current_node.is_leaf() { - break false; - } else { - current_node = ¤t_node.children[u_index]; + match current_node.keys.binary_search(&key) { + Ok(_) => return true, + Err(index) => { + if current_node.is_leaf() { + return false; + } + current_node = ¤t_node.children[index]; + } } } } @@ -170,19 +164,46 @@ where mod test { use super::BTree; - #[test] - fn test_search() { - let mut tree = BTree::new(2); - tree.insert(10); - tree.insert(20); - tree.insert(30); - tree.insert(5); - tree.insert(6); - tree.insert(7); - tree.insert(11); - tree.insert(12); - tree.insert(15); - assert!(tree.search(15)); - assert_eq!(tree.search(16), false); + macro_rules! test_search { + ($($name:ident: $number_of_children:expr,)*) => { + $( + #[test] + fn $name() { + let mut tree = BTree::new($number_of_children); + tree.insert(10); + tree.insert(20); + tree.insert(30); + tree.insert(5); + tree.insert(6); + tree.insert(7); + tree.insert(11); + tree.insert(12); + tree.insert(15); + assert!(!tree.search(4)); + assert!(tree.search(5)); + assert!(tree.search(6)); + assert!(tree.search(7)); + assert!(!tree.search(8)); + assert!(!tree.search(9)); + assert!(tree.search(10)); + assert!(tree.search(11)); + assert!(tree.search(12)); + assert!(!tree.search(13)); + assert!(!tree.search(14)); + assert!(tree.search(15)); + assert!(!tree.search(16)); + } + )* + } + } + + test_search! { + children_2: 2, + children_3: 3, + children_4: 4, + children_5: 5, + children_10: 10, + children_60: 60, + children_101: 101, } } diff --git a/src/data_structures/binary_search_tree.rs b/src/data_structures/binary_search_tree.rs index 5766d222a44..193fb485408 100644 --- a/src/data_structures/binary_search_tree.rs +++ b/src/data_structures/binary_search_tree.rs @@ -34,7 +34,7 @@ where } } - /// Find a value in this tree. Returns True iff value is in this + /// Find a value in this tree. Returns True if value is in this /// tree, and false otherwise pub fn search(&self, value: &T) -> bool { match &self.value { @@ -71,26 +71,22 @@ where /// Insert a value into the appropriate location in this tree. pub fn insert(&mut self, value: T) { - if self.value.is_none() { - self.value = Some(value); - } else { - match &self.value { - None => (), - Some(key) => { - let target_node = if value < *key { - &mut self.left - } else { - &mut self.right - }; - match target_node { - Some(ref mut node) => { - node.insert(value); - } - None => { - let mut node = BinarySearchTree::new(); - node.insert(value); - *target_node = Some(Box::new(node)); - } + match &self.value { + None => self.value = Some(value), + Some(key) => { + let target_node = if value < *key { + &mut self.left + } else { + &mut self.right + }; + match target_node { + Some(ref mut node) => { + node.insert(value); + } + None => { + let mut node = BinarySearchTree::new(); + node.value = Some(value); + *target_node = Some(Box::new(node)); } } } @@ -101,10 +97,7 @@ where pub fn minimum(&self) -> Option<&T> { match &self.left { Some(node) => node.minimum(), - None => match &self.value { - Some(value) => Some(value), - None => None, - }, + None => self.value.as_ref(), } } @@ -112,10 +105,7 @@ where pub fn maximum(&self) -> Option<&T> { match &self.right { Some(node) => node.maximum(), - None => match &self.value { - Some(value) => Some(value), - None => None, - }, + None => self.value.as_ref(), } } @@ -194,7 +184,7 @@ where stack: Vec<&'a BinarySearchTree>, } -impl<'a, T> BinarySearchTreeIter<'a, T> +impl BinarySearchTreeIter<'_, T> where T: Ord, { diff --git a/src/data_structures/fenwick_tree.rs b/src/data_structures/fenwick_tree.rs new file mode 100644 index 00000000000..c4b9c571de4 --- /dev/null +++ b/src/data_structures/fenwick_tree.rs @@ -0,0 +1,264 @@ +use std::ops::{Add, AddAssign, Sub, SubAssign}; + +/// A Fenwick Tree (also known as a Binary Indexed Tree) that supports efficient +/// prefix sum, range sum and point queries, as well as point updates. +/// +/// The Fenwick Tree uses **1-based** indexing internally but presents a **0-based** interface to the user. +/// This design improves efficiency and simplifies both internal operations and external usage. +pub struct FenwickTree +where + T: Add + AddAssign + Sub + SubAssign + Copy + Default, +{ + /// Internal storage of the Fenwick Tree. The first element (index 0) is unused + /// to simplify index calculations, so the effective tree size is `data.len() - 1`. + data: Vec, +} + +/// Enum representing the possible errors that can occur during FenwickTree operations. +#[derive(Debug, PartialEq, Eq)] +pub enum FenwickTreeError { + /// Error indicating that an index was out of the valid range. + IndexOutOfBounds, + /// Error indicating that a provided range was invalid (e.g., left > right). + InvalidRange, +} + +impl FenwickTree +where + T: Add + AddAssign + Sub + SubAssign + Copy + Default, +{ + /// Creates a new Fenwick Tree with a specified capacity. + /// + /// The tree will have `capacity + 1` elements, all initialized to the default + /// value of type `T`. The additional element allows for 1-based indexing internally. + /// + /// # Arguments + /// + /// * `capacity` - The number of elements the tree can hold (excluding the extra element). + /// + /// # Returns + /// + /// A new `FenwickTree` instance. + pub fn with_capacity(capacity: usize) -> Self { + FenwickTree { + data: vec![T::default(); capacity + 1], + } + } + + /// Updates the tree by adding a value to the element at a specified index. + /// + /// This operation also propagates the update to subsequent elements in the tree. + /// + /// # Arguments + /// + /// * `index` - The zero-based index where the value should be added. + /// * `value` - The value to add to the element at the specified index. + /// + /// # Returns + /// + /// A `Result` indicating success (`Ok`) or an error (`FenwickTreeError::IndexOutOfBounds`) + /// if the index is out of bounds. + pub fn update(&mut self, index: usize, value: T) -> Result<(), FenwickTreeError> { + if index >= self.data.len() - 1 { + return Err(FenwickTreeError::IndexOutOfBounds); + } + + let mut idx = index + 1; + while idx < self.data.len() { + self.data[idx] += value; + idx += lowbit(idx); + } + + Ok(()) + } + + /// Computes the sum of elements from the start of the tree up to a specified index. + /// + /// This operation efficiently calculates the prefix sum using the tree structure. + /// + /// # Arguments + /// + /// * `index` - The zero-based index up to which the sum should be computed. + /// + /// # Returns + /// + /// A `Result` containing the prefix sum (`Ok(sum)`) or an error (`FenwickTreeError::IndexOutOfBounds`) + /// if the index is out of bounds. + pub fn prefix_query(&self, index: usize) -> Result { + if index >= self.data.len() - 1 { + return Err(FenwickTreeError::IndexOutOfBounds); + } + + let mut idx = index + 1; + let mut result = T::default(); + while idx > 0 { + result += self.data[idx]; + idx -= lowbit(idx); + } + + Ok(result) + } + + /// Computes the sum of elements within a specified range `[left, right]`. + /// + /// This operation calculates the range sum by performing two prefix sum queries. + /// + /// # Arguments + /// + /// * `left` - The zero-based starting index of the range. + /// * `right` - The zero-based ending index of the range. + /// + /// # Returns + /// + /// A `Result` containing the range sum (`Ok(sum)`) or an error (`FenwickTreeError::InvalidRange`) + /// if the left index is greater than the right index or the right index is out of bounds. + pub fn range_query(&self, left: usize, right: usize) -> Result { + if left > right || right >= self.data.len() - 1 { + return Err(FenwickTreeError::InvalidRange); + } + + let right_query = self.prefix_query(right)?; + let left_query = if left == 0 { + T::default() + } else { + self.prefix_query(left - 1)? + }; + + Ok(right_query - left_query) + } + + /// Retrieves the value at a specific index by isolating it from the prefix sum. + /// + /// This operation determines the value at `index` by subtracting the prefix sum up to `index - 1` + /// from the prefix sum up to `index`. + /// + /// # Arguments + /// + /// * `index` - The zero-based index of the element to retrieve. + /// + /// # Returns + /// + /// A `Result` containing the value at the specified index (`Ok(value)`) or an error (`FenwickTreeError::IndexOutOfBounds`) + /// if the index is out of bounds. + pub fn point_query(&self, index: usize) -> Result { + if index >= self.data.len() - 1 { + return Err(FenwickTreeError::IndexOutOfBounds); + } + + let index_query = self.prefix_query(index)?; + let prev_query = if index == 0 { + T::default() + } else { + self.prefix_query(index - 1)? + }; + + Ok(index_query - prev_query) + } + + /// Sets the value at a specific index in the tree, updating the structure accordingly. + /// + /// This operation updates the value at `index` by computing the difference between the + /// desired value and the current value, then applying that difference using `update`. + /// + /// # Arguments + /// + /// * `index` - The zero-based index of the element to set. + /// * `value` - The new value to set at the specified index. + /// + /// # Returns + /// + /// A `Result` indicating success (`Ok`) or an error (`FenwickTreeError::IndexOutOfBounds`) + /// if the index is out of bounds. + pub fn set(&mut self, index: usize, value: T) -> Result<(), FenwickTreeError> { + self.update(index, value - self.point_query(index)?) + } +} + +/// Computes the lowest set bit (rightmost `1` bit) of a number. +/// +/// This function isolates the lowest set bit in the binary representation of `x`. +/// It's used to navigate the Fenwick Tree by determining the next index to update or query. +/// +/// +/// In a Fenwick Tree, operations like updating and querying use bitwise manipulation +/// (via the lowbit function). These operations naturally align with 1-based indexing, +/// making traversal between parent and child nodes more straightforward. +/// +/// # Arguments +/// +/// * `x` - The input number whose lowest set bit is to be determined. +/// +/// # Returns +/// +/// The value of the lowest set bit in `x`. +const fn lowbit(x: usize) -> usize { + x & (!x + 1) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_fenwick_tree() { + let mut fenwick_tree = FenwickTree::with_capacity(10); + + assert_eq!(fenwick_tree.update(0, 5), Ok(())); + assert_eq!(fenwick_tree.update(1, 3), Ok(())); + assert_eq!(fenwick_tree.update(2, -2), Ok(())); + assert_eq!(fenwick_tree.update(3, 6), Ok(())); + assert_eq!(fenwick_tree.update(4, -4), Ok(())); + assert_eq!(fenwick_tree.update(5, 7), Ok(())); + assert_eq!(fenwick_tree.update(6, -1), Ok(())); + assert_eq!(fenwick_tree.update(7, 2), Ok(())); + assert_eq!(fenwick_tree.update(8, -3), Ok(())); + assert_eq!(fenwick_tree.update(9, 4), Ok(())); + assert_eq!(fenwick_tree.set(3, 10), Ok(())); + assert_eq!(fenwick_tree.point_query(3), Ok(10)); + assert_eq!(fenwick_tree.set(5, 0), Ok(())); + assert_eq!(fenwick_tree.point_query(5), Ok(0)); + assert_eq!( + fenwick_tree.update(10, 11), + Err(FenwickTreeError::IndexOutOfBounds) + ); + assert_eq!( + fenwick_tree.set(10, 11), + Err(FenwickTreeError::IndexOutOfBounds) + ); + + assert_eq!(fenwick_tree.prefix_query(0), Ok(5)); + assert_eq!(fenwick_tree.prefix_query(1), Ok(8)); + assert_eq!(fenwick_tree.prefix_query(2), Ok(6)); + assert_eq!(fenwick_tree.prefix_query(3), Ok(16)); + assert_eq!(fenwick_tree.prefix_query(4), Ok(12)); + assert_eq!(fenwick_tree.prefix_query(5), Ok(12)); + assert_eq!(fenwick_tree.prefix_query(6), Ok(11)); + assert_eq!(fenwick_tree.prefix_query(7), Ok(13)); + assert_eq!(fenwick_tree.prefix_query(8), Ok(10)); + assert_eq!(fenwick_tree.prefix_query(9), Ok(14)); + assert_eq!( + fenwick_tree.prefix_query(10), + Err(FenwickTreeError::IndexOutOfBounds) + ); + + assert_eq!(fenwick_tree.range_query(0, 4), Ok(12)); + assert_eq!(fenwick_tree.range_query(3, 7), Ok(7)); + assert_eq!(fenwick_tree.range_query(2, 5), Ok(4)); + assert_eq!( + fenwick_tree.range_query(4, 3), + Err(FenwickTreeError::InvalidRange) + ); + assert_eq!( + fenwick_tree.range_query(2, 10), + Err(FenwickTreeError::InvalidRange) + ); + + assert_eq!(fenwick_tree.point_query(0), Ok(5)); + assert_eq!(fenwick_tree.point_query(4), Ok(-4)); + assert_eq!(fenwick_tree.point_query(9), Ok(4)); + assert_eq!( + fenwick_tree.point_query(10), + Err(FenwickTreeError::IndexOutOfBounds) + ); + } +} diff --git a/src/data_structures/floyds_algorithm.rs b/src/data_structures/floyds_algorithm.rs new file mode 100644 index 00000000000..b475d07d963 --- /dev/null +++ b/src/data_structures/floyds_algorithm.rs @@ -0,0 +1,95 @@ +// floyds_algorithm.rs +// https://github.com/rust-lang/rust/blob/master/library/alloc/src/collections/linked_list.rs#L113 +// use std::collections::linked_list::LinkedList; +// https://www.reddit.com/r/rust/comments/t7wquc/is_it_possible_to_solve_leetcode_problem141/ + +use crate::data_structures::linked_list::LinkedList; // Import the LinkedList from linked_list.rs + +pub fn detect_cycle(linked_list: &LinkedList) -> Option { + let mut current = linked_list.head; + let mut checkpoint = linked_list.head; + let mut steps_until_reset = 1; + let mut times_reset = 0; + + while let Some(node) = current { + steps_until_reset -= 1; + if steps_until_reset == 0 { + checkpoint = current; + times_reset += 1; + steps_until_reset = 1 << times_reset; // 2^times_reset + } + + unsafe { + let node_ptr = node.as_ptr(); + let next = (*node_ptr).next; + current = next; + } + if current == checkpoint { + return Some(linked_list.length as usize); + } + } + + None +} + +pub fn has_cycle(linked_list: &LinkedList) -> bool { + let mut slow = linked_list.head; + let mut fast = linked_list.head; + + while let (Some(slow_node), Some(fast_node)) = (slow, fast) { + unsafe { + slow = slow_node.as_ref().next; + fast = fast_node.as_ref().next; + + if let Some(fast_next) = fast { + // fast = (*fast_next.as_ptr()).next; + fast = fast_next.as_ref().next; + } else { + return false; // If fast reaches the end, there's no cycle + } + + if slow == fast { + return true; // Cycle detected + } + } + } + // println!("{}", flag); + false // No cycle detected +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_detect_cycle_no_cycle() { + let mut linked_list = LinkedList::new(); + linked_list.insert_at_tail(1); + linked_list.insert_at_tail(2); + linked_list.insert_at_tail(3); + + assert!(!has_cycle(&linked_list)); + + assert_eq!(detect_cycle(&linked_list), None); + } + + #[test] + fn test_detect_cycle_with_cycle() { + let mut linked_list = LinkedList::new(); + linked_list.insert_at_tail(1); + linked_list.insert_at_tail(2); + linked_list.insert_at_tail(3); + + // Create a cycle for testing + unsafe { + if let Some(mut tail) = linked_list.tail { + if let Some(head) = linked_list.head { + tail.as_mut().next = Some(head); + } + } + } + + assert!(has_cycle(&linked_list)); + assert_eq!(detect_cycle(&linked_list), Some(3)); + } +} diff --git a/src/data_structures/graph.rs b/src/data_structures/graph.rs new file mode 100644 index 00000000000..2bf3e64046b --- /dev/null +++ b/src/data_structures/graph.rs @@ -0,0 +1,220 @@ +use std::collections::{HashMap, HashSet}; +use std::fmt; + +#[derive(Debug, Clone)] +pub struct NodeNotInGraph; + +impl fmt::Display for NodeNotInGraph { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "accessing a node that is not in the graph") + } +} + +pub struct DirectedGraph { + adjacency_table: HashMap>, +} + +impl Graph for DirectedGraph { + fn new() -> DirectedGraph { + DirectedGraph { + adjacency_table: HashMap::new(), + } + } + fn adjacency_table_mutable(&mut self) -> &mut HashMap> { + &mut self.adjacency_table + } + fn adjacency_table(&self) -> &HashMap> { + &self.adjacency_table + } +} + +pub struct UndirectedGraph { + adjacency_table: HashMap>, +} + +impl Graph for UndirectedGraph { + fn new() -> UndirectedGraph { + UndirectedGraph { + adjacency_table: HashMap::new(), + } + } + fn adjacency_table_mutable(&mut self) -> &mut HashMap> { + &mut self.adjacency_table + } + fn adjacency_table(&self) -> &HashMap> { + &self.adjacency_table + } + fn add_edge(&mut self, edge: (&str, &str, i32)) { + self.add_node(edge.0); + self.add_node(edge.1); + + self.adjacency_table + .entry(edge.0.to_string()) + .and_modify(|e| { + e.push((edge.1.to_string(), edge.2)); + }); + self.adjacency_table + .entry(edge.1.to_string()) + .and_modify(|e| { + e.push((edge.0.to_string(), edge.2)); + }); + } +} + +pub trait Graph { + fn new() -> Self; + fn adjacency_table_mutable(&mut self) -> &mut HashMap>; + fn adjacency_table(&self) -> &HashMap>; + + fn add_node(&mut self, node: &str) -> bool { + match self.adjacency_table().get(node) { + None => { + self.adjacency_table_mutable() + .insert((*node).to_string(), Vec::new()); + true + } + _ => false, + } + } + + fn add_edge(&mut self, edge: (&str, &str, i32)) { + self.add_node(edge.0); + self.add_node(edge.1); + + self.adjacency_table_mutable() + .entry(edge.0.to_string()) + .and_modify(|e| { + e.push((edge.1.to_string(), edge.2)); + }); + } + + fn neighbours(&self, node: &str) -> Result<&Vec<(String, i32)>, NodeNotInGraph> { + match self.adjacency_table().get(node) { + None => Err(NodeNotInGraph), + Some(i) => Ok(i), + } + } + + fn contains(&self, node: &str) -> bool { + self.adjacency_table().get(node).is_some() + } + + fn nodes(&self) -> HashSet<&String> { + self.adjacency_table().keys().collect() + } + + fn edges(&self) -> Vec<(&String, &String, i32)> { + let mut edges = Vec::new(); + for (from_node, from_node_neighbours) in self.adjacency_table() { + for (to_node, weight) in from_node_neighbours { + edges.push((from_node, to_node, *weight)); + } + } + edges + } +} + +#[cfg(test)] +mod test_undirected_graph { + use super::Graph; + use super::UndirectedGraph; + #[test] + fn test_add_edge() { + let mut graph = UndirectedGraph::new(); + + graph.add_edge(("a", "b", 5)); + graph.add_edge(("b", "c", 10)); + graph.add_edge(("c", "a", 7)); + + let expected_edges = [ + (&String::from("a"), &String::from("b"), 5), + (&String::from("b"), &String::from("a"), 5), + (&String::from("c"), &String::from("a"), 7), + (&String::from("a"), &String::from("c"), 7), + (&String::from("b"), &String::from("c"), 10), + (&String::from("c"), &String::from("b"), 10), + ]; + for edge in expected_edges.iter() { + assert!(graph.edges().contains(edge)); + } + } + + #[test] + fn test_neighbours() { + let mut graph = UndirectedGraph::new(); + + graph.add_edge(("a", "b", 5)); + graph.add_edge(("b", "c", 10)); + graph.add_edge(("c", "a", 7)); + + assert_eq!( + graph.neighbours("a").unwrap(), + &vec![(String::from("b"), 5), (String::from("c"), 7)] + ); + } +} + +#[cfg(test)] +mod test_directed_graph { + use super::DirectedGraph; + use super::Graph; + + #[test] + fn test_add_node() { + let mut graph = DirectedGraph::new(); + graph.add_node("a"); + graph.add_node("b"); + graph.add_node("c"); + assert_eq!( + graph.nodes(), + [&String::from("a"), &String::from("b"), &String::from("c")] + .iter() + .cloned() + .collect() + ); + } + + #[test] + fn test_add_edge() { + let mut graph = DirectedGraph::new(); + + graph.add_edge(("a", "b", 5)); + graph.add_edge(("c", "a", 7)); + graph.add_edge(("b", "c", 10)); + + let expected_edges = [ + (&String::from("a"), &String::from("b"), 5), + (&String::from("c"), &String::from("a"), 7), + (&String::from("b"), &String::from("c"), 10), + ]; + for edge in expected_edges.iter() { + assert!(graph.edges().contains(edge)); + } + } + + #[test] + fn test_neighbours() { + let mut graph = DirectedGraph::new(); + + graph.add_edge(("a", "b", 5)); + graph.add_edge(("b", "c", 10)); + graph.add_edge(("c", "a", 7)); + + assert_eq!( + graph.neighbours("a").unwrap(), + &vec![(String::from("b"), 5)] + ); + } + + #[test] + fn test_contains() { + let mut graph = DirectedGraph::new(); + graph.add_node("a"); + graph.add_node("b"); + graph.add_node("c"); + assert!(graph.contains("a")); + assert!(graph.contains("b")); + assert!(graph.contains("c")); + assert!(!graph.contains("d")); + } +} diff --git a/src/data_structures/hash_table.rs b/src/data_structures/hash_table.rs new file mode 100644 index 00000000000..8eb39bdefb3 --- /dev/null +++ b/src/data_structures/hash_table.rs @@ -0,0 +1,145 @@ +use std::collections::LinkedList; + +pub struct HashTable { + elements: Vec>, + count: usize, +} + +impl Default for HashTable { + fn default() -> Self { + Self::new() + } +} + +pub trait Hashable { + fn hash(&self) -> usize; +} + +impl HashTable { + pub fn new() -> HashTable { + let initial_capacity = 3000; + let mut elements = Vec::with_capacity(initial_capacity); + + for _ in 0..initial_capacity { + elements.push(LinkedList::new()); + } + + HashTable { elements, count: 0 } + } + + pub fn insert(&mut self, key: K, value: V) { + if self.count >= self.elements.len() * 3 / 4 { + self.resize(); + } + let index = key.hash() % self.elements.len(); + self.elements[index].push_back((key, value)); + self.count += 1; + } + + pub fn search(&self, key: K) -> Option<&V> { + let index = key.hash() % self.elements.len(); + self.elements[index] + .iter() + .find(|(k, _)| *k == key) + .map(|(_, v)| v) + } + + fn resize(&mut self) { + let new_size = self.elements.len() * 2; + let mut new_elements = Vec::with_capacity(new_size); + + for _ in 0..new_size { + new_elements.push(LinkedList::new()); + } + + for old_list in self.elements.drain(..) { + for (key, value) in old_list { + let new_index = key.hash() % new_size; + new_elements[new_index].push_back((key, value)); + } + } + + self.elements = new_elements; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug, PartialEq, Eq)] + struct TestKey(usize); + + impl Hashable for TestKey { + fn hash(&self) -> usize { + self.0 + } + } + + #[test] + fn test_insert_and_search() { + let mut hash_table = HashTable::new(); + let key = TestKey(1); + let value = TestKey(10); + + hash_table.insert(key, value); + let result = hash_table.search(TestKey(1)); + + assert_eq!(result, Some(&TestKey(10))); + } + + #[test] + fn test_resize() { + let mut hash_table = HashTable::new(); + let initial_capacity = hash_table.elements.capacity(); + + for i in 0..=initial_capacity * 3 / 4 { + hash_table.insert(TestKey(i), TestKey(i + 10)); + } + + assert!(hash_table.elements.capacity() > initial_capacity); + } + + #[test] + fn test_search_nonexistent() { + let mut hash_table = HashTable::new(); + let key = TestKey(1); + let value = TestKey(10); + + hash_table.insert(key, value); + let result = hash_table.search(TestKey(2)); + + assert_eq!(result, None); + } + + #[test] + fn test_multiple_inserts_and_searches() { + let mut hash_table = HashTable::new(); + for i in 0..10 { + hash_table.insert(TestKey(i), TestKey(i + 100)); + } + + for i in 0..10 { + let result = hash_table.search(TestKey(i)); + assert_eq!(result, Some(&TestKey(i + 100))); + } + } + + #[test] + fn test_not_overwrite_existing_key() { + let mut hash_table = HashTable::new(); + hash_table.insert(TestKey(1), TestKey(100)); + hash_table.insert(TestKey(1), TestKey(200)); + + let result = hash_table.search(TestKey(1)); + assert_eq!(result, Some(&TestKey(100))); + } + + #[test] + fn test_empty_search() { + let hash_table: HashTable = HashTable::new(); + let result = hash_table.search(TestKey(1)); + + assert_eq!(result, None); + } +} diff --git a/src/data_structures/heap.rs b/src/data_structures/heap.rs index 586e5f5bdf7..cb48ff1bbd1 100644 --- a/src/data_structures/heap.rs +++ b/src/data_structures/heap.rs @@ -1,158 +1,242 @@ -// Heap data structure -// Takes a closure as a comparator to allow for min-heap, max-heap, and works with custom key functions +//! A generic heap data structure. +//! +//! This module provides a `Heap` implementation that can function as either a +//! min-heap or a max-heap. It supports common heap operations such as adding, +//! removing, and iterating over elements. The heap can also be created from +//! an unsorted vector and supports custom comparators for flexible sorting +//! behavior. -use std::cmp::Ord; -use std::default::Default; +use std::{cmp::Ord, slice::Iter}; -pub struct Heap -where - T: Default, -{ - count: usize, +/// A heap data structure that can be used as a min-heap, max-heap or with +/// custom comparators. +/// +/// This struct manages a collection of items where the heap property is maintained. +/// The heap can be configured to order elements based on a provided comparator function, +/// allowing for both min-heap and max-heap functionalities, as well as custom sorting orders. +pub struct Heap { items: Vec, comparator: fn(&T, &T) -> bool, } -impl Heap -where - T: Default, -{ +impl Heap { + /// Creates a new, empty heap with a custom comparator function. + /// + /// # Parameters + /// - `comparator`: A function that defines the heap's ordering. + /// + /// # Returns + /// A new `Heap` instance. pub fn new(comparator: fn(&T, &T) -> bool) -> Self { Self { - count: 0, - // Add a default in the first spot to offset indexes - // for the parent/child math to work out. - // Vecs have to have all the same type so using Default - // is a way to add an unused item. - items: vec![T::default()], + items: vec![], comparator, } } + /// Creates a heap from a vector and a custom comparator function. + /// + /// # Parameters + /// - `items`: A vector of items to be turned into a heap. + /// - `comparator`: A function that defines the heap's ordering. + /// + /// # Returns + /// A `Heap` instance with the elements from the provided vector. + pub fn from_vec(items: Vec, comparator: fn(&T, &T) -> bool) -> Self { + let mut heap = Self { items, comparator }; + heap.build_heap(); + heap + } + + /// Constructs the heap from an unsorted vector by applying the heapify process. + fn build_heap(&mut self) { + let last_parent_idx = (self.len() / 2).wrapping_sub(1); + for idx in (0..=last_parent_idx).rev() { + self.heapify_down(idx); + } + } + + /// Returns the number of elements in the heap. + /// + /// # Returns + /// The number of elements in the heap. pub fn len(&self) -> usize { - self.count + self.items.len() } + /// Checks if the heap is empty. + /// + /// # Returns + /// `true` if the heap is empty, `false` otherwise. pub fn is_empty(&self) -> bool { self.len() == 0 } + /// Adds a new element to the heap and maintains the heap property. + /// + /// # Parameters + /// - `value`: The value to add to the heap. pub fn add(&mut self, value: T) { - self.count += 1; self.items.push(value); + self.heapify_up(self.len() - 1); + } + + /// Removes and returns the root element from the heap. + /// + /// # Returns + /// The root element if the heap is not empty, otherwise `None`. + pub fn pop(&mut self) -> Option { + if self.is_empty() { + return None; + } + let next = Some(self.items.swap_remove(0)); + if !self.is_empty() { + self.heapify_down(0); + } + next + } + + /// Returns an iterator over the elements in the heap. + /// + /// # Returns + /// An iterator over the elements in the heap, in their internal order. + pub fn iter(&self) -> Iter<'_, T> { + self.items.iter() + } - // Heapify Up - let mut idx = self.count; - while self.parent_idx(idx) > 0 { - let pdx = self.parent_idx(idx); + /// Moves an element upwards to restore the heap property. + /// + /// # Parameters + /// - `idx`: The index of the element to heapify up. + fn heapify_up(&mut self, mut idx: usize) { + while let Some(pdx) = self.parent_idx(idx) { if (self.comparator)(&self.items[idx], &self.items[pdx]) { self.items.swap(idx, pdx); + idx = pdx; + } else { + break; } - idx = pdx; } } - fn parent_idx(&self, idx: usize) -> usize { - idx / 2 + /// Moves an element downwards to restore the heap property. + /// + /// # Parameters + /// - `idx`: The index of the element to heapify down. + fn heapify_down(&mut self, mut idx: usize) { + while self.children_present(idx) { + let cdx = { + if self.right_child_idx(idx) >= self.len() { + self.left_child_idx(idx) + } else { + let ldx = self.left_child_idx(idx); + let rdx = self.right_child_idx(idx); + if (self.comparator)(&self.items[ldx], &self.items[rdx]) { + ldx + } else { + rdx + } + } + }; + + if (self.comparator)(&self.items[cdx], &self.items[idx]) { + self.items.swap(idx, cdx); + idx = cdx; + } else { + break; + } + } } + /// Returns the index of the parent of the element at `idx`. + /// + /// # Parameters + /// - `idx`: The index of the element. + /// + /// # Returns + /// The index of the parent element if it exists, otherwise `None`. + fn parent_idx(&self, idx: usize) -> Option { + if idx > 0 { + Some((idx - 1) / 2) + } else { + None + } + } + + /// Checks if the element at `idx` has children. + /// + /// # Parameters + /// - `idx`: The index of the element. + /// + /// # Returns + /// `true` if the element has children, `false` otherwise. fn children_present(&self, idx: usize) -> bool { - self.left_child_idx(idx) <= self.count + self.left_child_idx(idx) < self.len() } + /// Returns the index of the left child of the element at `idx`. + /// + /// # Parameters + /// - `idx`: The index of the element. + /// + /// # Returns + /// The index of the left child. fn left_child_idx(&self, idx: usize) -> usize { - idx * 2 + idx * 2 + 1 } + /// Returns the index of the right child of the element at `idx`. + /// + /// # Parameters + /// - `idx`: The index of the element. + /// + /// # Returns + /// The index of the right child. fn right_child_idx(&self, idx: usize) -> usize { self.left_child_idx(idx) + 1 } - - fn smallest_child_idx(&self, idx: usize) -> usize { - if self.right_child_idx(idx) > self.count { - self.left_child_idx(idx) - } else { - let ldx = self.left_child_idx(idx); - let rdx = self.right_child_idx(idx); - if (self.comparator)(&self.items[ldx], &self.items[rdx]) { - ldx - } else { - rdx - } - } - } } impl Heap where - T: Default + Ord, + T: Ord, { - /// Create a new MinHeap - pub fn new_min() -> Self { + /// Creates a new min-heap. + /// + /// # Returns + /// A new `Heap` instance configured as a min-heap. + pub fn new_min() -> Heap { Self::new(|a, b| a < b) } - /// Create a new MaxHeap - pub fn new_max() -> Self { + /// Creates a new max-heap. + /// + /// # Returns + /// A new `Heap` instance configured as a max-heap. + pub fn new_max() -> Heap { Self::new(|a, b| a > b) } -} - -impl Iterator for Heap -where - T: Default, -{ - type Item = T; - - fn next(&mut self) -> Option { - let next = if self.count == 0 { - None - } else { - // This feels like a function built for heap impl :) - // Removes an item at an index and fills in with the last item - // of the Vec - let next = self.items.swap_remove(1); - Some(next) - }; - self.count -= 1; - - if self.count > 0 { - // Heapify Down - let mut idx = 1; - while self.children_present(idx) { - let cdx = self.smallest_child_idx(idx); - if !(self.comparator)(&self.items[idx], &self.items[cdx]) { - self.items.swap(idx, cdx); - } - idx = cdx; - } - } - - next - } -} - -pub struct MinHeap; -impl MinHeap { - #[allow(clippy::new_ret_no_self)] - pub fn new() -> Heap - where - T: Default + Ord, - { - Heap::new(|a, b| a < b) + /// Creates a min-heap from an unsorted vector. + /// + /// # Parameters + /// - `items`: A vector of items to be turned into a min-heap. + /// + /// # Returns + /// A `Heap` instance configured as a min-heap. + pub fn from_vec_min(items: Vec) -> Heap { + Self::from_vec(items, |a, b| a < b) } -} - -pub struct MaxHeap; -impl MaxHeap { - #[allow(clippy::new_ret_no_self)] - pub fn new() -> Heap - where - T: Default + Ord, - { - Heap::new(|a, b| a > b) + /// Creates a max-heap from an unsorted vector. + /// + /// # Parameters + /// - `items`: A vector of items to be turned into a max-heap. + /// + /// # Returns + /// A `Heap` instance configured as a max-heap. + pub fn from_vec_max(items: Vec) -> Heap { + Self::from_vec(items, |a, b| a > b) } } @@ -160,53 +244,90 @@ impl MaxHeap { mod tests { use super::*; + #[test] + fn test_empty_heap() { + let mut heap: Heap = Heap::new_max(); + assert_eq!(heap.pop(), None); + } + #[test] fn test_min_heap() { - let mut heap = MinHeap::new(); + let mut heap = Heap::new_min(); heap.add(4); heap.add(2); heap.add(9); heap.add(11); assert_eq!(heap.len(), 4); - assert_eq!(heap.next(), Some(2)); - assert_eq!(heap.next(), Some(4)); - assert_eq!(heap.next(), Some(9)); + assert_eq!(heap.pop(), Some(2)); + assert_eq!(heap.pop(), Some(4)); + assert_eq!(heap.pop(), Some(9)); heap.add(1); - assert_eq!(heap.next(), Some(1)); + assert_eq!(heap.pop(), Some(1)); + assert_eq!(heap.pop(), Some(11)); + assert_eq!(heap.pop(), None); } #[test] fn test_max_heap() { - let mut heap = MaxHeap::new(); + let mut heap = Heap::new_max(); heap.add(4); heap.add(2); heap.add(9); heap.add(11); assert_eq!(heap.len(), 4); - assert_eq!(heap.next(), Some(11)); - assert_eq!(heap.next(), Some(9)); - assert_eq!(heap.next(), Some(4)); + assert_eq!(heap.pop(), Some(11)); + assert_eq!(heap.pop(), Some(9)); + assert_eq!(heap.pop(), Some(4)); heap.add(1); - assert_eq!(heap.next(), Some(2)); + assert_eq!(heap.pop(), Some(2)); + assert_eq!(heap.pop(), Some(1)); + assert_eq!(heap.pop(), None); } - struct Point(/* x */ i32, /* y */ i32); - impl Default for Point { - fn default() -> Self { - Self(0, 0) - } + #[test] + fn test_iter_heap() { + let mut heap = Heap::new_min(); + heap.add(4); + heap.add(2); + heap.add(9); + heap.add(11); + + let mut iter = heap.iter(); + assert_eq!(iter.next(), Some(&2)); + assert_eq!(iter.next(), Some(&4)); + assert_eq!(iter.next(), Some(&9)); + assert_eq!(iter.next(), Some(&11)); + assert_eq!(iter.next(), None); + + assert_eq!(heap.len(), 4); + assert_eq!(heap.pop(), Some(2)); + assert_eq!(heap.pop(), Some(4)); + assert_eq!(heap.pop(), Some(9)); + assert_eq!(heap.pop(), Some(11)); + assert_eq!(heap.pop(), None); + } + + #[test] + fn test_from_vec_min() { + let vec = vec![3, 1, 4, 1, 5, 9, 2, 6, 5]; + let mut heap = Heap::from_vec_min(vec); + assert_eq!(heap.len(), 9); + assert_eq!(heap.pop(), Some(1)); + assert_eq!(heap.pop(), Some(1)); + assert_eq!(heap.pop(), Some(2)); + heap.add(0); + assert_eq!(heap.pop(), Some(0)); } #[test] - fn test_key_heap() { - let mut heap: Heap = Heap::new(|a, b| a.0 < b.0); - heap.add(Point(1, 5)); - heap.add(Point(3, 10)); - heap.add(Point(-2, 4)); - assert_eq!(heap.len(), 3); - assert_eq!(heap.next().unwrap().0, -2); - assert_eq!(heap.next().unwrap().0, 1); - heap.add(Point(50, 34)); - assert_eq!(heap.next().unwrap().0, 3); + fn test_from_vec_max() { + let vec = vec![3, 1, 4, 1, 5, 9, 2, 6, 5]; + let mut heap = Heap::from_vec_max(vec); + assert_eq!(heap.len(), 9); + assert_eq!(heap.pop(), Some(9)); + assert_eq!(heap.pop(), Some(6)); + assert_eq!(heap.pop(), Some(5)); + heap.add(10); + assert_eq!(heap.pop(), Some(10)); } } diff --git a/src/data_structures/lazy_segment_tree.rs b/src/data_structures/lazy_segment_tree.rs new file mode 100644 index 00000000000..d34b0d35432 --- /dev/null +++ b/src/data_structures/lazy_segment_tree.rs @@ -0,0 +1,278 @@ +use std::fmt::{Debug, Display}; +use std::ops::{Add, AddAssign, Range}; + +pub struct LazySegmentTree> +{ + len: usize, + tree: Vec, + lazy: Vec>, + merge: fn(T, T) -> T, +} + +impl> LazySegmentTree { + pub fn from_vec(arr: &[T], merge: fn(T, T) -> T) -> Self { + let len = arr.len(); + let mut sgtr = LazySegmentTree { + len, + tree: vec![T::default(); 4 * len], + lazy: vec![None; 4 * len], + merge, + }; + if len != 0 { + sgtr.build_recursive(arr, 1, 0..len, merge); + } + sgtr + } + + fn build_recursive( + &mut self, + arr: &[T], + idx: usize, + range: Range, + merge: fn(T, T) -> T, + ) { + if range.end - range.start == 1 { + self.tree[idx] = arr[range.start]; + } else { + let mid = range.start + (range.end - range.start) / 2; + self.build_recursive(arr, 2 * idx, range.start..mid, merge); + self.build_recursive(arr, 2 * idx + 1, mid..range.end, merge); + self.tree[idx] = merge(self.tree[2 * idx], self.tree[2 * idx + 1]); + } + } + + pub fn query(&mut self, range: Range) -> Option { + self.query_recursive(1, 0..self.len, &range) + } + + fn query_recursive( + &mut self, + idx: usize, + element_range: Range, + query_range: &Range, + ) -> Option { + if element_range.start >= query_range.end || element_range.end <= query_range.start { + return None; + } + if self.lazy[idx].is_some() { + self.propagation(idx, &element_range, T::default()); + } + if element_range.start >= query_range.start && element_range.end <= query_range.end { + return Some(self.tree[idx]); + } + let mid = element_range.start + (element_range.end - element_range.start) / 2; + let left = self.query_recursive(idx * 2, element_range.start..mid, query_range); + let right = self.query_recursive(idx * 2 + 1, mid..element_range.end, query_range); + match (left, right) { + (None, None) => None, + (None, Some(r)) => Some(r), + (Some(l), None) => Some(l), + (Some(l), Some(r)) => Some((self.merge)(l, r)), + } + } + + pub fn update(&mut self, target_range: Range, val: T) { + self.update_recursive(1, 0..self.len, &target_range, val); + } + + fn update_recursive( + &mut self, + idx: usize, + element_range: Range, + target_range: &Range, + val: T, + ) { + if element_range.start >= target_range.end || element_range.end <= target_range.start { + return; + } + if element_range.end - element_range.start == 1 { + self.tree[idx] += val; + return; + } + if element_range.start >= target_range.start && element_range.end <= target_range.end { + self.lazy[idx] = match self.lazy[idx] { + Some(lazy) => Some(lazy + val), + None => Some(val), + }; + return; + } + if self.lazy[idx].is_some() && self.lazy[idx].unwrap() != T::default() { + self.propagation(idx, &element_range, T::default()); + } + let mid = element_range.start + (element_range.end - element_range.start) / 2; + self.update_recursive(idx * 2, element_range.start..mid, target_range, val); + self.update_recursive(idx * 2 + 1, mid..element_range.end, target_range, val); + self.tree[idx] = (self.merge)(self.tree[idx * 2], self.tree[idx * 2 + 1]); + self.lazy[idx] = Some(T::default()); + } + + fn propagation(&mut self, idx: usize, element_range: &Range, parent_lazy: T) { + if element_range.end - element_range.start == 1 { + self.tree[idx] += parent_lazy; + return; + } + + let lazy = self.lazy[idx].unwrap_or_default(); + self.lazy[idx] = None; + + let mid = element_range.start + (element_range.end - element_range.start) / 2; + self.propagation(idx * 2, &(element_range.start..mid), parent_lazy + lazy); + self.propagation(idx * 2 + 1, &(mid..element_range.end), parent_lazy + lazy); + self.tree[idx] = (self.merge)(self.tree[idx * 2], self.tree[idx * 2 + 1]); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use quickcheck::TestResult; + use quickcheck_macros::quickcheck; + use std::cmp::{max, min}; + + #[test] + fn test_min_segments() { + let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut min_seg_tree = LazySegmentTree::from_vec(&vec, min); + // [-30, 2, -4, 7, (3, -5, 6), 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(-5), min_seg_tree.query(4..7)); + // [(-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8)] + assert_eq!(Some(-30), min_seg_tree.query(0..vec.len())); + // [(-30, 2), -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(-30), min_seg_tree.query(0..2)); + // [-30, (2, -4), 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(-4), min_seg_tree.query(1..3)); + // [-30, (2, -4, 7, 3, -5, 6), 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(-5), min_seg_tree.query(1..7)); + } + + #[test] + fn test_max_segments() { + let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut max_seg_tree = LazySegmentTree::from_vec(&vec, max); + // [-30, 2, -4, 7, (3, -5, 6), 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(6), max_seg_tree.query(4..7)); + // [(-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8)] + assert_eq!(Some(15), max_seg_tree.query(0..vec.len())); + // [(-30, 2), -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(2), max_seg_tree.query(0..2)); + // [-30, (2, -4), 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(2), max_seg_tree.query(1..3)); + // [-30, (2, -4, 7, 3, -5, 6), 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(7), max_seg_tree.query(1..7)); + } + + #[test] + fn test_sum_segments() { + let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut max_seg_tree = LazySegmentTree::from_vec(&vec, |x, y| x + y); + // [-30, 2, -4, 7, (3, -5, 6), 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(4), max_seg_tree.query(4..7)); + // [(-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8)] + assert_eq!(Some(7), max_seg_tree.query(0..vec.len())); + // [(-30, 2), -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(-28), max_seg_tree.query(0..2)); + // [-30, (2, -4), 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(-2), max_seg_tree.query(1..3)); + // [-30, (2, -4, 7, 3, -5, 6), 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(9), max_seg_tree.query(1..7)); + } + + #[test] + fn test_update_segments_tiny() { + let vec = vec![0, 0, 0, 0, 0]; + let mut update_seg_tree = LazySegmentTree::from_vec(&vec, |x, y| x + y); + update_seg_tree.update(0..3, 3); + update_seg_tree.update(2..5, 3); + assert_eq!(Some(3), update_seg_tree.query(0..1)); + assert_eq!(Some(3), update_seg_tree.query(1..2)); + assert_eq!(Some(6), update_seg_tree.query(2..3)); + assert_eq!(Some(3), update_seg_tree.query(3..4)); + assert_eq!(Some(3), update_seg_tree.query(4..5)); + } + + #[test] + fn test_update_segments() { + let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut update_seg_tree = LazySegmentTree::from_vec(&vec, |x, y| x + y); + // -> [-30, (5, -1, 10, 6), -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + update_seg_tree.update(1..5, 3); + + // [-30, 5, -1, 10, (6 -5, 6), 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(7), update_seg_tree.query(4..7)); + // [(-30, 5, -1, 10, 6 , -5, 6, 11, -20, 9, 14, 15, 5, 2, -8)] + assert_eq!(Some(19), update_seg_tree.query(0..vec.len())); + // [(-30, 5), -1, 10, 6, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(-25), update_seg_tree.query(0..2)); + // [-30, (5, -1), 10, 6, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(4), update_seg_tree.query(1..3)); + // [-30, (5, -1, 10, 6, -5, 6), 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(21), update_seg_tree.query(1..7)); + } + + // Some properties over segment trees: + // When asking for the range of the overall array, return the same as iter().min() or iter().max(), etc. + // When asking for an interval containing a single value, return this value, no matter the merge function + + #[quickcheck] + fn check_overall_interval_min(array: Vec) -> TestResult { + let mut seg_tree = LazySegmentTree::from_vec(&array, min); + TestResult::from_bool(array.iter().min().copied() == seg_tree.query(0..array.len())) + } + + #[quickcheck] + fn check_overall_interval_max(array: Vec) -> TestResult { + let mut seg_tree = LazySegmentTree::from_vec(&array, max); + TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len())) + } + + #[quickcheck] + fn check_overall_interval_sum(array: Vec) -> TestResult { + let mut seg_tree = LazySegmentTree::from_vec(&array, max); + TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len())) + } + + #[quickcheck] + fn check_single_interval_min(array: Vec) -> TestResult { + let mut seg_tree = LazySegmentTree::from_vec(&array, min); + for (i, value) in array.into_iter().enumerate() { + let res = seg_tree.query(Range { + start: i, + end: i + 1, + }); + if res != Some(value) { + return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); + } + } + TestResult::passed() + } + + #[quickcheck] + fn check_single_interval_max(array: Vec) -> TestResult { + let mut seg_tree = LazySegmentTree::from_vec(&array, max); + for (i, value) in array.into_iter().enumerate() { + let res = seg_tree.query(Range { + start: i, + end: i + 1, + }); + if res != Some(value) { + return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); + } + } + TestResult::passed() + } + + #[quickcheck] + fn check_single_interval_sum(array: Vec) -> TestResult { + let mut seg_tree = LazySegmentTree::from_vec(&array, max); + for (i, value) in array.into_iter().enumerate() { + let res = seg_tree.query(Range { + start: i, + end: i + 1, + }); + if res != Some(value) { + return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); + } + } + TestResult::passed() + } +} diff --git a/src/data_structures/linked_list.rs b/src/data_structures/linked_list.rs index 36060db87a9..5f782d82967 100644 --- a/src/data_structures/linked_list.rs +++ b/src/data_structures/linked_list.rs @@ -1,9 +1,10 @@ use std::fmt::{self, Display, Formatter}; +use std::marker::PhantomData; use std::ptr::NonNull; -struct Node { - val: T, - next: Option>>, +pub struct Node { + pub val: T, + pub next: Option>>, prev: Option>>, } @@ -18,9 +19,11 @@ impl Node { } pub struct LinkedList { - length: u32, - start: Option>>, - end: Option>>, + pub length: u32, + pub head: Option>>, + pub tail: Option>>, + // Act like we own boxed nodes since we construct and leak them + marker: PhantomData>>, } impl Default for LinkedList { @@ -33,48 +36,181 @@ impl LinkedList { pub fn new() -> Self { Self { length: 0, - start: None, - end: None, + head: None, + tail: None, + marker: PhantomData, } } - pub fn add(&mut self, obj: T) { + pub fn insert_at_head(&mut self, obj: T) { + let mut node = Box::new(Node::new(obj)); + node.next = self.head; + node.prev = None; + let node_ptr = NonNull::new(Box::into_raw(node)); + match self.head { + None => self.tail = node_ptr, + Some(head_ptr) => unsafe { (*head_ptr.as_ptr()).prev = node_ptr }, + } + self.head = node_ptr; + self.length += 1; + } + + pub fn insert_at_tail(&mut self, obj: T) { let mut node = Box::new(Node::new(obj)); - // Since we are adding node at the end, next will always be None node.next = None; - node.prev = self.end; - // Get a pointer to node - let node_ptr = Some(unsafe { NonNull::new_unchecked(Box::into_raw(node)) }); - match self.end { - // This is the case of empty list - None => self.start = node_ptr, - Some(end_ptr) => unsafe { (*end_ptr.as_ptr()).next = node_ptr }, - } - self.end = node_ptr; + node.prev = self.tail; + let node_ptr = NonNull::new(Box::into_raw(node)); + match self.tail { + None => self.head = node_ptr, + Some(tail_ptr) => unsafe { (*tail_ptr.as_ptr()).next = node_ptr }, + } + self.tail = node_ptr; self.length += 1; } - pub fn get(&mut self, index: i32) -> Option<&T> { - self.get_ith_node(self.start, index) + pub fn insert_at_ith(&mut self, index: u32, obj: T) { + if self.length < index { + panic!("Index out of bounds"); + } + + if index == 0 || self.head.is_none() { + self.insert_at_head(obj); + return; + } + + if self.length == index { + self.insert_at_tail(obj); + return; + } + + if let Some(mut ith_node) = self.head { + for _ in 0..index { + unsafe { + match (*ith_node.as_ptr()).next { + None => panic!("Index out of bounds"), + Some(next_ptr) => ith_node = next_ptr, + } + } + } + + let mut node = Box::new(Node::new(obj)); + unsafe { + node.prev = (*ith_node.as_ptr()).prev; + node.next = Some(ith_node); + if let Some(p) = (*ith_node.as_ptr()).prev { + let node_ptr = NonNull::new(Box::into_raw(node)); + println!("{:?}", (*p.as_ptr()).next); + (*p.as_ptr()).next = node_ptr; + (*ith_node.as_ptr()).prev = node_ptr; + self.length += 1; + } + } + } + } + + pub fn delete_head(&mut self) -> Option { + // Safety: head_ptr points to a leaked boxed node managed by this list + // We reassign pointers that pointed to the head node + if self.length == 0 { + return None; + } + + self.head.map(|head_ptr| unsafe { + let old_head = Box::from_raw(head_ptr.as_ptr()); + match old_head.next { + Some(mut next_ptr) => next_ptr.as_mut().prev = None, + None => self.tail = None, + } + self.head = old_head.next; + self.length = self.length.checked_add_signed(-1).unwrap_or(0); + old_head.val + }) + // None + } + + pub fn delete_tail(&mut self) -> Option { + // Safety: tail_ptr points to a leaked boxed node managed by this list + // We reassign pointers that pointed to the tail node + self.tail.map(|tail_ptr| unsafe { + let old_tail = Box::from_raw(tail_ptr.as_ptr()); + match old_tail.prev { + Some(mut prev) => prev.as_mut().next = None, + None => self.head = None, + } + self.tail = old_tail.prev; + self.length -= 1; + old_tail.val + }) + } + + pub fn delete_ith(&mut self, index: u32) -> Option { + if self.length <= index { + panic!("Index out of bounds"); + } + + if index == 0 || self.head.is_none() { + return self.delete_head(); + } + + if self.length - 1 == index { + return self.delete_tail(); + } + + if let Some(mut ith_node) = self.head { + for _ in 0..index { + unsafe { + match (*ith_node.as_ptr()).next { + None => panic!("Index out of bounds"), + Some(next_ptr) => ith_node = next_ptr, + } + } + } + + unsafe { + let old_ith = Box::from_raw(ith_node.as_ptr()); + if let Some(mut prev) = old_ith.prev { + prev.as_mut().next = old_ith.next; + } + if let Some(mut next) = old_ith.next { + next.as_mut().prev = old_ith.prev; + } + + self.length -= 1; + Some(old_ith.val) + } + } else { + None + } } - fn get_ith_node(&mut self, node: Option>>, index: i32) -> Option<&T> { + pub fn get(&self, index: i32) -> Option<&T> { + Self::get_ith_node(self.head, index).map(|ptr| unsafe { &(*ptr.as_ptr()).val }) + } + + fn get_ith_node(node: Option>>, index: i32) -> Option>> { match node { None => None, Some(next_ptr) => match index { - 0 => Some(unsafe { &(*next_ptr.as_ptr()).val }), - _ => self.get_ith_node(unsafe { (*next_ptr.as_ptr()).next }, index - 1), + 0 => Some(next_ptr), + _ => Self::get_ith_node(unsafe { (*next_ptr.as_ptr()).next }, index - 1), }, } } } +impl Drop for LinkedList { + fn drop(&mut self) { + // Pop items until there are none left + while self.delete_head().is_some() {} + } +} + impl Display for LinkedList where T: Display, { fn fmt(&self, f: &mut Formatter) -> fmt::Result { - match self.start { + match self.head { Some(node) => write!(f, "{}", unsafe { node.as_ref() }), None => Ok(()), } @@ -95,47 +231,282 @@ where #[cfg(test)] mod tests { + use std::convert::TryInto; + use super::LinkedList; + #[test] + fn insert_at_tail_works() { + let mut list = LinkedList::::new(); + let second_value = 2; + list.insert_at_tail(1); + list.insert_at_tail(second_value); + println!("Linked List is {list}"); + match list.get(1) { + Some(val) => assert_eq!(*val, second_value), + None => panic!("Expected to find {second_value} at index 1"), + } + } + #[test] + fn insert_at_head_works() { + let mut list = LinkedList::::new(); + let second_value = 2; + list.insert_at_head(1); + list.insert_at_head(second_value); + println!("Linked List is {list}"); + match list.get(0) { + Some(val) => assert_eq!(*val, second_value), + None => panic!("Expected to find {second_value} at index 0"), + } + } + + #[test] + fn insert_at_ith_can_add_to_tail() { + let mut list = LinkedList::::new(); + let second_value = 2; + list.insert_at_ith(0, 0); + list.insert_at_ith(1, second_value); + println!("Linked List is {list}"); + match list.get(1) { + Some(val) => assert_eq!(*val, second_value), + None => panic!("Expected to find {second_value} at index 1"), + } + } + + #[test] + fn insert_at_ith_can_add_to_head() { + let mut list = LinkedList::::new(); + let second_value = 2; + list.insert_at_ith(0, 1); + list.insert_at_ith(0, second_value); + println!("Linked List is {list}"); + match list.get(0) { + Some(val) => assert_eq!(*val, second_value), + None => panic!("Expected to find {second_value} at index 0"), + } + } + + #[test] + fn insert_at_ith_can_add_to_middle() { + let mut list = LinkedList::::new(); + let second_value = 2; + let third_value = 3; + list.insert_at_ith(0, 1); + list.insert_at_ith(1, second_value); + list.insert_at_ith(1, third_value); + println!("Linked List is {list}"); + match list.get(1) { + Some(val) => assert_eq!(*val, third_value), + None => panic!("Expected to find {third_value} at index 1"), + } + + match list.get(2) { + Some(val) => assert_eq!(*val, second_value), + None => panic!("Expected to find {second_value} at index 1"), + } + } + + #[test] + fn insert_at_ith_and_delete_at_ith_in_the_middle() { + // Insert and delete in the middle of the list to ensure pointers are updated correctly + let mut list = LinkedList::::new(); + let first_value = 0; + let second_value = 1; + let third_value = 2; + let fourth_value = 3; + + list.insert_at_ith(0, first_value); + list.insert_at_ith(1, fourth_value); + list.insert_at_ith(1, third_value); + list.insert_at_ith(1, second_value); + + list.delete_ith(2); + list.insert_at_ith(2, third_value); + + for (i, expected) in [ + (0, first_value), + (1, second_value), + (2, third_value), + (3, fourth_value), + ] { + match list.get(i) { + Some(val) => assert_eq!(*val, expected), + None => panic!("Expected to find {expected} at index {i}"), + } + } + } + + #[test] + fn insert_at_ith_and_delete_ith_work_over_many_iterations() { + let mut list = LinkedList::::new(); + for i in 0..100 { + list.insert_at_ith(i, i.try_into().unwrap()); + } + + // Pop even numbers to 50 + for i in 0..50 { + println!("list.length {}", list.length); + if i % 2 == 0 { + list.delete_ith(i); + } + } + + assert_eq!(list.length, 75); + + // Insert even numbers back + for i in 0..50 { + if i % 2 == 0 { + list.insert_at_ith(i, i.try_into().unwrap()); + } + } + + assert_eq!(list.length, 100); + + // Ensure numbers were adderd back and we're able to traverse nodes + if let Some(val) = list.get(78) { + assert_eq!(*val, 78); + } else { + panic!("Expected to find 78 at index 78"); + } + } + + #[test] + fn delete_tail_works() { + let mut list = LinkedList::::new(); + let first_value = 1; + let second_value = 2; + list.insert_at_tail(first_value); + list.insert_at_tail(second_value); + match list.delete_tail() { + Some(val) => assert_eq!(val, 2), + None => panic!("Expected to remove {second_value} at tail"), + } + + println!("Linked List is {list}"); + match list.get(0) { + Some(val) => assert_eq!(*val, first_value), + None => panic!("Expected to find {first_value} at index 0"), + } + } + + #[test] + fn delete_head_works() { + let mut list = LinkedList::::new(); + let first_value = 1; + let second_value = 2; + list.insert_at_tail(first_value); + list.insert_at_tail(second_value); + match list.delete_head() { + Some(val) => assert_eq!(val, 1), + None => panic!("Expected to remove {first_value} at head"), + } + + println!("Linked List is {list}"); + match list.get(0) { + Some(val) => assert_eq!(*val, second_value), + None => panic!("Expected to find {second_value} at index 0"), + } + } + + #[test] + fn delete_ith_can_delete_at_tail() { + let mut list = LinkedList::::new(); + let first_value = 1; + let second_value = 2; + list.insert_at_tail(first_value); + list.insert_at_tail(second_value); + match list.delete_ith(1) { + Some(val) => assert_eq!(val, 2), + None => panic!("Expected to remove {second_value} at tail"), + } + + assert_eq!(list.length, 1); + } + + #[test] + fn delete_ith_can_delete_at_head() { + let mut list = LinkedList::::new(); + let first_value = 1; + let second_value = 2; + list.insert_at_tail(first_value); + list.insert_at_tail(second_value); + match list.delete_ith(0) { + Some(val) => assert_eq!(val, 1), + None => panic!("Expected to remove {first_value} at tail"), + } + + assert_eq!(list.length, 1); + } + + #[test] + fn delete_ith_can_delete_in_middle() { + let mut list = LinkedList::::new(); + let first_value = 1; + let second_value = 2; + let third_value = 3; + list.insert_at_tail(first_value); + list.insert_at_tail(second_value); + list.insert_at_tail(third_value); + match list.delete_ith(1) { + Some(val) => assert_eq!(val, 2), + None => panic!("Expected to remove {second_value} at tail"), + } + + match list.get(1) { + Some(val) => assert_eq!(*val, third_value), + None => panic!("Expected to find {third_value} at index 1"), + } + } + #[test] fn create_numeric_list() { let mut list = LinkedList::::new(); - list.add(1); - list.add(2); - list.add(3); - println!("Linked List is {}", list); + list.insert_at_tail(1); + list.insert_at_tail(2); + list.insert_at_tail(3); + println!("Linked List is {list}"); assert_eq!(3, list.length); } #[test] fn create_string_list() { let mut list_str = LinkedList::::new(); - list_str.add("A".to_string()); - list_str.add("B".to_string()); - list_str.add("C".to_string()); - println!("Linked List is {}", list_str); + list_str.insert_at_tail("A".to_string()); + list_str.insert_at_tail("B".to_string()); + list_str.insert_at_tail("C".to_string()); + println!("Linked List is {list_str}"); assert_eq!(3, list_str.length); } #[test] fn get_by_index_in_numeric_list() { let mut list = LinkedList::::new(); - list.add(1); - list.add(2); - println!("Linked List is {}", list); + list.insert_at_tail(1); + list.insert_at_tail(2); + println!("Linked List is {list}"); let retrived_item = list.get(1); assert!(retrived_item.is_some()); - assert_eq!(2 as i32, *retrived_item.unwrap()); + assert_eq!(2, *retrived_item.unwrap()); } #[test] fn get_by_index_in_string_list() { let mut list_str = LinkedList::::new(); - list_str.add("A".to_string()); - list_str.add("B".to_string()); - println!("Linked List is {}", list_str); + list_str.insert_at_tail("A".to_string()); + list_str.insert_at_tail("B".to_string()); + println!("Linked List is {list_str}"); let retrived_item = list_str.get(1); assert!(retrived_item.is_some()); assert_eq!("B", *retrived_item.unwrap()); } + + #[test] + #[should_panic(expected = "Index out of bounds")] + fn delete_ith_panics_if_index_equals_length() { + let mut list = LinkedList::::new(); + list.insert_at_tail(1); + list.insert_at_tail(2); + // length is 2, so index 2 is out of bounds + list.delete_ith(2); + } } diff --git a/src/data_structures/mod.rs b/src/data_structures/mod.rs index 035deaaf0d8..621ff290360 100644 --- a/src/data_structures/mod.rs +++ b/src/data_structures/mod.rs @@ -1,9 +1,45 @@ +mod avl_tree; mod b_tree; mod binary_search_tree; +mod fenwick_tree; +mod floyds_algorithm; +pub mod graph; +mod hash_table; mod heap; +mod lazy_segment_tree; mod linked_list; +mod probabilistic; +mod queue; +mod range_minimum_query; +mod rb_tree; +mod segment_tree; +mod segment_tree_recursive; +mod stack_using_singly_linked_list; +mod treap; +mod trie; +mod union_find; +mod veb_tree; +pub use self::avl_tree::AVLTree; pub use self::b_tree::BTree; pub use self::binary_search_tree::BinarySearchTree; -pub use self::heap::{Heap, MaxHeap, MinHeap}; +pub use self::fenwick_tree::FenwickTree; +pub use self::floyds_algorithm::{detect_cycle, has_cycle}; +pub use self::graph::DirectedGraph; +pub use self::graph::UndirectedGraph; +pub use self::hash_table::HashTable; +pub use self::heap::Heap; +pub use self::lazy_segment_tree::LazySegmentTree; pub use self::linked_list::LinkedList; +pub use self::probabilistic::bloom_filter; +pub use self::probabilistic::count_min_sketch; +pub use self::queue::Queue; +pub use self::range_minimum_query::RangeMinimumQuery; +pub use self::rb_tree::RBTree; +pub use self::segment_tree::SegmentTree; +pub use self::segment_tree_recursive::SegmentTree as SegmentTreeRecursive; +pub use self::stack_using_singly_linked_list::Stack; +pub use self::treap::Treap; +pub use self::trie::Trie; +pub use self::union_find::UnionFind; +pub use self::veb_tree::VebTree; diff --git a/src/data_structures/probabilistic/bloom_filter.rs b/src/data_structures/probabilistic/bloom_filter.rs new file mode 100644 index 00000000000..b75fe5b1c90 --- /dev/null +++ b/src/data_structures/probabilistic/bloom_filter.rs @@ -0,0 +1,269 @@ +use std::collections::hash_map::{DefaultHasher, RandomState}; +use std::hash::{BuildHasher, Hash, Hasher}; + +/// A Bloom Filter is a probabilistic data structure testing whether an element belongs to a set or not +/// Therefore, its contract looks very close to the one of a set, for example a `HashSet` +pub trait BloomFilter { + fn insert(&mut self, item: Item); + fn contains(&self, item: &Item) -> bool; +} + +/// What is the point of using a Bloom Filter if it acts like a Set? +/// Let's imagine we have a huge number of elements to store (like un unbounded data stream) a Set storing every element will most likely take up too much space, at some point. +/// As other probabilistic data structures like Count-min Sketch, the goal of a Bloom Filter is to trade off exactitude for constant space. +/// We won't have a strictly exact result of whether the value belongs to the set, but we'll use constant space instead + +/// Let's start with the basic idea behind the implementation +/// Let's start by trying to make a `HashSet` with constant space: +/// Instead of storing every element and grow the set infinitely, let's use a vector with constant capacity `CAPACITY` +/// Each element of this vector will be a boolean. +/// When a new element is inserted, we hash its value and set the index at index `hash(item) % CAPACITY` to `true` +/// When looking for an item, we hash its value and retrieve the boolean at index `hash(item) % CAPACITY` +/// If it's `false` it's absolutely sure the item isn't present +/// If it's `true` the item may be present, or maybe another one produces the same hash +#[derive(Debug)] +struct BasicBloomFilter { + vec: [bool; CAPACITY], +} + +impl Default for BasicBloomFilter { + fn default() -> Self { + Self { + vec: [false; CAPACITY], + } + } +} + +impl BloomFilter for BasicBloomFilter { + fn insert(&mut self, item: Item) { + let mut hasher = DefaultHasher::new(); + item.hash(&mut hasher); + let idx = (hasher.finish() % CAPACITY as u64) as usize; + self.vec[idx] = true; + } + + fn contains(&self, item: &Item) -> bool { + let mut hasher = DefaultHasher::new(); + item.hash(&mut hasher); + let idx = (hasher.finish() % CAPACITY as u64) as usize; + self.vec[idx] + } +} + +/// Can we improve it? Certainly, in different ways. +/// One pattern you may have identified here is that we use a "binary array" (a vector of binary values) +/// For instance, we might have `[0,1,0,0,1,0]`, which is actually the binary representation of 9 +/// This means we can immediately replace our `Vec` by an actual number +/// What would it mean to set a `1` at index `i`? +/// Imagine a `CAPACITY` of `6`. The initial value for our mask is `000000`. +/// We want to store `"Bloom"`. Its hash modulo `CAPACITY` is `5`. Which means we need to set `1` at the last index. +/// It can be performed by doing `000000 | 000001` +/// Meaning we can hash the item value, use a modulo to find the index, and do a binary `or` between the current number and the index +#[allow(dead_code)] +#[derive(Debug, Default)] +struct SingleBinaryBloomFilter { + fingerprint: u128, // let's use 128 bits, the equivalent of using CAPACITY=128 in the previous example +} + +/// Given a value and a hash function, compute the hash and return the bit mask +fn mask_128(hasher: &mut DefaultHasher, item: T) -> u128 { + item.hash(hasher); + let idx = (hasher.finish() % 128) as u32; + // idx is where we want to put a 1, let's convert this into a proper binary mask + 2_u128.pow(idx) +} + +impl BloomFilter for SingleBinaryBloomFilter { + fn insert(&mut self, item: T) { + self.fingerprint |= mask_128(&mut DefaultHasher::new(), &item); + } + + fn contains(&self, item: &T) -> bool { + (self.fingerprint & mask_128(&mut DefaultHasher::new(), item)) > 0 + } +} + +/// We may have made some progress in term of CPU efficiency, using binary operators. +/// But we might still run into a lot of collisions with a single 128-bits number. +/// Can we use greater numbers then? Currently, our implementation is limited to 128 bits. +/// +/// Should we go back to using an array, then? +/// We could! But instead of using `Vec` we could use `Vec`. +/// Each `u8` can act as a mask as we've done before, and is actually 1 byte in memory (same as a boolean!) +/// That'd allow us to go over 128 bits, but would divide by 8 the memory footprint. +/// That's one thing, and will involve dividing / shifting by 8 in different places. +/// +/// But still, can we reduce the collisions furthermore? +/// +/// As we did with count-min-sketch, we could use multiple hash function. +/// When inserting a value, we compute its hash with every hash function (`hash_i`) and perform the same operation as above (the OR with `fingerprint`) +/// Then when looking for a value, if **ANY** of the tests (`hash` then `AND`) returns 0 then this means the value is missing from the set, otherwise it would have returned 1 +/// If it returns `1`, it **may** be that the item is present, but could also be a collision +/// This is what a Bloom Filter is about: returning `false` means the value is necessarily absent, and returning true means it may be present +pub struct MultiBinaryBloomFilter { + filter_size: usize, + bytes: Vec, + hash_builders: Vec, +} + +impl MultiBinaryBloomFilter { + pub fn with_dimensions(filter_size: usize, hash_count: usize) -> Self { + let bytes_count = filter_size / 8 + usize::from(filter_size % 8 > 0); // we need 8 times less entries in the array, since we are using bytes. Careful that we have at least one element though + Self { + filter_size, + bytes: vec![0; bytes_count], + hash_builders: vec![RandomState::new(); hash_count], + } + } + + pub fn from_estimate( + estimated_count_of_items: usize, + max_false_positive_probability: f64, + ) -> Self { + // Check Wikipedia for these formulae + let optimal_filter_size = (-(estimated_count_of_items as f64) + * max_false_positive_probability.ln() + / (2.0_f64.ln().powi(2))) + .ceil() as usize; + let optimal_hash_count = ((optimal_filter_size as f64 / estimated_count_of_items as f64) + * 2.0_f64.ln()) + .ceil() as usize; + Self::with_dimensions(optimal_filter_size, optimal_hash_count) + } +} + +impl BloomFilter for MultiBinaryBloomFilter { + fn insert(&mut self, item: Item) { + for builder in &self.hash_builders { + let mut hasher = builder.build_hasher(); + item.hash(&mut hasher); + let hash = builder.hash_one(&item); + let index = hash % self.filter_size as u64; + let byte_index = index as usize / 8; // this is this byte that we need to modify + let bit_index = (index % 8) as u8; // we cannot only OR with value 1 this time, since we have 8 bits + self.bytes[byte_index] |= 1 << bit_index; + } + } + + fn contains(&self, item: &Item) -> bool { + for builder in &self.hash_builders { + let mut hasher = builder.build_hasher(); + item.hash(&mut hasher); + let hash = builder.hash_one(item); + let index = hash % self.filter_size as u64; + let byte_index = index as usize / 8; // this is this byte that we need to modify + let bit_index = (index % 8) as u8; // we cannot only OR with value 1 this time, since we have 8 bits + if self.bytes[byte_index] & (1 << bit_index) == 0 { + return false; + } + } + true + } +} + +#[cfg(test)] +mod tests { + use crate::data_structures::probabilistic::bloom_filter::{ + BasicBloomFilter, BloomFilter, MultiBinaryBloomFilter, SingleBinaryBloomFilter, + }; + use quickcheck::{Arbitrary, Gen}; + use quickcheck_macros::quickcheck; + use std::collections::HashSet; + + #[derive(Debug, Clone)] + struct TestSet { + to_insert: HashSet, + to_test: Vec, + } + + impl Arbitrary for TestSet { + fn arbitrary(g: &mut Gen) -> Self { + let mut qty = usize::arbitrary(g) % 5_000; + if qty < 50 { + qty += 50; // won't be perfectly uniformly distributed + } + let mut to_insert = HashSet::with_capacity(qty); + let mut to_test = Vec::with_capacity(qty); + for _ in 0..(qty) { + to_insert.insert(i32::arbitrary(g)); + to_test.push(i32::arbitrary(g)); + } + TestSet { to_insert, to_test } + } + } + + #[quickcheck] + fn basic_filter_must_not_return_false_negative(TestSet { to_insert, to_test }: TestSet) { + let mut basic_filter = BasicBloomFilter::<10_000>::default(); + for item in &to_insert { + basic_filter.insert(*item); + } + for other in to_test { + if !basic_filter.contains(&other) { + assert!(!to_insert.contains(&other)) + } + } + } + + #[quickcheck] + fn binary_filter_must_not_return_false_negative(TestSet { to_insert, to_test }: TestSet) { + let mut binary_filter = SingleBinaryBloomFilter::default(); + for item in &to_insert { + binary_filter.insert(*item); + } + for other in to_test { + if !binary_filter.contains(&other) { + assert!(!to_insert.contains(&other)) + } + } + } + + #[quickcheck] + fn a_basic_filter_of_capacity_128_is_the_same_as_a_binary_filter( + TestSet { to_insert, to_test }: TestSet, + ) { + let mut basic_filter = BasicBloomFilter::<128>::default(); // change 32 to anything else here, and the test won't pass + let mut binary_filter = SingleBinaryBloomFilter::default(); + for item in &to_insert { + basic_filter.insert(*item); + binary_filter.insert(*item); + } + for other in to_test { + // Since we use the same DefaultHasher::new(), and both have size 32, we should have exactly the same results + assert_eq!( + basic_filter.contains(&other), + binary_filter.contains(&other) + ); + } + } + + const FALSE_POSITIVE_MAX: f64 = 0.05; + + #[quickcheck] + fn a_multi_binary_bloom_filter_must_not_return_false_negatives( + TestSet { to_insert, to_test }: TestSet, + ) { + let n = to_insert.len(); + if n == 0 { + // avoid dividing by 0 when adjusting the size + return; + } + // See Wikipedia for those formula + let mut binary_filter = MultiBinaryBloomFilter::from_estimate(n, FALSE_POSITIVE_MAX); + for item in &to_insert { + binary_filter.insert(*item); + } + let tests = to_test.len(); + let mut false_positives = 0; + for other in to_test { + if !binary_filter.contains(&other) { + assert!(!to_insert.contains(&other)) + } else if !to_insert.contains(&other) { + // false positive + false_positives += 1; + } + } + let fp_rate = false_positives as f64 / tests as f64; + assert!(fp_rate < 1.0); // This isn't really a test, but so that you have the `fp_rate` variable to print out, or evaluate + } +} diff --git a/src/data_structures/probabilistic/count_min_sketch.rs b/src/data_structures/probabilistic/count_min_sketch.rs new file mode 100644 index 00000000000..0aec3bff577 --- /dev/null +++ b/src/data_structures/probabilistic/count_min_sketch.rs @@ -0,0 +1,247 @@ +use std::collections::hash_map::RandomState; +use std::fmt::{Debug, Formatter}; +use std::hash::{BuildHasher, Hash}; + +/// A probabilistic data structure holding an approximate count for diverse items efficiently (using constant space) +/// +/// Let's imagine we want to count items from an incoming (unbounded) data stream +/// One way to do this would be to hold a frequency hashmap, counting element hashes +/// This works extremely well, but unfortunately would require a lot of memory if we have a huge diversity of incoming items in the data stream +/// +/// CountMinSketch aims at solving this problem, trading off the exact count for an approximate one, but getting from potentially unbounded space complexity to constant complexity +/// See the implementation below for more details +/// +/// Here is the definition of the different allowed operations on a CountMinSketch: +/// * increment the count of an item +/// * retrieve the count of an item +pub trait CountMinSketch { + type Item; + + fn increment(&mut self, item: Self::Item); + fn increment_by(&mut self, item: Self::Item, count: usize); + fn get_count(&self, item: Self::Item) -> usize; +} + +/// The common implementation of a CountMinSketch +/// Holding a DEPTH x WIDTH matrix of counts +/// +/// The idea behind the implementation is the following: +/// Let's start from our problem statement above. We have a frequency map of counts, and want to go reduce its space complexity +/// The immediate way to do this would be to use a Vector with a fixed size, let this size be `WIDTH` +/// We will be holding the count of each item `item` in the Vector, at index `i = hash(item) % WIDTH` where `hash` is a hash function: `item -> usize` +/// We now have constant space. +/// +/// The problem though is that we'll potentially run into a lot of collisions. +/// Taking an extreme example, if `WIDTH = 1`, all items will have the same count, which is the sum of counts of every items +/// We could reduce the amount of collisions by using a bigger `WIDTH` but this wouldn't be way more efficient than the "big" frequency map +/// How do we improve the solution, but still keeping constant space? +/// +/// The idea is to use not just one vector, but multiple (`DEPTH`) ones and attach different `hash` functions to each vector +/// This would lead to the following data structure: +/// <- WIDTH = 5 -> +/// D hash1: [0, 0, 0, 0, 0] +/// E hash2: [0, 0, 0, 0, 0] +/// P hash3: [0, 0, 0, 0, 0] +/// T hash4: [0, 0, 0, 0, 0] +/// H hash5: [0, 0, 0, 0, 0] +/// = hash6: [0, 0, 0, 0, 0] +/// 7 hash7: [0, 0, 0, 0, 0] +/// Every hash function must return a different value for the same item. +/// Let's say we hash "TEST" and: +/// hash1("TEST") = 42 => idx = 2 +/// hash2("TEST") = 26 => idx = 1 +/// hash3("TEST") = 10 => idx = 0 +/// hash4("TEST") = 33 => idx = 3 +/// hash5("TEST") = 54 => idx = 4 +/// hash6("TEST") = 11 => idx = 1 +/// hash7("TEST") = 50 => idx = 0 +/// This would lead our structure to become: +/// <- WIDTH = 5 -> +/// D hash1: [0, 0, 1, 0, 0] +/// E hash2: [0, 1, 0, 0, 0] +/// P hash3: [1, 0, 0, 0, 0] +/// T hash4: [0, 0, 0, 1, 0] +/// H hash5: [0, 0, 0, 0, 1] +/// = hash6: [0, 1, 0, 0, 0] +/// 7 hash7: [1, 0, 0, 0, 0] +/// +/// Now say we hash "OTHER" and: +/// hash1("OTHER") = 23 => idx = 3 +/// hash2("OTHER") = 11 => idx = 1 +/// hash3("OTHER") = 52 => idx = 2 +/// hash4("OTHER") = 25 => idx = 0 +/// hash5("OTHER") = 31 => idx = 1 +/// hash6("OTHER") = 24 => idx = 4 +/// hash7("OTHER") = 30 => idx = 0 +/// Leading our data structure to become: +/// <- WIDTH = 5 -> +/// D hash1: [0, 0, 1, 1, 0] +/// E hash2: [0, 2, 0, 0, 0] +/// P hash3: [1, 0, 1, 0, 0] +/// T hash4: [1, 0, 0, 1, 0] +/// H hash5: [0, 1, 0, 0, 1] +/// = hash6: [0, 1, 0, 0, 1] +/// 7 hash7: [2, 0, 0, 0, 0] +/// +/// We actually can witness some collisions (invalid counts of `2` above in some rows). +/// This means that if we have to return the count for "TEST", we'd actually fetch counts from every row and return the minimum value +/// +/// This could potentially be overestimated if we have a huge number of entries and a lot of collisions. +/// But an interesting property is that the count we return for "TEST" cannot be underestimated +pub struct HashCountMinSketch { + phantom: std::marker::PhantomData, // just a marker for Item to be used + counts: [[usize; WIDTH]; DEPTH], + hashers: [RandomState; DEPTH], +} + +impl Debug + for HashCountMinSketch +{ + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Item").field("vecs", &self.counts).finish() + } +} + +impl Default + for HashCountMinSketch +{ + fn default() -> Self { + let hashers = std::array::from_fn(|_| RandomState::new()); + + Self { + phantom: std::marker::PhantomData, + counts: [[0; WIDTH]; DEPTH], + hashers, + } + } +} + +impl CountMinSketch + for HashCountMinSketch +{ + type Item = Item; + + fn increment(&mut self, item: Self::Item) { + self.increment_by(item, 1) + } + + fn increment_by(&mut self, item: Self::Item, count: usize) { + for (row, r) in self.hashers.iter_mut().enumerate() { + let mut h = r.build_hasher(); + item.hash(&mut h); + let hashed = r.hash_one(&item); + let col = (hashed % WIDTH as u64) as usize; + self.counts[row][col] += count; + } + } + + fn get_count(&self, item: Self::Item) -> usize { + self.hashers + .iter() + .enumerate() + .map(|(row, r)| { + let mut h = r.build_hasher(); + item.hash(&mut h); + let hashed = r.hash_one(&item); + let col = (hashed % WIDTH as u64) as usize; + self.counts[row][col] + }) + .min() + .unwrap() + } +} + +#[cfg(test)] +mod tests { + use crate::data_structures::probabilistic::count_min_sketch::{ + CountMinSketch, HashCountMinSketch, + }; + use quickcheck::{Arbitrary, Gen}; + use std::collections::HashSet; + + #[test] + fn hash_functions_should_hash_differently() { + let mut sketch: HashCountMinSketch<&str, 50, 50> = HashCountMinSketch::default(); // use a big DEPTH + sketch.increment("something"); + // We want to check that our hash functions actually produce different results, so we'll store the indices where we encounter a count=1 in a set + let mut indices_of_ones: HashSet = HashSet::default(); + for counts in sketch.counts { + let ones = counts + .into_iter() + .enumerate() + .filter_map(|(idx, count)| (count == 1).then_some(idx)) + .collect::>(); + assert_eq!(1, ones.len()); + indices_of_ones.insert(ones[0]); + } + // Given the parameters (WIDTH = 50, DEPTH = 50) it's extremely unlikely that all hash functions hash to the same index + assert!(indices_of_ones.len() > 1); // but we want to avoid a bug where all hash functions would produce the same hash (or hash to the same index) + } + + #[test] + fn inspect_counts() { + let mut sketch: HashCountMinSketch<&str, 5, 7> = HashCountMinSketch::default(); + sketch.increment("test"); + // Inspect internal state: + for counts in sketch.counts { + let zeroes = counts.iter().filter(|count| **count == 0).count(); + assert_eq!(4, zeroes); + let ones = counts.iter().filter(|count| **count == 1).count(); + assert_eq!(1, ones); + } + sketch.increment("test"); + for counts in sketch.counts { + let zeroes = counts.iter().filter(|count| **count == 0).count(); + assert_eq!(4, zeroes); + let twos = counts.iter().filter(|count| **count == 2).count(); + assert_eq!(1, twos); + } + + // This one is actually deterministic + assert_eq!(2, sketch.get_count("test")); + } + + #[derive(Debug, Clone, Eq, PartialEq, Hash)] + struct TestItem { + item: String, + count: usize, + } + + const MAX_STR_LEN: u8 = 30; + const MAX_COUNT: usize = 20; + + impl Arbitrary for TestItem { + fn arbitrary(g: &mut Gen) -> Self { + let str_len = u8::arbitrary(g) % MAX_STR_LEN; + let mut str = String::with_capacity(str_len as usize); + for _ in 0..str_len { + str.push(char::arbitrary(g)); + } + let count = usize::arbitrary(g) % MAX_COUNT; + TestItem { item: str, count } + } + } + + #[quickcheck_macros::quickcheck] + fn must_not_understimate_count(test_items: Vec) { + let test_items = test_items.into_iter().collect::>(); // remove duplicated (would lead to weird counts) + let n = test_items.len(); + let mut sketch: HashCountMinSketch = HashCountMinSketch::default(); + let mut exact_count = 0; + for TestItem { item, count } in &test_items { + sketch.increment_by(item.clone(), *count); + } + for TestItem { item, count } in test_items { + let stored_count = sketch.get_count(item); + assert!(stored_count >= count); + if count == stored_count { + exact_count += 1; + } + } + if n > 20 { + // if n is too short, the stat isn't really relevant + let exact_ratio = exact_count as f64 / n as f64; + assert!(exact_ratio > 0.7); // the proof is quite hard, but this should be OK + } + } +} diff --git a/src/data_structures/probabilistic/mod.rs b/src/data_structures/probabilistic/mod.rs new file mode 100644 index 00000000000..de55027f15f --- /dev/null +++ b/src/data_structures/probabilistic/mod.rs @@ -0,0 +1,2 @@ +pub mod bloom_filter; +pub mod count_min_sketch; diff --git a/src/data_structures/queue.rs b/src/data_structures/queue.rs new file mode 100644 index 00000000000..a0299155490 --- /dev/null +++ b/src/data_structures/queue.rs @@ -0,0 +1,91 @@ +//! This module provides a generic `Queue` data structure, implemented using +//! Rust's `LinkedList` from the standard library. The queue follows the FIFO +//! (First-In-First-Out) principle, where elements are added to the back of +//! the queue and removed from the front. + +use std::collections::LinkedList; + +#[derive(Debug)] +pub struct Queue { + elements: LinkedList, +} + +impl Queue { + // Creates a new empty Queue + pub fn new() -> Queue { + Queue { + elements: LinkedList::new(), + } + } + + // Adds an element to the back of the queue + pub fn enqueue(&mut self, value: T) { + self.elements.push_back(value) + } + + // Removes and returns the front element from the queue, or None if empty + pub fn dequeue(&mut self) -> Option { + self.elements.pop_front() + } + + // Returns a reference to the front element of the queue, or None if empty + pub fn peek_front(&self) -> Option<&T> { + self.elements.front() + } + + // Returns a reference to the back element of the queue, or None if empty + pub fn peek_back(&self) -> Option<&T> { + self.elements.back() + } + + // Returns the number of elements in the queue + pub fn len(&self) -> usize { + self.elements.len() + } + + // Checks if the queue is empty + pub fn is_empty(&self) -> bool { + self.elements.is_empty() + } + + // Clears all elements from the queue + pub fn drain(&mut self) { + self.elements.clear(); + } +} + +// Implementing the Default trait for Queue +impl Default for Queue { + fn default() -> Queue { + Queue::new() + } +} + +#[cfg(test)] +mod tests { + use super::Queue; + + #[test] + fn test_queue_functionality() { + let mut queue: Queue = Queue::default(); + + assert!(queue.is_empty()); + queue.enqueue(8); + queue.enqueue(16); + assert!(!queue.is_empty()); + assert_eq!(queue.len(), 2); + + assert_eq!(queue.peek_front(), Some(&8)); + assert_eq!(queue.peek_back(), Some(&16)); + + assert_eq!(queue.dequeue(), Some(8)); + assert_eq!(queue.len(), 1); + assert_eq!(queue.peek_front(), Some(&16)); + assert_eq!(queue.peek_back(), Some(&16)); + + queue.drain(); + assert!(queue.is_empty()); + assert_eq!(queue.len(), 0); + assert_eq!(queue.dequeue(), None); + } +} diff --git a/src/data_structures/range_minimum_query.rs b/src/data_structures/range_minimum_query.rs new file mode 100644 index 00000000000..8bb74a7a1fe --- /dev/null +++ b/src/data_structures/range_minimum_query.rs @@ -0,0 +1,194 @@ +//! Range Minimum Query (RMQ) Implementation +//! +//! This module provides an efficient implementation of a Range Minimum Query data structure using a +//! sparse table approach. It allows for quick retrieval of the minimum value within a specified subdata +//! of a given data after an initial preprocessing phase. +//! +//! The RMQ is particularly useful in scenarios requiring multiple queries on static data, as it +//! allows querying in constant time after an O(n log(n)) preprocessing time. +//! +//! References: [Wikipedia](https://en.wikipedia.org/wiki/Range_minimum_query) + +use std::cmp::PartialOrd; + +/// Custom error type for invalid range queries. +#[derive(Debug, PartialEq, Eq)] +pub enum RangeError { + /// Indicates that the provided range is invalid (start index is not less than end index). + InvalidRange, + /// Indicates that one or more indices are out of bounds for the data. + IndexOutOfBound, +} + +/// A data structure for efficiently answering range minimum queries on static data. +pub struct RangeMinimumQuery { + /// The original input data on which range queries are performed. + data: Vec, + /// The sparse table for storing preprocessed range minimum information. Each entry + /// contains the index of the minimum element in the range starting at `j` and having a length of `2^i`. + sparse_table: Vec>, +} + +impl RangeMinimumQuery { + /// Creates a new `RangeMinimumQuery` instance with the provided input data. + /// + /// # Arguments + /// + /// * `input` - A slice of elements of type `T` that implement `PartialOrd` and `Copy`. + /// + /// # Returns + /// + /// A `RangeMinimumQuery` instance that can be used to perform range minimum queries. + pub fn new(input: &[T]) -> RangeMinimumQuery { + RangeMinimumQuery { + data: input.to_vec(), + sparse_table: build_sparse_table(input), + } + } + + /// Retrieves the minimum value in the specified range [start, end). + /// + /// # Arguments + /// + /// * `start` - The starting index of the range (inclusive). + /// * `end` - The ending index of the range (exclusive). + /// + /// # Returns + /// + /// * `Ok(T)` - The minimum value found in the specified range. + /// * `Err(RangeError)` - An error indicating the reason for failure, such as an invalid range + /// or indices out of bounds. + pub fn get_range_min(&self, start: usize, end: usize) -> Result { + // Validate range + if start >= end { + return Err(RangeError::InvalidRange); + } + if start >= self.data.len() || end > self.data.len() { + return Err(RangeError::IndexOutOfBound); + } + + // Calculate the log length and the index for the sparse table + let log_len = (end - start).ilog2() as usize; + let idx: usize = end - (1 << log_len); + + // Retrieve the indices of the minimum values from the sparse table + let min_idx_start = self.sparse_table[log_len][start]; + let min_idx_end = self.sparse_table[log_len][idx]; + + // Compare the values at the retrieved indices and return the minimum + if self.data[min_idx_start] < self.data[min_idx_end] { + Ok(self.data[min_idx_start]) + } else { + Ok(self.data[min_idx_end]) + } + } +} + +/// Builds a sparse table for the provided data to support range minimum queries. +/// +/// # Arguments +/// +/// * `data` - A slice of elements of type `T` that implement `PartialOrd`. +/// +/// # Returns +/// +/// A 2D vector representing the sparse table, where each entry contains the index of the minimum +/// element in the range defined by the starting index and the power of two lengths. +fn build_sparse_table(data: &[T]) -> Vec> { + let mut sparse_table: Vec> = vec![(0..data.len()).collect()]; + let len = data.len(); + + // Fill the sparse table + for log_len in 1..=len.ilog2() { + let mut row = Vec::new(); + for idx in 0..=len - (1 << log_len) { + let min_idx_start = sparse_table[sparse_table.len() - 1][idx]; + let min_idx_end = sparse_table[sparse_table.len() - 1][idx + (1 << (log_len - 1))]; + if data[min_idx_start] < data[min_idx_end] { + row.push(min_idx_start); + } else { + row.push(min_idx_end); + } + } + sparse_table.push(row); + } + + sparse_table +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_build_sparse_table { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (data, expected) = $inputs; + assert_eq!(build_sparse_table(&data), expected); + } + )* + } + } + + test_build_sparse_table! { + small: ( + [1, 6, 3], + vec![ + vec![0, 1, 2], + vec![0, 2] + ] + ), + medium: ( + [1, 3, 6, 123, 7, 235, 3, -4, 6, 2], + vec![ + vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + vec![0, 1, 2, 4, 4, 6, 7, 7, 9], + vec![0, 1, 2, 6, 7, 7, 7], + vec![7, 7, 7] + ] + ), + large: ( + [20, 13, -13, 2, 3634, -2, 56, 3, 67, 8, 23, 0, -23, 1, 5, 85, 3, 24, 5, -10, 3, 4, 20], + vec![ + vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22], + vec![1, 2, 2, 3, 5, 5, 7, 7, 9, 9, 11, 12, 12, 13, 14, 16, 16, 18, 19, 19, 20, 21], + vec![2, 2, 2, 5, 5, 5, 7, 7, 11, 12, 12, 12, 12, 13, 16, 16, 19, 19, 19, 19], + vec![2, 2, 2, 5, 5, 12, 12, 12, 12, 12, 12, 12, 12, 19, 19, 19], + vec![12, 12, 12, 12, 12, 12, 12, 12] + ] + ), + } + + #[test] + fn simple_query_tests() { + let rmq = RangeMinimumQuery::new(&[1, 3, 6, 123, 7, 235, 3, -4, 6, 2]); + + assert_eq!(rmq.get_range_min(1, 6), Ok(3)); + assert_eq!(rmq.get_range_min(0, 10), Ok(-4)); + assert_eq!(rmq.get_range_min(8, 9), Ok(6)); + assert_eq!(rmq.get_range_min(4, 3), Err(RangeError::InvalidRange)); + assert_eq!(rmq.get_range_min(0, 1000), Err(RangeError::IndexOutOfBound)); + assert_eq!( + rmq.get_range_min(1000, 1001), + Err(RangeError::IndexOutOfBound) + ); + } + + #[test] + fn float_query_tests() { + let rmq = RangeMinimumQuery::new(&[0.4, -2.3, 0.0, 234.22, 12.2, -3.0]); + + assert_eq!(rmq.get_range_min(0, 6), Ok(-3.0)); + assert_eq!(rmq.get_range_min(0, 4), Ok(-2.3)); + assert_eq!(rmq.get_range_min(3, 5), Ok(12.2)); + assert_eq!(rmq.get_range_min(2, 3), Ok(0.0)); + assert_eq!(rmq.get_range_min(4, 3), Err(RangeError::InvalidRange)); + assert_eq!(rmq.get_range_min(0, 1000), Err(RangeError::IndexOutOfBound)); + assert_eq!( + rmq.get_range_min(1000, 1001), + Err(RangeError::IndexOutOfBound) + ); + } +} diff --git a/src/data_structures/rb_tree.rs b/src/data_structures/rb_tree.rs new file mode 100644 index 00000000000..3465ad5d4d3 --- /dev/null +++ b/src/data_structures/rb_tree.rs @@ -0,0 +1,652 @@ +use std::boxed::Box; +use std::cmp::{Ord, Ordering}; +use std::iter::Iterator; +use std::ptr::null_mut; + +#[derive(Copy, Clone)] +enum Color { + Red, + Black, +} + +pub struct RBNode { + key: K, + value: V, + color: Color, + parent: *mut RBNode, + left: *mut RBNode, + right: *mut RBNode, +} + +impl RBNode { + fn new(key: K, value: V) -> RBNode { + RBNode { + key, + value, + color: Color::Red, + parent: null_mut(), + left: null_mut(), + right: null_mut(), + } + } +} + +pub struct RBTree { + root: *mut RBNode, +} + +impl Default for RBTree { + fn default() -> Self { + Self::new() + } +} + +impl RBTree { + pub fn new() -> RBTree { + RBTree:: { root: null_mut() } + } + + pub fn find(&self, key: &K) -> Option<&V> { + unsafe { + let mut node = self.root; + while !node.is_null() { + node = match (*node).key.cmp(key) { + Ordering::Less => (*node).right, + Ordering::Equal => return Some(&(*node).value), + Ordering::Greater => (*node).left, + } + } + } + None + } + + pub fn insert(&mut self, key: K, value: V) { + unsafe { + let mut parent = null_mut(); + let mut node = self.root; + while !node.is_null() { + parent = node; + node = match (*node).key.cmp(&key) { + Ordering::Less => (*node).right, + Ordering::Equal => { + (*node).value = value; + return; + } + Ordering::Greater => (*node).left, + } + } + node = Box::into_raw(Box::new(RBNode::new(key, value))); + if !parent.is_null() { + if (*node).key < (*parent).key { + (*parent).left = node; + } else { + (*parent).right = node; + } + } else { + self.root = node; + } + (*node).parent = parent; + insert_fixup(self, node); + } + } + + pub fn delete(&mut self, key: &K) { + unsafe { + let mut parent = null_mut(); + let mut node = self.root; + while !node.is_null() { + node = match (*node).key.cmp(key) { + Ordering::Less => { + parent = node; + (*node).right + } + Ordering::Equal => break, + Ordering::Greater => { + parent = node; + (*node).left + } + }; + } + + if node.is_null() { + return; + } + + /* cl and cr denote left and right child of node, respectively. */ + let cl = (*node).left; + let cr = (*node).right; + let mut deleted_color; + + if cl.is_null() { + replace_node(self, parent, node, cr); + if cr.is_null() { + /* + * Case 1 - cl and cr are both NULL + * (n could be either color here) + * + * (n) NULL + * / \ --> + * NULL NULL + */ + + deleted_color = (*node).color; + } else { + /* + * Case 2 - cl is NULL and cr is not NULL + * + * N Cr + * / \ --> / \ + * NULL cr NULL NULL + */ + + (*cr).parent = parent; + (*cr).color = Color::Black; + deleted_color = Color::Red; + } + } else if cr.is_null() { + /* + * Case 3 - cl is not NULL and cr is NULL + * + * N Cl + * / \ --> / \ + * cl NULL NULL NULL + */ + + replace_node(self, parent, node, cl); + (*cl).parent = parent; + (*cl).color = Color::Black; + deleted_color = Color::Red; + } else { + let mut victim = (*node).right; + while !(*victim).left.is_null() { + victim = (*victim).left; + } + if victim == (*node).right { + /* Case 4 - victim is the right child of node + * + * N N n + * / \ / \ / \ + * (cl) cr (cl) Cr Cl Cr + * + * N n + * / \ / \ + * (cl) Cr Cl Cr + * \ \ + * crr crr + */ + + replace_node(self, parent, node, victim); + (*victim).parent = parent; + deleted_color = (*victim).color; + (*victim).color = (*node).color; + (*victim).left = cl; + (*cl).parent = victim; + if (*victim).right.is_null() { + parent = victim; + } else { + deleted_color = Color::Red; + (*(*victim).right).color = Color::Black; + } + } else { + /* + * Case 5 - victim is not the right child of node + */ + + /* vp and vr denote parent and right child of victim, respectively. */ + let vp = (*victim).parent; + let vr = (*victim).right; + (*vp).left = vr; + if vr.is_null() { + deleted_color = (*victim).color; + } else { + deleted_color = Color::Red; + (*vr).parent = vp; + (*vr).color = Color::Black; + } + replace_node(self, parent, node, victim); + (*victim).parent = parent; + (*victim).color = (*node).color; + (*victim).left = cl; + (*victim).right = cr; + (*cl).parent = victim; + (*cr).parent = victim; + parent = vp; + } + } + + /* release resource */ + drop(Box::from_raw(node)); + if matches!(deleted_color, Color::Black) { + delete_fixup(self, parent); + } + } + } + + pub fn iter<'a>(&self) -> RBTreeIterator<'a, K, V> { + let mut iterator = RBTreeIterator { stack: Vec::new() }; + let mut node = self.root; + unsafe { + while !node.is_null() { + iterator.stack.push(&*node); + node = (*node).left; + } + } + iterator + } +} + +#[inline] +unsafe fn insert_fixup(tree: &mut RBTree, mut node: *mut RBNode) { + let mut parent: *mut RBNode = (*node).parent; + let mut gparent: *mut RBNode; + let mut tmp: *mut RBNode; + + loop { + /* + * Loop invariant: + * - node is red + */ + + if parent.is_null() { + (*node).color = Color::Black; + break; + } + + if matches!((*parent).color, Color::Black) { + break; + } + + gparent = (*parent).parent; + tmp = (*gparent).right; + if parent != tmp { + /* parent = (*gparent).left */ + if !tmp.is_null() && matches!((*tmp).color, Color::Red) { + /* + * Case 1 - color flips and recurse at g + * + * G g + * / \ / \ + * p u --> P U + * / / + * n n + */ + + (*parent).color = Color::Black; + (*tmp).color = Color::Black; + (*gparent).color = Color::Red; + node = gparent; + parent = (*node).parent; + continue; + } + tmp = (*parent).right; + if node == tmp { + /* node = (*parent).right */ + /* + * Case 2 - left rotate at p (then Case 3) + * + * G G + * / \ / \ + * p U --> n U + * \ / + * n p + */ + + left_rotate(tree, parent); + parent = node; + } + /* + * Case 3 - right rotate at g + * + * G P + * / \ / \ + * p U --> n g + * / \ + * n U + */ + + (*parent).color = Color::Black; + (*gparent).color = Color::Red; + right_rotate(tree, gparent); + } else { + /* parent = (*gparent).right */ + tmp = (*gparent).left; + if !tmp.is_null() && matches!((*tmp).color, Color::Red) { + /* + * Case 1 - color flips and recurse at g + * G g + * / \ / \ + * u p --> U P + * \ \ + * n n + */ + + (*parent).color = Color::Black; + (*tmp).color = Color::Black; + (*gparent).color = Color::Red; + node = gparent; + parent = (*node).parent; + continue; + } + tmp = (*parent).left; + if node == tmp { + /* + * Case 2 - right rotate at p (then Case 3) + * + * G G + * / \ / \ + * U p --> U n + * / \ + * n p + */ + + right_rotate(tree, parent); + parent = node; + } + /* + * Case 3 - left rotate at g + * + * G P + * / \ / \ + * U p --> g n + * \ / + * n U + */ + + (*parent).color = Color::Black; + (*gparent).color = Color::Red; + left_rotate(tree, gparent); + } + break; + } +} + +#[inline] +unsafe fn delete_fixup(tree: &mut RBTree, mut parent: *mut RBNode) { + let mut node: *mut RBNode = null_mut(); + let mut sibling: *mut RBNode; + /* sl and sr denote left and right child of sibling, respectively. */ + let mut sl: *mut RBNode; + let mut sr: *mut RBNode; + + loop { + /* + * Loop invariants: + * - node is black (or null on first iteration) + * - node is not the root (so parent is not null) + * - All leaf paths going through parent and node have a + * black node count that is 1 lower than other leaf paths. + */ + sibling = (*parent).right; + if node != sibling { + /* node = (*parent).left */ + if matches!((*sibling).color, Color::Red) { + /* + * Case 1 - left rotate at parent + * + * P S + * / \ / \ + * N s --> p Sr + * / \ / \ + * Sl Sr N Sl + */ + + left_rotate(tree, parent); + (*parent).color = Color::Red; + (*sibling).color = Color::Black; + sibling = (*parent).right; + } + sl = (*sibling).left; + sr = (*sibling).right; + + if !sl.is_null() && matches!((*sl).color, Color::Red) { + /* + * Case 2 - right rotate at sibling and then left rotate at parent + * (p and sr could be either color here) + * + * (p) (p) (sl) + * / \ / \ / \ + * N S --> N sl --> P S + * / \ \ / \ + * sl (sr) S N (sr) + * \ + * (sr) + */ + + (*sl).color = (*parent).color; + (*parent).color = Color::Black; + right_rotate(tree, sibling); + left_rotate(tree, parent); + } else if !sr.is_null() && matches!((*sr).color, Color::Red) { + /* + * Case 3 - left rotate at parent + * (p could be either color here) + * + * (p) S + * / \ / \ + * N S --> (p) (sr) + * / \ / \ + * Sl sr N Sl + */ + + (*sr).color = (*parent).color; + left_rotate(tree, parent); + } else { + /* + * Case 4 - color clip + * (p could be either color here) + * + * (p) (p) + * / \ / \ + * N S --> N s + * / \ / \ + * Sl Sr Sl Sr + */ + + (*sibling).color = Color::Red; + if matches!((*parent).color, Color::Black) { + node = parent; + parent = (*node).parent; + continue; + } + (*parent).color = Color::Black; + } + } else { + /* node = (*parent).right */ + sibling = (*parent).left; + if matches!((*sibling).color, Color::Red) { + /* + * Case 1 - right rotate at parent + */ + + right_rotate(tree, parent); + (*parent).color = Color::Red; + (*sibling).color = Color::Black; + sibling = (*parent).right; + } + sl = (*sibling).left; + sr = (*sibling).right; + + if !sr.is_null() && matches!((*sr).color, Color::Red) { + /* + * Case 2 - left rotate at sibling and then right rotate at parent + */ + + (*sr).color = (*parent).color; + (*parent).color = Color::Black; + left_rotate(tree, sibling); + right_rotate(tree, parent); + } else if !sl.is_null() && matches!((*sl).color, Color::Red) { + /* + * Case 3 - right rotate at parent + */ + + (*sl).color = (*parent).color; + right_rotate(tree, parent); + } else { + /* + * Case 4 - color flip + */ + + (*sibling).color = Color::Red; + if matches!((*parent).color, Color::Black) { + node = parent; + parent = (*node).parent; + continue; + } + (*parent).color = Color::Black; + } + } + break; + } +} + +#[inline] +unsafe fn left_rotate(tree: &mut RBTree, x: *mut RBNode) { + /* + * Left rotate at x + * (x could also be the left child of p) + * + * p p + * \ \ + * x --> y + * / \ / \ + * y x + * / \ / \ + * c c + */ + + let p = (*x).parent; + let y = (*x).right; + let c = (*y).left; + + (*y).left = x; + (*x).parent = y; + (*x).right = c; + if !c.is_null() { + (*c).parent = x; + } + if p.is_null() { + tree.root = y; + } else if (*p).left == x { + (*p).left = y; + } else { + (*p).right = y; + } + (*y).parent = p; +} + +#[inline] +unsafe fn right_rotate(tree: &mut RBTree, x: *mut RBNode) { + /* + * Right rotate at x + * (x could also be the left child of p) + * + * p p + * \ \ + * x --> y + * / \ / \ + * y x + * / \ / \ + * c c + */ + + let p = (*x).parent; + let y = (*x).left; + let c = (*y).right; + + (*y).right = x; + (*x).parent = y; + (*x).left = c; + if !c.is_null() { + (*c).parent = x; + } + if p.is_null() { + tree.root = y; + } else if (*p).left == x { + (*p).left = y; + } else { + (*p).right = y; + } + (*y).parent = p; +} + +#[inline] +unsafe fn replace_node( + tree: &mut RBTree, + parent: *mut RBNode, + node: *mut RBNode, + new: *mut RBNode, +) { + if parent.is_null() { + tree.root = new; + } else if (*parent).left == node { + (*parent).left = new; + } else { + (*parent).right = new; + } +} + +pub struct RBTreeIterator<'a, K: Ord, V> { + stack: Vec<&'a RBNode>, +} + +impl<'a, K: Ord, V> Iterator for RBTreeIterator<'a, K, V> { + type Item = &'a RBNode; + fn next(&mut self) -> Option { + match self.stack.pop() { + Some(node) => { + let mut next = node.right; + unsafe { + while !next.is_null() { + self.stack.push(&*next); + next = (*next).left; + } + } + Some(node) + } + None => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::RBTree; + + #[test] + fn find() { + let mut tree = RBTree::::new(); + for (k, v) in "hello, world!".chars().enumerate() { + tree.insert(k, v); + } + assert_eq!(*tree.find(&3).unwrap_or(&'*'), 'l'); + assert_eq!(*tree.find(&6).unwrap_or(&'*'), ' '); + assert_eq!(*tree.find(&8).unwrap_or(&'*'), 'o'); + assert_eq!(*tree.find(&12).unwrap_or(&'*'), '!'); + } + + #[test] + fn insert() { + let mut tree = RBTree::::new(); + for (k, v) in "hello, world!".chars().enumerate() { + tree.insert(k, v); + } + let s: String = tree.iter().map(|x| x.value).collect(); + assert_eq!(s, "hello, world!"); + } + + #[test] + fn delete() { + let mut tree = RBTree::::new(); + for (k, v) in "hello, world!".chars().enumerate() { + tree.insert(k, v); + } + tree.delete(&1); + tree.delete(&3); + tree.delete(&5); + tree.delete(&7); + tree.delete(&11); + let s: String = tree.iter().map(|x| x.value).collect(); + assert_eq!(s, "hlo orl!"); + } +} diff --git a/src/data_structures/segment_tree.rs b/src/data_structures/segment_tree.rs new file mode 100644 index 00000000000..f569381967e --- /dev/null +++ b/src/data_structures/segment_tree.rs @@ -0,0 +1,224 @@ +//! A module providing a Segment Tree data structure for efficient range queries +//! and updates. It supports operations like finding the minimum, maximum, +//! and sum of segments in an array. + +use std::fmt::Debug; +use std::ops::Range; + +/// Custom error types representing possible errors that can occur during operations on the `SegmentTree`. +#[derive(Debug, PartialEq, Eq)] +pub enum SegmentTreeError { + /// Error indicating that an index is out of bounds. + IndexOutOfBounds, + /// Error indicating that a range provided for a query is invalid. + InvalidRange, +} + +/// A structure representing a Segment Tree. This tree can be used to efficiently +/// perform range queries and updates on an array of elements. +pub struct SegmentTree +where + T: Debug + Default + Ord + Copy, + F: Fn(T, T) -> T, +{ + /// The length of the input array for which the segment tree is built. + size: usize, + /// A vector representing the segment tree. + nodes: Vec, + /// A merging function defined as a closure or callable type. + merge_fn: F, +} + +impl SegmentTree +where + T: Debug + Default + Ord + Copy, + F: Fn(T, T) -> T, +{ + /// Creates a new `SegmentTree` from the provided slice of elements. + /// + /// # Arguments + /// + /// * `arr`: A slice of elements of type `T` to initialize the segment tree. + /// * `merge`: A merging function that defines how to merge two elements of type `T`. + /// + /// # Returns + /// + /// A new `SegmentTree` instance populated with the given elements. + pub fn from_vec(arr: &[T], merge: F) -> Self { + let size = arr.len(); + let mut buffer: Vec = vec![T::default(); 2 * size]; + + // Populate the leaves of the tree + buffer[size..(2 * size)].clone_from_slice(arr); + for idx in (1..size).rev() { + buffer[idx] = merge(buffer[2 * idx], buffer[2 * idx + 1]); + } + + SegmentTree { + size, + nodes: buffer, + merge_fn: merge, + } + } + + /// Queries the segment tree for the result of merging the elements in the given range. + /// + /// # Arguments + /// + /// * `range`: A range specified as `Range`, indicating the start (inclusive) + /// and end (exclusive) indices of the segment to query. + /// + /// # Returns + /// + /// * `Ok(Some(result))` if the query was successful and there are elements in the range, + /// * `Ok(None)` if the range is empty, + /// * `Err(SegmentTreeError::InvalidRange)` if the provided range is invalid. + pub fn query(&self, range: Range) -> Result, SegmentTreeError> { + if range.start >= self.size || range.end > self.size { + return Err(SegmentTreeError::InvalidRange); + } + + let mut left = range.start + self.size; + let mut right = range.end + self.size; + let mut result = None; + + // Iterate through the segment tree to accumulate results + while left < right { + if left % 2 == 1 { + result = Some(match result { + None => self.nodes[left], + Some(old) => (self.merge_fn)(old, self.nodes[left]), + }); + left += 1; + } + if right % 2 == 1 { + right -= 1; + result = Some(match result { + None => self.nodes[right], + Some(old) => (self.merge_fn)(old, self.nodes[right]), + }); + } + left /= 2; + right /= 2; + } + + Ok(result) + } + + /// Updates the value at the specified index in the segment tree. + /// + /// # Arguments + /// + /// * `idx`: The index (0-based) of the element to update. + /// * `val`: The new value of type `T` to set at the specified index. + /// + /// # Returns + /// + /// * `Ok(())` if the update was successful, + /// * `Err(SegmentTreeError::IndexOutOfBounds)` if the index is out of bounds. + pub fn update(&mut self, idx: usize, val: T) -> Result<(), SegmentTreeError> { + if idx >= self.size { + return Err(SegmentTreeError::IndexOutOfBounds); + } + + let mut index = idx + self.size; + if self.nodes[index] == val { + return Ok(()); + } + + self.nodes[index] = val; + while index > 1 { + index /= 2; + self.nodes[index] = (self.merge_fn)(self.nodes[2 * index], self.nodes[2 * index + 1]); + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::cmp::{max, min}; + + #[test] + fn test_min_segments() { + let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut min_seg_tree = SegmentTree::from_vec(&vec, min); + assert_eq!(min_seg_tree.query(4..7), Ok(Some(-5))); + assert_eq!(min_seg_tree.query(0..vec.len()), Ok(Some(-30))); + assert_eq!(min_seg_tree.query(0..2), Ok(Some(-30))); + assert_eq!(min_seg_tree.query(1..3), Ok(Some(-4))); + assert_eq!(min_seg_tree.query(1..7), Ok(Some(-5))); + assert_eq!(min_seg_tree.update(5, 10), Ok(())); + assert_eq!(min_seg_tree.update(14, -8), Ok(())); + assert_eq!(min_seg_tree.query(4..7), Ok(Some(3))); + assert_eq!( + min_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(min_seg_tree.query(5..5), Ok(None)); + assert_eq!( + min_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + min_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) + ); + } + + #[test] + fn test_max_segments() { + let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut max_seg_tree = SegmentTree::from_vec(&vec, max); + assert_eq!(max_seg_tree.query(0..vec.len()), Ok(Some(15))); + assert_eq!(max_seg_tree.query(3..5), Ok(Some(7))); + assert_eq!(max_seg_tree.query(4..8), Ok(Some(11))); + assert_eq!(max_seg_tree.query(8..10), Ok(Some(9))); + assert_eq!(max_seg_tree.query(9..12), Ok(Some(15))); + assert_eq!(max_seg_tree.update(4, 10), Ok(())); + assert_eq!(max_seg_tree.update(14, -8), Ok(())); + assert_eq!(max_seg_tree.query(3..5), Ok(Some(10))); + assert_eq!( + max_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(max_seg_tree.query(5..5), Ok(None)); + assert_eq!( + max_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + max_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) + ); + } + + #[test] + fn test_sum_segments() { + let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut sum_seg_tree = SegmentTree::from_vec(&vec, |a, b| a + b); + assert_eq!(sum_seg_tree.query(0..vec.len()), Ok(Some(38))); + assert_eq!(sum_seg_tree.query(1..4), Ok(Some(5))); + assert_eq!(sum_seg_tree.query(4..7), Ok(Some(4))); + assert_eq!(sum_seg_tree.query(6..9), Ok(Some(-3))); + assert_eq!(sum_seg_tree.query(9..vec.len()), Ok(Some(37))); + assert_eq!(sum_seg_tree.update(5, 10), Ok(())); + assert_eq!(sum_seg_tree.update(14, -8), Ok(())); + assert_eq!(sum_seg_tree.query(4..7), Ok(Some(19))); + assert_eq!( + sum_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(sum_seg_tree.query(5..5), Ok(None)); + assert_eq!( + sum_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + sum_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) + ); + } +} diff --git a/src/data_structures/segment_tree_recursive.rs b/src/data_structures/segment_tree_recursive.rs new file mode 100644 index 00000000000..7a64a978563 --- /dev/null +++ b/src/data_structures/segment_tree_recursive.rs @@ -0,0 +1,263 @@ +use std::fmt::Debug; +use std::ops::Range; + +/// Custom error types representing possible errors that can occur during operations on the `SegmentTree`. +#[derive(Debug, PartialEq, Eq)] +pub enum SegmentTreeError { + /// Error indicating that an index is out of bounds. + IndexOutOfBounds, + /// Error indicating that a range provided for a query is invalid. + InvalidRange, +} + +/// A data structure representing a Segment Tree. Which is used for efficient +/// range queries and updates on an array of elements. +pub struct SegmentTree +where + T: Debug + Default + Ord + Copy, + F: Fn(T, T) -> T, +{ + /// The number of elements in the original input array for which the segment tree is built. + size: usize, + /// A vector representing the nodes of the segment tree. + nodes: Vec, + /// A function that merges two elements of type `T`. + merge_fn: F, +} + +impl SegmentTree +where + T: Debug + Default + Ord + Copy, + F: Fn(T, T) -> T, +{ + /// Creates a new `SegmentTree` from the provided slice of elements. + /// + /// # Arguments + /// + /// * `arr`: A slice of elements of type `T` that initializes the segment tree. + /// * `merge_fn`: A merging function that specifies how to combine two elements of type `T`. + /// + /// # Returns + /// + /// A new `SegmentTree` instance initialized with the given elements. + pub fn from_vec(arr: &[T], merge_fn: F) -> Self { + let size = arr.len(); + let mut seg_tree = SegmentTree { + size, + nodes: vec![T::default(); 4 * size], + merge_fn, + }; + if size != 0 { + seg_tree.build_recursive(arr, 1, 0..size); + } + seg_tree + } + + /// Recursively builds the segment tree from the provided array. + /// + /// # Parameters + /// + /// * `arr` - The original array of values. + /// * `node_idx` - The index of the current node in the segment tree. + /// * `node_range` - The range of elements in the original array that the current node covers. + fn build_recursive(&mut self, arr: &[T], node_idx: usize, node_range: Range) { + if node_range.end - node_range.start == 1 { + self.nodes[node_idx] = arr[node_range.start]; + } else { + let mid = node_range.start + (node_range.end - node_range.start) / 2; + self.build_recursive(arr, 2 * node_idx, node_range.start..mid); + self.build_recursive(arr, 2 * node_idx + 1, mid..node_range.end); + self.nodes[node_idx] = + (self.merge_fn)(self.nodes[2 * node_idx], self.nodes[2 * node_idx + 1]); + } + } + + /// Queries the segment tree for the result of merging the elements in the specified range. + /// + /// # Arguments + /// + /// * `target_range`: A range specified as `Range`, indicating the start (inclusive) + /// and end (exclusive) indices of the segment to query. + /// + /// # Returns + /// + /// * `Ok(Some(result))` if the query is successful and there are elements in the range, + /// * `Ok(None)` if the range is empty, + /// * `Err(SegmentTreeError::InvalidRange)` if the provided range is invalid. + pub fn query(&self, target_range: Range) -> Result, SegmentTreeError> { + if target_range.start >= self.size || target_range.end > self.size { + return Err(SegmentTreeError::InvalidRange); + } + Ok(self.query_recursive(1, 0..self.size, &target_range)) + } + + /// Recursively performs a range query to find the merged result of the specified range. + /// + /// # Parameters + /// + /// * `node_idx` - The index of the current node in the segment tree. + /// * `tree_range` - The range of elements covered by the current node. + /// * `target_range` - The range for which the query is being performed. + /// + /// # Returns + /// + /// An `Option` containing the result of the merge operation on the range if within bounds, + /// or `None` if the range is outside the covered range. + fn query_recursive( + &self, + node_idx: usize, + tree_range: Range, + target_range: &Range, + ) -> Option { + if tree_range.start >= target_range.end || tree_range.end <= target_range.start { + return None; + } + if tree_range.start >= target_range.start && tree_range.end <= target_range.end { + return Some(self.nodes[node_idx]); + } + let mid = tree_range.start + (tree_range.end - tree_range.start) / 2; + let left_res = self.query_recursive(node_idx * 2, tree_range.start..mid, target_range); + let right_res = self.query_recursive(node_idx * 2 + 1, mid..tree_range.end, target_range); + match (left_res, right_res) { + (None, None) => None, + (None, Some(r)) => Some(r), + (Some(l), None) => Some(l), + (Some(l), Some(r)) => Some((self.merge_fn)(l, r)), + } + } + + /// Updates the value at the specified index in the segment tree. + /// + /// # Arguments + /// + /// * `target_idx`: The index (0-based) of the element to update. + /// * `val`: The new value of type `T` to set at the specified index. + /// + /// # Returns + /// + /// * `Ok(())` if the update was successful, + /// * `Err(SegmentTreeError::IndexOutOfBounds)` if the index is out of bounds. + pub fn update(&mut self, target_idx: usize, val: T) -> Result<(), SegmentTreeError> { + if target_idx >= self.size { + return Err(SegmentTreeError::IndexOutOfBounds); + } + self.update_recursive(1, 0..self.size, target_idx, val); + Ok(()) + } + + /// Recursively updates the segment tree for a specific index with a new value. + /// + /// # Parameters + /// + /// * `node_idx` - The index of the current node in the segment tree. + /// * `tree_range` - The range of elements covered by the current node. + /// * `target_idx` - The index in the original array to update. + /// * `val` - The new value to set at `target_idx`. + fn update_recursive( + &mut self, + node_idx: usize, + tree_range: Range, + target_idx: usize, + val: T, + ) { + if tree_range.start > target_idx || tree_range.end <= target_idx { + return; + } + if tree_range.end - tree_range.start <= 1 && tree_range.start == target_idx { + self.nodes[node_idx] = val; + return; + } + let mid = tree_range.start + (tree_range.end - tree_range.start) / 2; + self.update_recursive(node_idx * 2, tree_range.start..mid, target_idx, val); + self.update_recursive(node_idx * 2 + 1, mid..tree_range.end, target_idx, val); + self.nodes[node_idx] = + (self.merge_fn)(self.nodes[node_idx * 2], self.nodes[node_idx * 2 + 1]); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::cmp::{max, min}; + + #[test] + fn test_min_segments() { + let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut min_seg_tree = SegmentTree::from_vec(&vec, min); + assert_eq!(min_seg_tree.query(4..7), Ok(Some(-5))); + assert_eq!(min_seg_tree.query(0..vec.len()), Ok(Some(-30))); + assert_eq!(min_seg_tree.query(0..2), Ok(Some(-30))); + assert_eq!(min_seg_tree.query(1..3), Ok(Some(-4))); + assert_eq!(min_seg_tree.query(1..7), Ok(Some(-5))); + assert_eq!(min_seg_tree.update(5, 10), Ok(())); + assert_eq!(min_seg_tree.update(14, -8), Ok(())); + assert_eq!(min_seg_tree.query(4..7), Ok(Some(3))); + assert_eq!( + min_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(min_seg_tree.query(5..5), Ok(None)); + assert_eq!( + min_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + min_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) + ); + } + + #[test] + fn test_max_segments() { + let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut max_seg_tree = SegmentTree::from_vec(&vec, max); + assert_eq!(max_seg_tree.query(0..vec.len()), Ok(Some(15))); + assert_eq!(max_seg_tree.query(3..5), Ok(Some(7))); + assert_eq!(max_seg_tree.query(4..8), Ok(Some(11))); + assert_eq!(max_seg_tree.query(8..10), Ok(Some(9))); + assert_eq!(max_seg_tree.query(9..12), Ok(Some(15))); + assert_eq!(max_seg_tree.update(4, 10), Ok(())); + assert_eq!(max_seg_tree.update(14, -8), Ok(())); + assert_eq!(max_seg_tree.query(3..5), Ok(Some(10))); + assert_eq!( + max_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(max_seg_tree.query(5..5), Ok(None)); + assert_eq!( + max_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + max_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) + ); + } + + #[test] + fn test_sum_segments() { + let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut sum_seg_tree = SegmentTree::from_vec(&vec, |a, b| a + b); + assert_eq!(sum_seg_tree.query(0..vec.len()), Ok(Some(38))); + assert_eq!(sum_seg_tree.query(1..4), Ok(Some(5))); + assert_eq!(sum_seg_tree.query(4..7), Ok(Some(4))); + assert_eq!(sum_seg_tree.query(6..9), Ok(Some(-3))); + assert_eq!(sum_seg_tree.query(9..vec.len()), Ok(Some(37))); + assert_eq!(sum_seg_tree.update(5, 10), Ok(())); + assert_eq!(sum_seg_tree.update(14, -8), Ok(())); + assert_eq!(sum_seg_tree.query(4..7), Ok(Some(19))); + assert_eq!( + sum_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(sum_seg_tree.query(5..5), Ok(None)); + assert_eq!( + sum_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + sum_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) + ); + } +} diff --git a/src/data_structures/stack_using_singly_linked_list.rs b/src/data_structures/stack_using_singly_linked_list.rs new file mode 100644 index 00000000000..f3f6db1d5c9 --- /dev/null +++ b/src/data_structures/stack_using_singly_linked_list.rs @@ -0,0 +1,255 @@ +// The public struct can hide the implementation detail +pub struct Stack { + head: Link, +} + +type Link = Option>>; + +struct Node { + elem: T, + next: Link, +} + +impl Stack { + // Self is an alias for Stack + // We implement associated function name new for single-linked-list + pub fn new() -> Self { + // for new function we need to return a new instance + Self { + // we refer to variants of an enum using :: the namespacing operator + head: None, + } // we need to return the variant, so there without the ; + } + + // Here are the primary forms that self can take are: self, &mut self and &self. + // Since push will modify the linked list, we need a mutable reference `&mut`. + // The push method which the signature's first parameter is self + pub fn push(&mut self, elem: T) { + let new_node = Box::new(Node { + elem, + next: self.head.take(), + }); + // don't forget replace the head with new node for stack + self.head = Some(new_node); + } + + /// The pop function removes the head and returns its value. + /// + /// To do so, we'll need to match the `head` of the list, which is of enum type `Option`.\ + /// It has two variants: `Some(T)` and `None`. + /// * `None` - the list is empty: + /// * return an enum `Result` of variant `Err()`, as there is nothing to pop. + /// * `Some(node)` - the list is not empty: + /// * remove the head of the list, + /// * relink the list's head `head` to its following node `next`, + /// * return `Ok(elem)`. + pub fn pop(&mut self) -> Result { + match self.head.take() { + None => Err("Stack is empty"), + Some(node) => { + self.head = node.next; + Ok(node.elem) + } + } + } + + pub fn is_empty(&self) -> bool { + // Returns true if head is of variant `None`. + self.head.is_none() + } + + pub fn peek(&self) -> Option<&T> { + // Converts from &Option to Option<&T>. + match self.head.as_ref() { + None => None, + Some(node) => Some(&node.elem), + } + } + + pub fn peek_mut(&mut self) -> Option<&mut T> { + match self.head.as_mut() { + None => None, + Some(node) => Some(&mut node.elem), + } + } + + pub fn into_iter_for_stack(self) -> IntoIter { + IntoIter(self) + } + pub fn iter(&self) -> Iter<'_, T> { + Iter { + next: self.head.as_deref(), + } + } + // '_ is the "explicitly elided lifetime" syntax of Rust + pub fn iter_mut(&mut self) -> IterMut<'_, T> { + IterMut { + next: self.head.as_deref_mut(), + } + } +} + +impl Default for Stack { + fn default() -> Self { + Self::new() + } +} + +/// The drop method of singly linked list. +/// +/// Here's a question: *Do we need to worry about cleaning up our list?*\ +/// With the help of the ownership mechanism, the type `List` will be cleaned up automatically (dropped) after it goes out of scope.\ +/// The Rust Compiler does so automacally. In other words, the `Drop` trait is implemented automatically.\ +/// +/// The `Drop` trait is implemented for our type `List` with the following order: `List->Link->Box->Node`.\ +/// The `.drop()` method is tail recursive and will clean the element one by one, this recursion will stop at `Box`\ +/// +/// +/// We wouldn't be able to drop the contents contained by the box after deallocating, so we need to manually write the iterative drop. +impl Drop for Stack { + fn drop(&mut self) { + let mut cur_link = self.head.take(); + while let Some(mut boxed_node) = cur_link { + cur_link = boxed_node.next.take(); + // boxed_node goes out of scope and gets dropped here; + // but its Node's `next` field has been set to None + // so no unbound recursion occurs. + } + } +} + +// Rust has nothing like a yield statement, and there are actually 3 different iterator traits to be implemented + +// Collections are iterated in Rust using the Iterator trait, we define a struct implement Iterator +pub struct IntoIter(Stack); + +impl Iterator for IntoIter { + // This is declaring that every implementation of iterator has an associated type called Item + type Item = T; + // the reason iterator yield Option is because the interface coalesces the `has_next` and `get_next` concepts + fn next(&mut self) -> Option { + self.0.pop().ok() + } +} + +pub struct Iter<'a, T> { + next: Option<&'a Node>, +} + +impl<'a, T> Iterator for Iter<'a, T> { + type Item = &'a T; + fn next(&mut self) -> Option { + self.next.map(|node| { + // as_deref: Converts from Option (or &Option) to Option<&T::Target>. + self.next = node.next.as_deref(); + &node.elem + }) + } +} + +pub struct IterMut<'a, T> { + next: Option<&'a mut Node>, +} + +impl<'a, T> Iterator for IterMut<'a, T> { + type Item = &'a mut T; + fn next(&mut self) -> Option { + // we add take() here due to &mut self isn't Copy(& and Option<&> is Copy) + self.next.take().map(|node| { + self.next = node.next.as_deref_mut(); + &mut node.elem + }) + } +} + +#[cfg(test)] +mod test_stack { + + use super::*; + + #[test] + fn basics() { + let mut list = Stack::new(); + assert_eq!(list.pop(), Err("Stack is empty")); + + list.push(1); + list.push(2); + list.push(3); + + assert_eq!(list.pop(), Ok(3)); + assert_eq!(list.pop(), Ok(2)); + + list.push(4); + list.push(5); + + assert!(!list.is_empty()); + + assert_eq!(list.pop(), Ok(5)); + assert_eq!(list.pop(), Ok(4)); + + assert_eq!(list.pop(), Ok(1)); + assert_eq!(list.pop(), Err("Stack is empty")); + + assert!(list.is_empty()); + } + + #[test] + fn peek() { + let mut list = Stack::new(); + assert_eq!(list.peek(), None); + list.push(1); + list.push(2); + list.push(3); + + assert_eq!(list.peek(), Some(&3)); + assert_eq!(list.peek_mut(), Some(&mut 3)); + + match list.peek_mut() { + None => (), + Some(value) => *value = 42, + }; + + assert_eq!(list.peek(), Some(&42)); + assert_eq!(list.pop(), Ok(42)); + } + + #[test] + fn into_iter() { + let mut list = Stack::new(); + list.push(1); + list.push(2); + list.push(3); + + let mut iter = list.into_iter_for_stack(); + assert_eq!(iter.next(), Some(3)); + assert_eq!(iter.next(), Some(2)); + assert_eq!(iter.next(), Some(1)); + assert_eq!(iter.next(), None); + } + + #[test] + fn iter() { + let mut list = Stack::new(); + list.push(1); + list.push(2); + list.push(3); + + let mut iter = list.iter(); + assert_eq!(iter.next(), Some(&3)); + assert_eq!(iter.next(), Some(&2)); + assert_eq!(iter.next(), Some(&1)); + } + + #[test] + fn iter_mut() { + let mut list = Stack::new(); + list.push(1); + list.push(2); + list.push(3); + + let mut iter = list.iter_mut(); + assert_eq!(iter.next(), Some(&mut 3)); + assert_eq!(iter.next(), Some(&mut 2)); + assert_eq!(iter.next(), Some(&mut 1)); + } +} diff --git a/src/data_structures/treap.rs b/src/data_structures/treap.rs new file mode 100644 index 00000000000..e78e782a66d --- /dev/null +++ b/src/data_structures/treap.rs @@ -0,0 +1,352 @@ +use std::{ + cmp::Ordering, + iter::FromIterator, + mem, + ops::Not, + time::{SystemTime, UNIX_EPOCH}, +}; + +/// An internal node of an `Treap`. +struct TreapNode { + value: T, + priority: usize, + left: Option>>, + right: Option>>, +} + +/// A set based on a Treap (Randomized Binary Search Tree). +/// +/// A Treap is a self-balancing binary search tree. It matains a priority value for each node, such +/// that for every node, its children will have lower priority than itself. So, by just looking at +/// the priority, it is like a heap, and this is where the name, Treap, comes from, Tree + Heap. +pub struct Treap { + root: Option>>, + length: usize, +} + +/// Refers to the left or right subtree of a `Treap`. +#[derive(Clone, Copy)] +enum Side { + Left, + Right, +} + +impl Treap { + pub fn new() -> Treap { + Treap { + root: None, + length: 0, + } + } + + /// Returns `true` if the tree contains a value. + pub fn contains(&self, value: &T) -> bool { + let mut current = &self.root; + while let Some(node) = current { + current = match value.cmp(&node.value) { + Ordering::Equal => return true, + Ordering::Less => &node.left, + Ordering::Greater => &node.right, + } + } + false + } + + /// Adds a value to the tree + /// + /// Returns `true` if the tree did not yet contain the value. + pub fn insert(&mut self, value: T) -> bool { + let inserted = insert(&mut self.root, value); + if inserted { + self.length += 1; + } + inserted + } + + /// Removes a value from the tree. + /// + /// Returns `true` if the tree contained the value. + pub fn remove(&mut self, value: &T) -> bool { + let removed = remove(&mut self.root, value); + if removed { + self.length -= 1; + } + removed + } + + /// Returns the number of values in the tree. + pub fn len(&self) -> usize { + self.length + } + + /// Returns `true` if the tree contains no values. + pub fn is_empty(&self) -> bool { + self.length == 0 + } + + /// Returns an iterator that visits the nodes in the tree in order. + fn node_iter(&self) -> NodeIter { + let mut node_iter = NodeIter { stack: Vec::new() }; + // Initialize stack with path to leftmost child + let mut child = &self.root; + while let Some(node) = child { + node_iter.stack.push(node.as_ref()); + child = &node.left; + } + node_iter + } + + /// Returns an iterator that visits the values in the tree in ascending order. + pub fn iter(&self) -> Iter { + Iter { + node_iter: self.node_iter(), + } + } +} + +/// Generating random number, should use rand::Rng if possible. +fn rand() -> usize { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .subsec_nanos() as usize +} + +/// Recursive helper function for `Treap` insertion. +fn insert(tree: &mut Option>>, value: T) -> bool { + if let Some(node) = tree { + let inserted = match value.cmp(&node.value) { + Ordering::Equal => false, + Ordering::Less => insert(&mut node.left, value), + Ordering::Greater => insert(&mut node.right, value), + }; + if inserted { + node.rebalance(); + } + inserted + } else { + *tree = Some(Box::new(TreapNode { + value, + priority: rand(), + left: None, + right: None, + })); + true + } +} + +/// Recursive helper function for `Treap` deletion +fn remove(tree: &mut Option>>, value: &T) -> bool { + if let Some(node) = tree { + let removed = match value.cmp(&node.value) { + Ordering::Less => remove(&mut node.left, value), + Ordering::Greater => remove(&mut node.right, value), + Ordering::Equal => { + *tree = match (node.left.take(), node.right.take()) { + (None, None) => None, + (Some(b), None) | (None, Some(b)) => Some(b), + (Some(left), Some(right)) => { + let side = match left.priority.cmp(&right.priority) { + Ordering::Greater => Side::Right, + _ => Side::Left, + }; + node.left = Some(left); + node.right = Some(right); + node.rotate(side); + remove(node.child_mut(side), value); + Some(tree.take().unwrap()) + } + }; + return true; + } + }; + if removed { + node.rebalance(); + } + removed + } else { + false + } +} + +impl TreapNode { + /// Returns a reference to the left or right child. + fn child(&self, side: Side) -> &Option>> { + match side { + Side::Left => &self.left, + Side::Right => &self.right, + } + } + + /// Returns a mutable reference to the left or right child. + fn child_mut(&mut self, side: Side) -> &mut Option>> { + match side { + Side::Left => &mut self.left, + Side::Right => &mut self.right, + } + } + + /// Returns the priority of the left or right subtree. + fn priority(&self, side: Side) -> usize { + self.child(side).as_ref().map_or(0, |n| n.priority) + } + + /// Performs a left or right rotation + fn rotate(&mut self, side: Side) { + if self.child_mut(!side).is_none() { + return; + } + + let mut subtree = self.child_mut(!side).take().unwrap(); + *self.child_mut(!side) = subtree.child_mut(side).take(); + // Swap root and child nodes in memory + mem::swap(self, subtree.as_mut()); + // Set old root (subtree) as child of new root (self) + *self.child_mut(side) = Some(subtree); + } + + /// Performs left or right tree rotations to balance this node. + fn rebalance(&mut self) { + match ( + self.priority, + self.priority(Side::Left), + self.priority(Side::Right), + ) { + (v, p, q) if p >= q && p > v => self.rotate(Side::Right), + (v, p, q) if p < q && q > v => self.rotate(Side::Left), + _ => (), + }; + } + + #[cfg(test)] + fn is_valid(&self) -> bool { + self.priority >= self.priority(Side::Left) && self.priority >= self.priority(Side::Right) + } +} + +impl Default for Treap { + fn default() -> Self { + Self::new() + } +} + +impl Not for Side { + type Output = Side; + + fn not(self) -> Self::Output { + match self { + Side::Left => Side::Right, + Side::Right => Side::Left, + } + } +} + +impl FromIterator for Treap { + fn from_iter>(iter: I) -> Self { + let mut tree = Treap::new(); + for value in iter { + tree.insert(value); + } + tree + } +} + +/// An iterator over the nodes of an `Treap`. +/// +/// This struct is created by the `node_iter` method of `Treap`. +struct NodeIter<'a, T: Ord> { + stack: Vec<&'a TreapNode>, +} + +impl<'a, T: Ord> Iterator for NodeIter<'a, T> { + type Item = &'a TreapNode; + + fn next(&mut self) -> Option { + if let Some(node) = self.stack.pop() { + // Push left path of right subtree to stack + let mut child = &node.right; + while let Some(subtree) = child { + self.stack.push(subtree.as_ref()); + child = &subtree.left; + } + Some(node) + } else { + None + } + } +} + +/// An iterator over the items of an `Treap`. +/// +/// This struct is created by the `iter` method of `Treap`. +pub struct Iter<'a, T: Ord> { + node_iter: NodeIter<'a, T>, +} + +impl<'a, T: Ord> Iterator for Iter<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option<&'a T> { + match self.node_iter.next() { + Some(node) => Some(&node.value), + None => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::Treap; + + /// Returns `true` if all nodes in the tree are valid. + fn is_valid(tree: &Treap) -> bool { + tree.node_iter().all(|n| n.is_valid()) + } + + #[test] + fn len() { + let tree: Treap<_> = (1..4).collect(); + assert_eq!(tree.len(), 3); + } + + #[test] + fn contains() { + let tree: Treap<_> = (1..4).collect(); + assert!(tree.contains(&1)); + assert!(!tree.contains(&4)); + } + + #[test] + fn insert() { + let mut tree = Treap::new(); + // First insert succeeds + assert!(tree.insert(1)); + // Second insert fails + assert!(!tree.insert(1)); + } + + #[test] + fn remove() { + let mut tree: Treap<_> = (1..8).collect(); + // First remove succeeds + assert!(tree.remove(&4)); + // Second remove fails + assert!(!tree.remove(&4)); + } + + #[test] + fn sorted() { + let tree: Treap<_> = (1..8).rev().collect(); + assert!((1..8).eq(tree.iter().copied())); + } + + #[test] + fn valid() { + let mut tree: Treap<_> = (1..8).collect(); + assert!(is_valid(&tree)); + for x in 1..8 { + tree.remove(&x); + assert!(is_valid(&tree)); + } + } +} diff --git a/src/data_structures/trie.rs b/src/data_structures/trie.rs new file mode 100644 index 00000000000..ed05b0a509f --- /dev/null +++ b/src/data_structures/trie.rs @@ -0,0 +1,155 @@ +//! This module provides a generic implementation of a Trie (prefix tree). +//! A Trie is a tree-like data structure that is commonly used to store sequences of keys +//! (such as strings, integers, or other iterable types) where each node represents one element +//! of the key, and values can be associated with full sequences. + +use std::collections::HashMap; +use std::hash::Hash; + +/// A single node in the Trie structure, representing a key and an optional value. +#[derive(Debug, Default)] +struct Node { + /// A map of children nodes where each key maps to another `Node`. + children: HashMap>, + /// The value associated with this node, if any. + value: Option, +} + +/// A generic Trie (prefix tree) data structure that allows insertion and lookup +/// based on a sequence of keys. +#[derive(Debug, Default)] +pub struct Trie +where + Key: Default + Eq + Hash, + Type: Default, +{ + /// The root node of the Trie, which does not hold a value itself. + root: Node, +} + +impl Trie +where + Key: Default + Eq + Hash, + Type: Default, +{ + /// Creates a new, empty `Trie`. + /// + /// # Returns + /// A `Trie` instance with an empty root node. + pub fn new() -> Self { + Self { + root: Node::default(), + } + } + + /// Inserts a value into the Trie, associating it with a sequence of keys. + /// + /// # Arguments + /// - `key`: An iterable sequence of keys (e.g., characters in a string or integers in a vector). + /// - `value`: The value to associate with the sequence of keys. + pub fn insert(&mut self, key: impl IntoIterator, value: Type) + where + Key: Eq + Hash, + { + let mut node = &mut self.root; + for c in key { + node = node.children.entry(c).or_default(); + } + node.value = Some(value); + } + + /// Retrieves a reference to the value associated with a sequence of keys, if it exists. + /// + /// # Arguments + /// - `key`: An iterable sequence of keys (e.g., characters in a string or integers in a vector). + /// + /// # Returns + /// An `Option` containing a reference to the value if the sequence of keys exists in the Trie, + /// or `None` if it does not. + pub fn get(&self, key: impl IntoIterator) -> Option<&Type> + where + Key: Eq + Hash, + { + let mut node = &self.root; + for c in key { + node = node.children.get(&c)?; + } + node.value.as_ref() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_insertion_and_retrieval_with_strings() { + let mut trie = Trie::new(); + + trie.insert("foo".chars(), 1); + assert_eq!(trie.get("foo".chars()), Some(&1)); + trie.insert("foobar".chars(), 2); + assert_eq!(trie.get("foobar".chars()), Some(&2)); + assert_eq!(trie.get("foo".chars()), Some(&1)); + trie.insert("bar".chars(), 3); + assert_eq!(trie.get("bar".chars()), Some(&3)); + assert_eq!(trie.get("baz".chars()), None); + assert_eq!(trie.get("foobarbaz".chars()), None); + } + + #[test] + fn test_insertion_and_retrieval_with_integers() { + let mut trie = Trie::new(); + + trie.insert(vec![1, 2, 3], 1); + assert_eq!(trie.get(vec![1, 2, 3]), Some(&1)); + trie.insert(vec![1, 2, 3, 4, 5], 2); + assert_eq!(trie.get(vec![1, 2, 3, 4, 5]), Some(&2)); + assert_eq!(trie.get(vec![1, 2, 3]), Some(&1)); + trie.insert(vec![10, 20, 30], 3); + assert_eq!(trie.get(vec![10, 20, 30]), Some(&3)); + assert_eq!(trie.get(vec![4, 5, 6]), None); + assert_eq!(trie.get(vec![1, 2, 3, 4, 5, 6]), None); + } + + #[test] + fn test_empty_trie() { + let trie: Trie = Trie::new(); + + assert_eq!(trie.get("foo".chars()), None); + assert_eq!(trie.get("".chars()), None); + } + + #[test] + fn test_insert_empty_key() { + let mut trie: Trie = Trie::new(); + + trie.insert("".chars(), 42); + assert_eq!(trie.get("".chars()), Some(&42)); + assert_eq!(trie.get("foo".chars()), None); + } + + #[test] + fn test_overlapping_keys() { + let mut trie = Trie::new(); + + trie.insert("car".chars(), 1); + trie.insert("cart".chars(), 2); + trie.insert("carter".chars(), 3); + assert_eq!(trie.get("car".chars()), Some(&1)); + assert_eq!(trie.get("cart".chars()), Some(&2)); + assert_eq!(trie.get("carter".chars()), Some(&3)); + assert_eq!(trie.get("care".chars()), None); + } + + #[test] + fn test_partial_match() { + let mut trie = Trie::new(); + + trie.insert("apple".chars(), 10); + assert_eq!(trie.get("app".chars()), None); + assert_eq!(trie.get("appl".chars()), None); + assert_eq!(trie.get("apple".chars()), Some(&10)); + assert_eq!(trie.get("applepie".chars()), None); + } +} diff --git a/src/data_structures/union_find.rs b/src/data_structures/union_find.rs new file mode 100644 index 00000000000..b7cebd18c06 --- /dev/null +++ b/src/data_structures/union_find.rs @@ -0,0 +1,228 @@ +//! A Union-Find (Disjoint Set) data structure implementation in Rust. +//! +//! The Union-Find data structure keeps track of elements partitioned into +//! disjoint (non-overlapping) sets. +//! It provides near-constant-time operations to add new sets, to find the +//! representative of a set, and to merge sets. + +use std::cmp::Ordering; +use std::collections::HashMap; +use std::fmt::Debug; +use std::hash::Hash; + +#[derive(Debug)] +pub struct UnionFind { + payloads: HashMap, // Maps values to their indices in the parent_links array. + parent_links: Vec, // Holds the parent pointers; root elements are their own parents. + sizes: Vec, // Holds the sizes of the sets. + count: usize, // Number of disjoint sets. +} + +impl UnionFind { + /// Creates an empty Union-Find structure with a specified capacity. + pub fn with_capacity(capacity: usize) -> Self { + Self { + parent_links: Vec::with_capacity(capacity), + sizes: Vec::with_capacity(capacity), + payloads: HashMap::with_capacity(capacity), + count: 0, + } + } + + /// Inserts a new item (disjoint set) into the data structure. + pub fn insert(&mut self, item: T) { + let key = self.payloads.len(); + self.parent_links.push(key); + self.sizes.push(1); + self.payloads.insert(item, key); + self.count += 1; + } + + /// Returns the root index of the set containing the given value, or `None` if it doesn't exist. + pub fn find(&mut self, value: &T) -> Option { + self.payloads + .get(value) + .copied() + .map(|key| self.find_by_key(key)) + } + + /// Unites the sets containing the two given values. Returns: + /// - `None` if either value hasn't been inserted, + /// - `Some(true)` if two disjoint sets have been merged, + /// - `Some(false)` if both elements were already in the same set. + pub fn union(&mut self, first_item: &T, sec_item: &T) -> Option { + let (first_root, sec_root) = (self.find(first_item), self.find(sec_item)); + match (first_root, sec_root) { + (Some(first_root), Some(sec_root)) => Some(self.union_by_key(first_root, sec_root)), + _ => None, + } + } + + /// Finds the root of the set containing the element with the given index. + fn find_by_key(&mut self, key: usize) -> usize { + if self.parent_links[key] != key { + self.parent_links[key] = self.find_by_key(self.parent_links[key]); + } + self.parent_links[key] + } + + /// Unites the sets containing the two elements identified by their indices. + fn union_by_key(&mut self, first_key: usize, sec_key: usize) -> bool { + let (first_root, sec_root) = (self.find_by_key(first_key), self.find_by_key(sec_key)); + + if first_root == sec_root { + return false; + } + + match self.sizes[first_root].cmp(&self.sizes[sec_root]) { + Ordering::Less => { + self.parent_links[first_root] = sec_root; + self.sizes[sec_root] += self.sizes[first_root]; + } + _ => { + self.parent_links[sec_root] = first_root; + self.sizes[first_root] += self.sizes[sec_root]; + } + } + + self.count -= 1; + true + } + + /// Checks if two items belong to the same set. + pub fn is_same_set(&mut self, first_item: &T, sec_item: &T) -> bool { + matches!((self.find(first_item), self.find(sec_item)), (Some(first_root), Some(sec_root)) if first_root == sec_root) + } + + /// Returns the number of disjoint sets. + pub fn count(&self) -> usize { + self.count + } +} + +impl Default for UnionFind { + fn default() -> Self { + Self { + parent_links: Vec::default(), + sizes: Vec::default(), + payloads: HashMap::default(), + count: 0, + } + } +} + +impl FromIterator for UnionFind { + /// Creates a new UnionFind data structure from an iterable of disjoint elements. + fn from_iter>(iter: I) -> Self { + let mut uf = UnionFind::default(); + for item in iter { + uf.insert(item); + } + uf + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_union_find() { + let mut uf = (0..10).collect::>(); + assert_eq!(uf.find(&0), Some(0)); + assert_eq!(uf.find(&1), Some(1)); + assert_eq!(uf.find(&2), Some(2)); + assert_eq!(uf.find(&3), Some(3)); + assert_eq!(uf.find(&4), Some(4)); + assert_eq!(uf.find(&5), Some(5)); + assert_eq!(uf.find(&6), Some(6)); + assert_eq!(uf.find(&7), Some(7)); + assert_eq!(uf.find(&8), Some(8)); + assert_eq!(uf.find(&9), Some(9)); + + assert!(!uf.is_same_set(&0, &1)); + assert!(!uf.is_same_set(&2, &9)); + assert_eq!(uf.count(), 10); + + assert_eq!(uf.union(&0, &1), Some(true)); + assert_eq!(uf.union(&1, &2), Some(true)); + assert_eq!(uf.union(&2, &3), Some(true)); + assert_eq!(uf.union(&0, &2), Some(false)); + assert_eq!(uf.union(&4, &5), Some(true)); + assert_eq!(uf.union(&5, &6), Some(true)); + assert_eq!(uf.union(&6, &7), Some(true)); + assert_eq!(uf.union(&7, &8), Some(true)); + assert_eq!(uf.union(&8, &9), Some(true)); + assert_eq!(uf.union(&7, &9), Some(false)); + + assert_ne!(uf.find(&0), uf.find(&9)); + assert_eq!(uf.find(&0), uf.find(&3)); + assert_eq!(uf.find(&4), uf.find(&9)); + assert!(uf.is_same_set(&0, &3)); + assert!(uf.is_same_set(&4, &9)); + assert!(!uf.is_same_set(&0, &9)); + assert_eq!(uf.count(), 2); + + assert_eq!(Some(true), uf.union(&3, &4)); + assert_eq!(uf.find(&0), uf.find(&9)); + assert_eq!(uf.count(), 1); + assert!(uf.is_same_set(&0, &9)); + + assert_eq!(None, uf.union(&0, &11)); + } + + #[test] + fn test_spanning_tree() { + let mut uf = UnionFind::from_iter(["A", "B", "C", "D", "E", "F", "G"]); + uf.union(&"A", &"B"); + uf.union(&"B", &"C"); + uf.union(&"A", &"D"); + uf.union(&"F", &"G"); + + assert_eq!(None, uf.union(&"A", &"W")); + + assert_eq!(uf.find(&"A"), uf.find(&"B")); + assert_eq!(uf.find(&"A"), uf.find(&"C")); + assert_eq!(uf.find(&"B"), uf.find(&"D")); + assert_ne!(uf.find(&"A"), uf.find(&"E")); + assert_ne!(uf.find(&"A"), uf.find(&"F")); + assert_eq!(uf.find(&"G"), uf.find(&"F")); + assert_ne!(uf.find(&"G"), uf.find(&"E")); + + assert!(uf.is_same_set(&"A", &"B")); + assert!(uf.is_same_set(&"A", &"C")); + assert!(uf.is_same_set(&"B", &"D")); + assert!(!uf.is_same_set(&"B", &"F")); + assert!(!uf.is_same_set(&"E", &"A")); + assert!(!uf.is_same_set(&"E", &"G")); + assert_eq!(uf.count(), 3); + } + + #[test] + fn test_with_capacity() { + let mut uf: UnionFind = UnionFind::with_capacity(5); + uf.insert(0); + uf.insert(1); + uf.insert(2); + uf.insert(3); + uf.insert(4); + + assert_eq!(uf.count(), 5); + + assert_eq!(uf.union(&0, &1), Some(true)); + assert!(uf.is_same_set(&0, &1)); + assert_eq!(uf.count(), 4); + + assert_eq!(uf.union(&2, &3), Some(true)); + assert!(uf.is_same_set(&2, &3)); + assert_eq!(uf.count(), 3); + + assert_eq!(uf.union(&0, &2), Some(true)); + assert!(uf.is_same_set(&0, &1)); + assert!(uf.is_same_set(&2, &3)); + assert!(uf.is_same_set(&0, &3)); + assert_eq!(uf.count(), 2); + + assert_eq!(None, uf.union(&0, &10)); + } +} diff --git a/src/data_structures/veb_tree.rs b/src/data_structures/veb_tree.rs new file mode 100644 index 00000000000..fe5fd7fc06d --- /dev/null +++ b/src/data_structures/veb_tree.rs @@ -0,0 +1,342 @@ +// This struct implements Van Emde Boas tree (VEB tree). It stores integers in range [0, U), where +// O is any integer that is a power of 2. It supports operations such as insert, search, +// predecessor, and successor in O(log(log(U))) time. The structure takes O(U) space. +pub struct VebTree { + size: u32, + child_size: u32, // Set to square root of size. Cache here to avoid recomputation. + min: u32, + max: u32, + summary: Option>, + cluster: Vec, +} + +impl VebTree { + /// Create a new, empty VEB tree. The tree will contain number of elements equal to size + /// rounded up to the nearest power of two. + pub fn new(size: u32) -> VebTree { + let rounded_size = size.next_power_of_two(); + let child_size = (size as f64).sqrt().ceil() as u32; + + let mut cluster = Vec::new(); + if rounded_size > 2 { + for _ in 0..rounded_size { + cluster.push(VebTree::new(child_size)); + } + } + + VebTree { + size: rounded_size, + child_size, + min: u32::MAX, + max: u32::MIN, + cluster, + summary: if rounded_size <= 2 { + None + } else { + Some(Box::new(VebTree::new(child_size))) + }, + } + } + + fn high(&self, value: u32) -> u32 { + value / self.child_size + } + + fn low(&self, value: u32) -> u32 { + value % self.child_size + } + + fn index(&self, cluster: u32, offset: u32) -> u32 { + cluster * self.child_size + offset + } + + pub fn min(&self) -> u32 { + self.min + } + + pub fn max(&self) -> u32 { + self.max + } + + pub fn iter(&self) -> VebTreeIter { + VebTreeIter::new(self) + } + + // A VEB tree is empty if the min is greater than the max. + pub fn empty(&self) -> bool { + self.min > self.max + } + + // Returns true if value is in the tree, false otherwise. + pub fn search(&self, value: u32) -> bool { + if self.empty() { + return false; + } else if value == self.min || value == self.max { + return true; + } else if value < self.min || value > self.max { + return false; + } + self.cluster[self.high(value) as usize].search(self.low(value)) + } + + fn insert_empty(&mut self, value: u32) { + assert!(self.empty(), "tree should be empty"); + self.min = value; + self.max = value; + } + + // Inserts value into the tree. + pub fn insert(&mut self, mut value: u32) { + assert!(value < self.size); + + if self.empty() { + self.insert_empty(value); + return; + } + + if value < self.min { + // If the new value is less than the current tree's min, set the min to the new value + // and insert the old min. + (value, self.min) = (self.min, value); + } + + if self.size > 2 { + // Non base case. The checks for min/max will handle trees of size 2. + let high = self.high(value); + let low = self.low(value); + if self.cluster[high as usize].empty() { + // If the cluster tree for the value is empty, we set the min/max of the tree to + // value and record that the cluster tree has an elements in the summary. + self.cluster[high as usize].insert_empty(low); + if let Some(summary) = self.summary.as_mut() { + summary.insert(high); + } + } else { + // If the cluster tree already has a value, the summary does not need to be + // updated. Recursively insert the value into the cluster tree. + self.cluster[high as usize].insert(low); + } + } + + if value > self.max { + self.max = value; + } + } + + // Returns the next greatest value(successor) in the tree after pred. Returns + // `None` if there is no successor. + pub fn succ(&self, pred: u32) -> Option { + if self.empty() { + return None; + } + + if self.size == 2 { + // Base case. If pred is 0, and 1 exists in the tree (max is set to 1), the successor + // is 1. + return if pred == 0 && self.max == 1 { + Some(1) + } else { + None + }; + } + + if pred < self.min { + // If the predecessor is less than the minimum of this tree, the successor is the min. + return Some(self.min); + } + + let low = self.low(pred); + let high = self.high(pred); + + if !self.cluster[high as usize].empty() && low < self.cluster[high as usize].max { + // The successor is within the same cluster as the predecessor + return Some(self.index(high, self.cluster[high as usize].succ(low).unwrap())); + }; + + // If we reach this point, the successor exists in a different cluster. We use the summary + // to efficiently query which cluster the successor lives in. If there is no successor + // cluster, return None. + let succ_cluster = self.summary.as_ref().unwrap().succ(high); + succ_cluster + .map(|succ_cluster| self.index(succ_cluster, self.cluster[succ_cluster as usize].min)) + } + + // Returns the next smallest value(predecessor) in the tree after succ. Returns + // `None` if there is no predecessor. pred() is almost a mirror of succ(). + // Differences are noted in comments. + pub fn pred(&self, succ: u32) -> Option { + if self.empty() { + return None; + } + + // base case. + if self.size == 2 { + return if succ == 1 && self.min == 0 { + Some(0) + } else { + None + }; + } + + if succ > self.max { + return Some(self.max); + } + + let low = self.low(succ); + let high = self.high(succ); + + if !self.cluster[high as usize].empty() && low > self.cluster[high as usize].min { + return Some(self.index(high, self.cluster[high as usize].pred(low).unwrap())); + }; + + // Find the cluster that has the predecessor. The successor will be that cluster's max. + let succ_cluster = self.summary.as_ref().unwrap().pred(high); + match succ_cluster { + Some(succ_cluster) => { + Some(self.index(succ_cluster, self.cluster[succ_cluster as usize].max)) + } + // Special case for pred() that does not exist in succ(). The current tree's min + // does not exist in a cluster. So if we cannot find a cluster that could have the + // predecessor, the predecessor could be the min of the current tree. + None => { + if succ > self.min { + Some(self.min) + } else { + None + } + } + } + } +} + +pub struct VebTreeIter<'a> { + tree: &'a VebTree, + curr: Option, +} + +impl<'a> VebTreeIter<'a> { + pub fn new(tree: &'a VebTree) -> VebTreeIter<'a> { + let curr = if tree.empty() { None } else { Some(tree.min) }; + VebTreeIter { tree, curr } + } +} + +impl Iterator for VebTreeIter<'_> { + type Item = u32; + + fn next(&mut self) -> Option { + let curr = self.curr; + curr?; + self.curr = self.tree.succ(curr.unwrap()); + curr + } +} + +#[cfg(test)] +mod test { + use super::VebTree; + use rand::{rngs::StdRng, Rng, SeedableRng}; + + fn test_veb_tree(size: u32, mut elements: Vec, exclude: Vec) { + // Insert elements + let mut tree = VebTree::new(size); + for element in elements.iter() { + tree.insert(*element); + } + + // Test search + for element in elements.iter() { + assert!(tree.search(*element)); + } + for element in exclude { + assert!(!tree.search(element)); + } + + // Test iterator and successor, and predecessor + elements.sort(); + elements.dedup(); + for (i, element) in tree.iter().enumerate() { + assert!(elements[i] == element); + } + for i in 1..elements.len() { + assert!(tree.succ(elements[i - 1]) == Some(elements[i])); + assert!(tree.pred(elements[i]) == Some(elements[i - 1])); + } + } + + #[test] + fn test_empty() { + test_veb_tree(16, Vec::new(), (0..16).collect()); + } + + #[test] + fn test_single() { + test_veb_tree(16, Vec::from([5]), (0..16).filter(|x| *x != 5).collect()); + } + + #[test] + fn test_two() { + test_veb_tree( + 16, + Vec::from([4, 9]), + (0..16).filter(|x| *x != 4 && *x != 9).collect(), + ); + } + + #[test] + fn test_repeat_insert() { + let mut tree = VebTree::new(16); + for _ in 0..5 { + tree.insert(10); + } + assert!(tree.search(10)); + let elements: Vec = (0..16).filter(|x| *x != 10).collect(); + for element in elements { + assert!(!tree.search(element)); + } + } + + #[test] + fn test_linear() { + test_veb_tree(16, (0..10).collect(), (10..16).collect()); + } + + fn test_full(size: u32) { + test_veb_tree(size, (0..size).collect(), Vec::new()); + } + + #[test] + fn test_full_small() { + test_full(8); + test_full(10); + test_full(16); + test_full(20); + test_full(32); + } + + #[test] + fn test_full_256() { + test_full(256); + } + + #[test] + fn test_10_256() { + let mut rng = StdRng::seed_from_u64(0); + let elements: Vec = (0..10).map(|_| rng.random_range(0..255)).collect(); + test_veb_tree(256, elements, Vec::new()); + } + + #[test] + fn test_100_256() { + let mut rng = StdRng::seed_from_u64(0); + let elements: Vec = (0..100).map(|_| rng.random_range(0..255)).collect(); + test_veb_tree(256, elements, Vec::new()); + } + + #[test] + fn test_100_300() { + let mut rng = StdRng::seed_from_u64(0); + let elements: Vec = (0..100).map(|_| rng.random_range(0..255)).collect(); + test_veb_tree(300, elements, Vec::new()); + } +} diff --git a/src/dynamic_programming/coin_change.rs b/src/dynamic_programming/coin_change.rs index 5416c268df7..2bfd573a9c0 100644 --- a/src/dynamic_programming/coin_change.rs +++ b/src/dynamic_programming/coin_change.rs @@ -1,67 +1,94 @@ -/// Coin change via Dynamic Programming +//! This module provides a solution to the coin change problem using dynamic programming. +//! The `coin_change` function calculates the fewest number of coins required to make up +//! a given amount using a specified set of coin denominations. +//! +//! The implementation leverages dynamic programming to build up solutions for smaller +//! amounts and combines them to solve for larger amounts. It ensures optimal substructure +//! and overlapping subproblems are efficiently utilized to achieve the solution. -/// coin_change(coins, amount) returns the fewest number of coins that need to make up that amount. -/// If that amount of money cannot be made up by any combination of the coins, return `None`. +//! # Complexity +//! - Time complexity: O(amount * coins.length) +//! - Space complexity: O(amount) + +/// Returns the fewest number of coins needed to make up the given amount using the provided coin denominations. +/// If the amount cannot be made up by any combination of the coins, returns `None`. +/// +/// # Arguments +/// * `coins` - A slice of coin denominations. +/// * `amount` - The total amount of money to be made up. +/// +/// # Returns +/// * `Option` - The minimum number of coins required to make up the amount, or `None` if it's not possible. /// -/// Arguments: -/// * `coins` - coins of different denominations -/// * `amount` - a total amount of money be made up. -/// Complexity -/// - time complexity: O(amount * coins.length), -/// - space complexity: O(amount), +/// # Complexity +/// * Time complexity: O(amount * coins.length) +/// * Space complexity: O(amount) pub fn coin_change(coins: &[usize], amount: usize) -> Option { - let mut dp = vec![std::usize::MAX; amount + 1]; - dp[0] = 0; + let mut min_coins = vec![None; amount + 1]; + min_coins[0] = Some(0); - // Assume dp[i] is the fewest number of coins making up amount i, - // then for every coin in coins, dp[i] = min(dp[i - coin] + 1). - for i in 0..=amount { - for j in 0..coins.len() { - if i >= coins[j] && dp[i - coins[j]] != std::usize::MAX { - dp[i] = dp[i].min(dp[i - coins[j]] + 1); - } - } - } + (0..=amount).for_each(|curr_amount| { + coins + .iter() + .filter(|&&coin| curr_amount >= coin) + .for_each(|&coin| { + if let Some(prev_min_coins) = min_coins[curr_amount - coin] { + min_coins[curr_amount] = Some( + min_coins[curr_amount].map_or(prev_min_coins + 1, |curr_min_coins| { + curr_min_coins.min(prev_min_coins + 1) + }), + ); + } + }); + }); - match dp[amount] { - std::usize::MAX => None, - _ => Some(dp[amount]), - } + min_coins[amount] } #[cfg(test)] mod tests { use super::*; - #[test] - fn basic() { - // 11 = 5 * 2 + 1 * 1 - let coins = vec![1, 2, 5]; - assert_eq!(Some(3), coin_change(&coins, 11)); - - // 119 = 11 * 10 + 7 * 1 + 2 * 1 - let coins = vec![2, 3, 5, 7, 11]; - assert_eq!(Some(12), coin_change(&coins, 119)); - } - - #[test] - fn coins_empty() { - let coins = vec![]; - assert_eq!(None, coin_change(&coins, 1)); - } - - #[test] - fn amount_zero() { - let coins = vec![1, 2, 3]; - assert_eq!(Some(0), coin_change(&coins, 0)); + macro_rules! coin_change_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (coins, amount, expected) = $test_case; + assert_eq!(expected, coin_change(&coins, amount)); + } + )* + } } - #[test] - fn fail_change() { - // 3 can't be change by 2. - let coins = vec![2]; - assert_eq!(None, coin_change(&coins, 3)); - let coins = vec![10, 20, 50, 100]; - assert_eq!(None, coin_change(&coins, 5)); + coin_change_tests! { + test_basic_case: (vec![1, 2, 5], 11, Some(3)), + test_multiple_denominations: (vec![2, 3, 5, 7, 11], 119, Some(12)), + test_empty_coins: (vec![], 1, None), + test_zero_amount: (vec![1, 2, 3], 0, Some(0)), + test_no_solution_small_coin: (vec![2], 3, None), + test_no_solution_large_coin: (vec![10, 20, 50, 100], 5, None), + test_single_coin_large_amount: (vec![1], 100, Some(100)), + test_large_amount_multiple_coins: (vec![1, 2, 5], 10000, Some(2000)), + test_no_combination_possible: (vec![3, 7], 5, None), + test_exact_combination: (vec![1, 3, 4], 6, Some(2)), + test_large_denomination_multiple_coins: (vec![10, 50, 100], 1000, Some(10)), + test_small_amount_not_possible: (vec![5, 10], 1, None), + test_non_divisible_amount: (vec![2], 3, None), + test_all_multiples: (vec![1, 2, 4, 8], 15, Some(4)), + test_large_amount_mixed_coins: (vec![1, 5, 10, 25], 999, Some(45)), + test_prime_coins_and_amount: (vec![2, 3, 5, 7], 17, Some(3)), + test_coins_larger_than_amount: (vec![5, 10, 20], 1, None), + test_repeating_denominations: (vec![1, 1, 1, 5], 8, Some(4)), + test_non_standard_denominations: (vec![1, 4, 6, 9], 15, Some(2)), + test_very_large_denominations: (vec![1000, 2000, 5000], 1, None), + test_large_amount_performance: (vec![1, 5, 10, 25, 50, 100, 200, 500], 9999, Some(29)), + test_powers_of_two: (vec![1, 2, 4, 8, 16, 32, 64], 127, Some(7)), + test_fibonacci_sequence: (vec![1, 2, 3, 5, 8, 13, 21, 34], 55, Some(2)), + test_mixed_small_large: (vec![1, 100, 1000, 10000], 11001, Some(3)), + test_impossible_combinations: (vec![2, 4, 6, 8], 7, None), + test_greedy_approach_does_not_work: (vec![1, 12, 20], 24, Some(2)), + test_zero_denominations_no_solution: (vec![0], 1, None), + test_zero_denominations_solution: (vec![0], 0, Some(0)), } } diff --git a/src/dynamic_programming/edit_distance.rs b/src/dynamic_programming/edit_distance.rs deleted file mode 100644 index e862ec7eb9e..00000000000 --- a/src/dynamic_programming/edit_distance.rs +++ /dev/null @@ -1,119 +0,0 @@ -//! Compute the edit distance between two strings - -use std::cmp::min; - -/// edit_distance(str_a, str_b) returns the edit distance between the two -/// strings This edit distance is defined as being 1 point per insertion, -/// substitution, or deletion which must be made to make the strings equal. -/// -/// This function iterates over the bytes in the string, so it may not behave -/// entirely as expected for non-ASCII strings. -/// -/// # Complexity -/// -/// - time complexity: O(nm), -/// - space complexity: O(nm), -/// -/// where n and m are lengths of `str_a` and `str_b` -pub fn edit_distance(str_a: &str, str_b: &str) -> u32 { - // distances[i][j] = distance between a[..i] and b[..j] - let mut distances = vec![vec![0; str_b.len() + 1]; str_a.len() + 1]; - // Initialize cases in which one string is empty - for j in 0..=str_b.len() { - distances[0][j] = j as u32; - } - for (i, item) in distances.iter_mut().enumerate() { - item[0] = i as u32; - } - for i in 1..=str_a.len() { - for j in 1..=str_b.len() { - distances[i][j] = min(distances[i - 1][j] + 1, distances[i][j - 1] + 1); - if str_a.as_bytes()[i - 1] == str_b.as_bytes()[j - 1] { - distances[i][j] = min(distances[i][j], distances[i - 1][j - 1]); - } else { - distances[i][j] = min(distances[i][j], distances[i - 1][j - 1] + 1); - } - } - } - distances[str_a.len()][str_b.len()] -} - -/// The space efficient version of the above algorithm. -/// -/// Instead of storing the `m * n` matrix expicitly, only one row (of length `n`) is stored. -/// It keeps overwriting itself based on its previous values with the help of two scalars, -/// gradually reaching the last row. Then, the score is `matrix[n]`. -/// -/// # Complexity -/// -/// - time complexity: O(nm), -/// - space complexity: O(n), -/// -/// where n and m are lengths of `str_a` and `str_b` -pub fn edit_distance_se(str_a: &str, str_b: &str) -> u32 { - let (str_a, str_b) = (str_a.as_bytes(), str_b.as_bytes()); - let (m, n) = (str_a.len(), str_b.len()); - let mut distances: Vec = vec![0; n + 1]; // the dynamic programming matrix (only 1 row stored) - let mut s: u32; // distances[i - 1][j - 1] or distances[i - 1][j] - let mut c: u32; // distances[i][j - 1] or distances[i][j] - let mut char_a: u8; // str_a[i - 1] the i-th character in str_a; only needs to be computed once per row - let mut char_b: u8; // str_b[j - 1] the j-th character in str_b - - // 0th row - for (j, v) in distances.iter_mut().enumerate().take(n + 1).skip(1) { - *v = j as u32; - } - // rows 1 to m - for i in 1..=m { - s = (i - 1) as u32; - c = i as u32; - char_a = str_a[i - 1]; - for j in 1..=n { - // c is distances[i][j-1] and s is distances[i-1][j-1] at the beginning of each round of iteration - char_b = str_b[j - 1]; - c = min( - s + if char_a == char_b { 0 } else { 1 }, - min(c + 1, distances[j] + 1), - ); - // c is updated to distances[i][j], and will thus become distances[i][j-1] for the next cell - s = distances[j]; // here distances[j] means distances[i-1][j] becuase it has not been overwritten yet - // s is updated to distances[i-1][j], and will thus become distances[i-1][j-1] for the next cell - distances[j] = c; // now distances[j] is updated to distances[i][j], and will thus become distances[i-1][j] for the next ROW - } - } - - distances[n] -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn equal_strings() { - assert_eq!(0, edit_distance("Hello, world!", "Hello, world!")); - assert_eq!(0, edit_distance_se("Hello, world!", "Hello, world!")); - assert_eq!(0, edit_distance("Test_Case_#1", "Test_Case_#1")); - assert_eq!(0, edit_distance_se("Test_Case_#1", "Test_Case_#1")); - } - - #[test] - fn one_edit_difference() { - assert_eq!(1, edit_distance("Hello, world!", "Hell, world!")); - assert_eq!(1, edit_distance("Test_Case_#1", "Test_Case_#2")); - assert_eq!(1, edit_distance("Test_Case_#1", "Test_Case_#10")); - assert_eq!(1, edit_distance_se("Hello, world!", "Hell, world!")); - assert_eq!(1, edit_distance_se("Test_Case_#1", "Test_Case_#2")); - assert_eq!(1, edit_distance_se("Test_Case_#1", "Test_Case_#10")); - } - - #[test] - fn several_differences() { - assert_eq!(2, edit_distance("My Cat", "My Case")); - assert_eq!(7, edit_distance("Hello, world!", "Goodbye, world!")); - assert_eq!(6, edit_distance("Test_Case_#3", "Case #3")); - assert_eq!(2, edit_distance_se("My Cat", "My Case")); - assert_eq!(7, edit_distance_se("Hello, world!", "Goodbye, world!")); - assert_eq!(6, edit_distance_se("Test_Case_#3", "Case #3")); - } -} diff --git a/src/dynamic_programming/egg_dropping.rs b/src/dynamic_programming/egg_dropping.rs index 1a403637ec5..ab24494c014 100644 --- a/src/dynamic_programming/egg_dropping.rs +++ b/src/dynamic_programming/egg_dropping.rs @@ -1,91 +1,84 @@ -/// # Egg Dropping Puzzle +//! This module contains the `egg_drop` function, which determines the minimum number of egg droppings +//! required to find the highest floor from which an egg can be dropped without breaking. It also includes +//! tests for the function using various test cases, including edge cases. -/// `egg_drop(eggs, floors)` returns the least number of egg droppings -/// required to determine the highest floor from which an egg will not -/// break upon dropping +/// Returns the least number of egg droppings required to determine the highest floor from which an egg will not break upon dropping. /// -/// Assumptions: n > 0 -pub fn egg_drop(eggs: u32, floors: u32) -> u32 { - assert!(eggs > 0); - - // Explicity handle edge cases (optional) - if eggs == 1 || floors == 0 || floors == 1 { - return floors; - } - - let eggs_index = eggs as usize; - let floors_index = floors as usize; - - // Store solutions to subproblems in 2D Vec, - // where egg_drops[i][j] represents the solution to the egg dropping - // problem with i eggs and j floors - let mut egg_drops: Vec> = vec![vec![0; floors_index + 1]; eggs_index + 1]; - - // Assign solutions for egg_drop(n, 0) = 0, egg_drop(n, 1) = 1 - for egg_drop in egg_drops.iter_mut().skip(1) { - egg_drop[0] = 0; - egg_drop[1] = 1; - } - - // Assign solutions to egg_drop(1, k) = k - for j in 1..=floors_index { - egg_drops[1][j] = j as u32; +/// # Arguments +/// +/// * `eggs` - The number of eggs available. +/// * `floors` - The number of floors in the building. +/// +/// # Returns +/// +/// * `Some(usize)` - The minimum number of drops required if the number of eggs is greater than 0. +/// * `None` - If the number of eggs is 0. +pub fn egg_drop(eggs: usize, floors: usize) -> Option { + if eggs == 0 { + return None; } - // Complete solutions vector using optimal substructure property - for i in 2..=eggs_index { - for j in 2..=floors_index { - egg_drops[i][j] = std::u32::MAX; - - for k in 1..=j { - let res = 1 + std::cmp::max(egg_drops[i - 1][k - 1], egg_drops[i][j - k]); - - if res < egg_drops[i][j] { - egg_drops[i][j] = res; - } - } - } + if eggs == 1 || floors == 0 || floors == 1 { + return Some(floors); } - egg_drops[eggs_index][floors_index] + // Create a 2D vector to store solutions to subproblems + let mut egg_drops: Vec> = vec![vec![0; floors + 1]; eggs + 1]; + + // Base cases: 0 floors -> 0 drops, 1 floor -> 1 drop + (1..=eggs).for_each(|i| { + egg_drops[i][1] = 1; + }); + + // Base case: 1 egg -> k drops for k floors + (1..=floors).for_each(|j| { + egg_drops[1][j] = j; + }); + + // Fill the table using the optimal substructure property + (2..=eggs).for_each(|i| { + (2..=floors).for_each(|j| { + egg_drops[i][j] = (1..=j) + .map(|k| 1 + std::cmp::max(egg_drops[i - 1][k - 1], egg_drops[i][j - k])) + .min() + .unwrap(); + }); + }); + + Some(egg_drops[eggs][floors]) } #[cfg(test)] mod tests { - use super::egg_drop; - - #[test] - fn zero_floors() { - assert_eq!(egg_drop(5, 0), 0); - } - - #[test] - fn one_egg() { - assert_eq!(egg_drop(1, 8), 8); - } - - #[test] - fn eggs2_floors2() { - assert_eq!(egg_drop(2, 2), 2); - } - - #[test] - fn eggs3_floors5() { - assert_eq!(egg_drop(3, 5), 3); - } - - #[test] - fn eggs2_floors10() { - assert_eq!(egg_drop(2, 10), 4); - } - - #[test] - fn eggs2_floors36() { - assert_eq!(egg_drop(2, 36), 8); + use super::*; + + macro_rules! egg_drop_tests { + ($($name:ident: $test_cases:expr,)*) => { + $( + #[test] + fn $name() { + let (eggs, floors, expected) = $test_cases; + assert_eq!(egg_drop(eggs, floors), expected); + } + )* + } } - #[test] - fn large_floors() { - assert_eq!(egg_drop(2, 100), 14); + egg_drop_tests! { + test_no_floors: (5, 0, Some(0)), + test_one_egg_multiple_floors: (1, 8, Some(8)), + test_multiple_eggs_one_floor: (5, 1, Some(1)), + test_two_eggs_two_floors: (2, 2, Some(2)), + test_three_eggs_five_floors: (3, 5, Some(3)), + test_two_eggs_ten_floors: (2, 10, Some(4)), + test_two_eggs_thirty_six_floors: (2, 36, Some(8)), + test_many_eggs_one_floor: (100, 1, Some(1)), + test_many_eggs_few_floors: (100, 5, Some(3)), + test_few_eggs_many_floors: (2, 1000, Some(45)), + test_zero_eggs: (0, 10, None::), + test_no_eggs_no_floors: (0, 0, None::), + test_one_egg_no_floors: (1, 0, Some(0)), + test_one_egg_one_floor: (1, 1, Some(1)), + test_maximum_floors_one_egg: (1, usize::MAX, Some(usize::MAX)), } } diff --git a/src/dynamic_programming/fibonacci.rs b/src/dynamic_programming/fibonacci.rs index cdffd219ab0..f1a55ce77f1 100644 --- a/src/dynamic_programming/fibonacci.rs +++ b/src/dynamic_programming/fibonacci.rs @@ -1,4 +1,5 @@ /// Fibonacci via Dynamic Programming +use std::collections::HashMap; /// fibonacci(n) returns the nth fibonacci number /// This function uses the definition of Fibonacci where: @@ -38,9 +39,253 @@ fn _recursive_fibonacci(n: u32, previous: u128, current: u128) -> u128 { } } +/// classical_fibonacci(n) returns the nth fibonacci number +/// This function uses the definition of Fibonacci where: +/// F(0) = 0, F(1) = 1 and F(n+1) = F(n) + F(n-1) for n>0 +/// +/// Warning: This will overflow the 128-bit unsigned integer at n=186 +pub fn classical_fibonacci(n: u32) -> u128 { + match n { + 0 => 0, + 1 => 1, + _ => { + let k = n / 2; + let f1 = classical_fibonacci(k); + let f2 = classical_fibonacci(k - 1); + + match n % 4 { + 0 | 2 => f1 * (f1 + 2 * f2), + 1 => (2 * f1 + f2) * (2 * f1 - f2) + 2, + _ => (2 * f1 + f2) * (2 * f1 - f2) - 2, + } + } + } +} + +/// logarithmic_fibonacci(n) returns the nth fibonacci number +/// This function uses the definition of Fibonacci where: +/// F(0) = 0, F(1) = 1 and F(n+1) = F(n) + F(n-1) for n>0 +/// +/// Warning: This will overflow the 128-bit unsigned integer at n=186 +pub fn logarithmic_fibonacci(n: u32) -> u128 { + // if it is the max value before overflow, use n-1 then get the second + // value in the tuple + if n == 186 { + let (_, second) = _logarithmic_fibonacci(185); + second + } else { + let (first, _) = _logarithmic_fibonacci(n); + first + } +} + +fn _logarithmic_fibonacci(n: u32) -> (u128, u128) { + match n { + 0 => (0, 1), + _ => { + let (current, next) = _logarithmic_fibonacci(n / 2); + let c = current * (next * 2 - current); + let d = current * current + next * next; + + match n % 2 { + 0 => (c, d), + _ => (d, c + d), + } + } + } +} + +/// Memoized fibonacci. +pub fn memoized_fibonacci(n: u32) -> u128 { + let mut cache: HashMap = HashMap::new(); + + _memoized_fibonacci(n, &mut cache) +} + +fn _memoized_fibonacci(n: u32, cache: &mut HashMap) -> u128 { + if n == 0 { + return 0; + } + if n == 1 { + return 1; + } + + let f = match cache.get(&n) { + Some(f) => f, + None => { + let f1 = _memoized_fibonacci(n - 1, cache); + let f2 = _memoized_fibonacci(n - 2, cache); + cache.insert(n, f1 + f2); + cache.get(&n).unwrap() + } + }; + + *f +} + +/// matrix_fibonacci(n) returns the nth fibonacci number +/// This function uses the definition of Fibonacci where: +/// F(0) = 0, F(1) = 1 and F(n+1) = F(n) + F(n-1) for n>0 +/// +/// Matrix formula: +/// [F(n + 2)] = [1, 1] * [F(n + 1)] +/// [F(n + 1)] [1, 0] [F(n) ] +/// +/// Warning: This will overflow the 128-bit unsigned integer at n=186 +pub fn matrix_fibonacci(n: u32) -> u128 { + let multiplier: Vec> = vec![vec![1, 1], vec![1, 0]]; + + let multiplier = matrix_power(&multiplier, n); + let initial_fib_matrix: Vec> = vec![vec![1], vec![0]]; + + let res = matrix_multiply(&multiplier, &initial_fib_matrix); + + res[1][0] +} + +fn matrix_power(base: &Vec>, power: u32) -> Vec> { + let identity_matrix: Vec> = vec![vec![1, 0], vec![0, 1]]; + + vec![base; power as usize] + .iter() + .fold(identity_matrix, |acc, x| matrix_multiply(&acc, x)) +} + +// Copied from matrix_ops since u128 is required instead of i32 +#[allow(clippy::needless_range_loop)] +fn matrix_multiply(multiplier: &[Vec], multiplicand: &[Vec]) -> Vec> { + // Multiply two matching matrices. The multiplier needs to have the same amount + // of columns as the multiplicand has rows. + let mut result: Vec> = vec![]; + let mut temp; + // Using variable to compare lengths of rows in multiplicand later + let row_right_length = multiplicand[0].len(); + for row_left in 0..multiplier.len() { + if multiplier[row_left].len() != multiplicand.len() { + panic!("Matrix dimensions do not match"); + } + result.push(vec![]); + for column_right in 0..multiplicand[0].len() { + temp = 0; + for row_right in 0..multiplicand.len() { + if row_right_length != multiplicand[row_right].len() { + // If row is longer than a previous row cancel operation with error + panic!("Matrix dimensions do not match"); + } + temp += multiplier[row_left][row_right] * multiplicand[row_right][column_right]; + } + result[row_left].push(temp); + } + } + result +} + +/// Binary lifting fibonacci +/// +/// Following properties of F(n) could be deduced from the matrix formula above: +/// +/// F(2n) = F(n) * (2F(n+1) - F(n)) +/// F(2n+1) = F(n+1)^2 + F(n)^2 +/// +/// Therefore F(n) and F(n+1) can be derived from F(n>>1) and F(n>>1 + 1), which +/// has a smaller constant in both time and space compared to matrix fibonacci. +pub fn binary_lifting_fibonacci(n: u32) -> u128 { + // the state always stores F(k), F(k+1) for some k, initially F(0), F(1) + let mut state = (0u128, 1u128); + + for i in (0..u32::BITS - n.leading_zeros()).rev() { + // compute F(2k), F(2k+1) from F(k), F(k+1) + state = ( + state.0 * (2 * state.1 - state.0), + state.0 * state.0 + state.1 * state.1, + ); + if n & (1 << i) != 0 { + state = (state.1, state.0 + state.1); + } + } + + state.0 +} + +/// nth_fibonacci_number_modulo_m(n, m) returns the nth fibonacci number modulo the specified m +/// i.e. F(n) % m +pub fn nth_fibonacci_number_modulo_m(n: i64, m: i64) -> i128 { + let (length, pisano_sequence) = get_pisano_sequence_and_period(m); + + let remainder = n % length as i64; + pisano_sequence[remainder as usize].to_owned() +} + +/// get_pisano_sequence_and_period(m) returns the Pisano Sequence and period for the specified integer m. +/// The pisano period is the period with which the sequence of Fibonacci numbers taken modulo m repeats. +/// The pisano sequence is the numbers in pisano period. +fn get_pisano_sequence_and_period(m: i64) -> (i128, Vec) { + let mut a = 0; + let mut b = 1; + let mut length: i128 = 0; + let mut pisano_sequence: Vec = vec![a, b]; + + // Iterating through all the fib numbers to get the sequence + for _i in 0..=(m * m) { + let c = (a + b) % m as i128; + + // adding number into the sequence + pisano_sequence.push(c); + + a = b; + b = c; + + if a == 0 && b == 1 { + // Remove the last two elements from the sequence + // This is a less elegant way to do it. + pisano_sequence.pop(); + pisano_sequence.pop(); + length = pisano_sequence.len() as i128; + break; + } + } + + (length, pisano_sequence) +} + +/// last_digit_of_the_sum_of_nth_fibonacci_number(n) returns the last digit of the sum of n fibonacci numbers. +/// The function uses the definition of Fibonacci where: +/// F(0) = 0, F(1) = 1 and F(n+1) = F(n) + F(n-1) for n > 2 +/// +/// The sum of the Fibonacci numbers are: +/// F(0) + F(1) + F(2) + ... + F(n) +pub fn last_digit_of_the_sum_of_nth_fibonacci_number(n: i64) -> i64 { + if n < 2 { + return n; + } + + // the pisano period of mod 10 is 60 + let n = ((n + 2) % 60) as usize; + let mut fib = vec![0; n + 1]; + fib[0] = 0; + fib[1] = 1; + + for i in 2..=n { + fib[i] = (fib[i - 1] % 10 + fib[i - 2] % 10) % 10; + } + + if fib[n] == 0 { + return 9; + } + + fib[n] % 10 - 1 +} + #[cfg(test)] mod tests { + use super::binary_lifting_fibonacci; + use super::classical_fibonacci; use super::fibonacci; + use super::last_digit_of_the_sum_of_nth_fibonacci_number; + use super::logarithmic_fibonacci; + use super::matrix_fibonacci; + use super::memoized_fibonacci; + use super::nth_fibonacci_number_modulo_m; use super::recursive_fibonacci; #[test] @@ -73,4 +318,157 @@ mod tests { 205697230343233228174223751303346572685 ); } + + #[test] + fn test_classical_fibonacci() { + assert_eq!(classical_fibonacci(0), 0); + assert_eq!(classical_fibonacci(1), 1); + assert_eq!(classical_fibonacci(2), 1); + assert_eq!(classical_fibonacci(3), 2); + assert_eq!(classical_fibonacci(4), 3); + assert_eq!(classical_fibonacci(5), 5); + assert_eq!(classical_fibonacci(10), 55); + assert_eq!(classical_fibonacci(20), 6765); + assert_eq!(classical_fibonacci(21), 10946); + assert_eq!(classical_fibonacci(100), 354224848179261915075); + assert_eq!( + classical_fibonacci(184), + 127127879743834334146972278486287885163 + ); + } + + #[test] + fn test_logarithmic_fibonacci() { + assert_eq!(logarithmic_fibonacci(0), 0); + assert_eq!(logarithmic_fibonacci(1), 1); + assert_eq!(logarithmic_fibonacci(2), 1); + assert_eq!(logarithmic_fibonacci(3), 2); + assert_eq!(logarithmic_fibonacci(4), 3); + assert_eq!(logarithmic_fibonacci(5), 5); + assert_eq!(logarithmic_fibonacci(10), 55); + assert_eq!(logarithmic_fibonacci(20), 6765); + assert_eq!(logarithmic_fibonacci(21), 10946); + assert_eq!(logarithmic_fibonacci(100), 354224848179261915075); + assert_eq!( + logarithmic_fibonacci(184), + 127127879743834334146972278486287885163 + ); + } + + #[test] + /// Check that the iterative and recursive fibonacci + /// produce the same value. Both are combinatorial ( F(0) = F(1) = 1 ) + fn test_iterative_and_recursive_equivalence() { + assert_eq!(fibonacci(0), recursive_fibonacci(0)); + assert_eq!(fibonacci(1), recursive_fibonacci(1)); + assert_eq!(fibonacci(2), recursive_fibonacci(2)); + assert_eq!(fibonacci(3), recursive_fibonacci(3)); + assert_eq!(fibonacci(4), recursive_fibonacci(4)); + assert_eq!(fibonacci(5), recursive_fibonacci(5)); + assert_eq!(fibonacci(10), recursive_fibonacci(10)); + assert_eq!(fibonacci(20), recursive_fibonacci(20)); + assert_eq!(fibonacci(100), recursive_fibonacci(100)); + assert_eq!(fibonacci(184), recursive_fibonacci(184)); + } + + #[test] + /// Check that classical and combinatorial fibonacci produce the + /// same value when 'n' differs by 1. + /// classical fibonacci: ( F(0) = 0, F(1) = 1 ) + /// combinatorial fibonacci: ( F(0) = F(1) = 1 ) + fn test_classical_and_combinatorial_are_off_by_one() { + assert_eq!(classical_fibonacci(1), fibonacci(0)); + assert_eq!(classical_fibonacci(2), fibonacci(1)); + assert_eq!(classical_fibonacci(3), fibonacci(2)); + assert_eq!(classical_fibonacci(4), fibonacci(3)); + assert_eq!(classical_fibonacci(5), fibonacci(4)); + assert_eq!(classical_fibonacci(6), fibonacci(5)); + assert_eq!(classical_fibonacci(11), fibonacci(10)); + assert_eq!(classical_fibonacci(20), fibonacci(19)); + assert_eq!(classical_fibonacci(21), fibonacci(20)); + assert_eq!(classical_fibonacci(101), fibonacci(100)); + assert_eq!(classical_fibonacci(185), fibonacci(184)); + } + + #[test] + fn test_memoized_fibonacci() { + assert_eq!(memoized_fibonacci(0), 0); + assert_eq!(memoized_fibonacci(1), 1); + assert_eq!(memoized_fibonacci(2), 1); + assert_eq!(memoized_fibonacci(3), 2); + assert_eq!(memoized_fibonacci(4), 3); + assert_eq!(memoized_fibonacci(5), 5); + assert_eq!(memoized_fibonacci(10), 55); + assert_eq!(memoized_fibonacci(20), 6765); + assert_eq!(memoized_fibonacci(21), 10946); + assert_eq!(memoized_fibonacci(100), 354224848179261915075); + assert_eq!( + memoized_fibonacci(184), + 127127879743834334146972278486287885163 + ); + } + + #[test] + fn test_matrix_fibonacci() { + assert_eq!(matrix_fibonacci(0), 0); + assert_eq!(matrix_fibonacci(1), 1); + assert_eq!(matrix_fibonacci(2), 1); + assert_eq!(matrix_fibonacci(3), 2); + assert_eq!(matrix_fibonacci(4), 3); + assert_eq!(matrix_fibonacci(5), 5); + assert_eq!(matrix_fibonacci(10), 55); + assert_eq!(matrix_fibonacci(20), 6765); + assert_eq!(matrix_fibonacci(21), 10946); + assert_eq!(matrix_fibonacci(100), 354224848179261915075); + assert_eq!( + matrix_fibonacci(184), + 127127879743834334146972278486287885163 + ); + } + + #[test] + fn test_binary_lifting_fibonacci() { + assert_eq!(binary_lifting_fibonacci(0), 0); + assert_eq!(binary_lifting_fibonacci(1), 1); + assert_eq!(binary_lifting_fibonacci(2), 1); + assert_eq!(binary_lifting_fibonacci(3), 2); + assert_eq!(binary_lifting_fibonacci(4), 3); + assert_eq!(binary_lifting_fibonacci(5), 5); + assert_eq!(binary_lifting_fibonacci(10), 55); + assert_eq!(binary_lifting_fibonacci(20), 6765); + assert_eq!(binary_lifting_fibonacci(21), 10946); + assert_eq!(binary_lifting_fibonacci(100), 354224848179261915075); + assert_eq!( + binary_lifting_fibonacci(184), + 127127879743834334146972278486287885163 + ); + } + + #[test] + fn test_nth_fibonacci_number_modulo_m() { + assert_eq!(nth_fibonacci_number_modulo_m(5, 10), 5); + assert_eq!(nth_fibonacci_number_modulo_m(10, 7), 6); + assert_eq!(nth_fibonacci_number_modulo_m(20, 100), 65); + assert_eq!(nth_fibonacci_number_modulo_m(1, 5), 1); + assert_eq!(nth_fibonacci_number_modulo_m(0, 15), 0); + assert_eq!(nth_fibonacci_number_modulo_m(50, 1000), 25); + assert_eq!(nth_fibonacci_number_modulo_m(100, 37), 7); + assert_eq!(nth_fibonacci_number_modulo_m(15, 2), 0); + assert_eq!(nth_fibonacci_number_modulo_m(8, 1_000_000), 21); + assert_eq!(nth_fibonacci_number_modulo_m(1000, 997), 996); + assert_eq!(nth_fibonacci_number_modulo_m(200, 123), 0); + } + + #[test] + fn test_last_digit_of_the_sum_of_nth_fibonacci_number() { + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(0), 0); + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(1), 1); + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(2), 2); + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(3), 4); + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(4), 7); + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(5), 2); + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(25), 7); + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(50), 8); + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(100), 5); + } } diff --git a/src/dynamic_programming/fractional_knapsack.rs b/src/dynamic_programming/fractional_knapsack.rs new file mode 100644 index 00000000000..b05a1b18f2f --- /dev/null +++ b/src/dynamic_programming/fractional_knapsack.rs @@ -0,0 +1,98 @@ +pub fn fractional_knapsack(mut capacity: f64, weights: Vec, values: Vec) -> f64 { + // vector of tuple of weights and their value/weight ratio + let mut weights: Vec<(f64, f64)> = weights + .iter() + .zip(values.iter()) + .map(|(&w, &v)| (w, v / w)) + .collect(); + + // sort in decreasing order by value/weight ratio + weights.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).expect("Encountered NaN")); + dbg!(&weights); + + // value to compute + let mut knapsack_value: f64 = 0.0; + + // iterate through our vector. + for w in weights { + // w.0 is weight and w.1 value/weight ratio + if w.0 < capacity { + capacity -= w.0; // our sack is filling + knapsack_value += w.0 * w.1; + dbg!(&w.0, &knapsack_value); + } else { + // Multiply with capacity and not w.0 + dbg!(&w.0, &knapsack_value); + knapsack_value += capacity * w.1; + break; + } + } + + knapsack_value +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test() { + let capacity = 50.0; + let values = vec![60.0, 100.0, 120.0]; + let weights = vec![10.0, 20.0, 30.0]; + assert_eq!(fractional_knapsack(capacity, weights, values), 240.0); + } + + #[test] + fn test2() { + let capacity = 60.0; + let values = vec![280.0, 100.0, 120.0, 120.0]; + let weights = vec![40.0, 10.0, 20.0, 24.0]; + assert_eq!(fractional_knapsack(capacity, weights, values), 440.0); + } + + #[test] + fn test3() { + let capacity = 50.0; + let values = vec![60.0, 100.0, 120.0]; + let weights = vec![20.0, 50.0, 30.0]; + assert_eq!(fractional_knapsack(capacity, weights, values), 180.0); + } + + #[test] + fn test4() { + let capacity = 60.0; + let values = vec![30.0, 40.0, 45.0, 77.0, 90.0]; + let weights = vec![5.0, 10.0, 15.0, 22.0, 25.0]; + assert_eq!(fractional_knapsack(capacity, weights, values), 230.0); + } + + #[test] + fn test5() { + let capacity = 10.0; + let values = vec![500.0]; + let weights = vec![30.0]; + assert_eq!( + format!("{:.2}", fractional_knapsack(capacity, weights, values)), + String::from("166.67") + ); + } + + #[test] + fn test6() { + let capacity = 36.0; + let values = vec![25.0, 25.0, 25.0, 6.0, 2.0]; + let weights = vec![10.0, 10.0, 10.0, 4.0, 2.0]; + assert_eq!(fractional_knapsack(capacity, weights, values), 83.0); + } + + #[test] + #[should_panic] + fn test_nan() { + let capacity = 36.0; + // 2nd element is NaN + let values = vec![25.0, f64::NAN, 25.0, 6.0, 2.0]; + let weights = vec![10.0, 10.0, 10.0, 4.0, 2.0]; + assert_eq!(fractional_knapsack(capacity, weights, values), 83.0); + } +} diff --git a/src/dynamic_programming/is_subsequence.rs b/src/dynamic_programming/is_subsequence.rs new file mode 100644 index 00000000000..22b43c387b1 --- /dev/null +++ b/src/dynamic_programming/is_subsequence.rs @@ -0,0 +1,71 @@ +//! A module for checking if one string is a subsequence of another string. +//! +//! A subsequence is formed by deleting some (can be none) of the characters +//! from the original string without disturbing the relative positions of the +//! remaining characters. This module provides a function to determine if +//! a given string is a subsequence of another string. + +/// Checks if `sub` is a subsequence of `main`. +/// +/// # Arguments +/// +/// * `sub` - A string slice that may be a subsequence. +/// * `main` - A string slice that is checked against. +/// +/// # Returns +/// +/// Returns `true` if `sub` is a subsequence of `main`, otherwise returns `false`. +pub fn is_subsequence(sub: &str, main: &str) -> bool { + let mut sub_iter = sub.chars().peekable(); + let mut main_iter = main.chars(); + + while let Some(&sub_char) = sub_iter.peek() { + match main_iter.next() { + Some(main_char) if main_char == sub_char => { + sub_iter.next(); + } + None => return false, + _ => {} + } + } + + true +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! subsequence_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (sub, main, expected) = $test_case; + assert_eq!(is_subsequence(sub, main), expected); + } + )* + }; + } + + subsequence_tests! { + test_empty_subsequence: ("", "ahbgdc", true), + test_empty_strings: ("", "", true), + test_non_empty_sub_empty_main: ("abc", "", false), + test_subsequence_found: ("abc", "ahbgdc", true), + test_subsequence_not_found: ("axc", "ahbgdc", false), + test_longer_sub: ("abcd", "abc", false), + test_single_character_match: ("a", "ahbgdc", true), + test_single_character_not_match: ("x", "ahbgdc", false), + test_subsequence_at_start: ("abc", "abchello", true), + test_subsequence_at_end: ("cde", "abcde", true), + test_same_characters: ("aaa", "aaaaa", true), + test_interspersed_subsequence: ("ace", "abcde", true), + test_different_chars_in_subsequence: ("aceg", "abcdef", false), + test_single_character_in_main_not_match: ("a", "b", false), + test_single_character_in_main_match: ("b", "b", true), + test_subsequence_with_special_chars: ("a1!c", "a1!bcd", true), + test_case_sensitive: ("aBc", "abc", false), + test_subsequence_with_whitespace: ("hello world", "h e l l o w o r l d", true), + } +} diff --git a/src/dynamic_programming/knapsack.rs b/src/dynamic_programming/knapsack.rs index b9edf1bda88..36876a15bd8 100644 --- a/src/dynamic_programming/knapsack.rs +++ b/src/dynamic_programming/knapsack.rs @@ -1,148 +1,363 @@ -//! Solves the knapsack problem -use std::cmp::max; +//! This module provides functionality to solve the knapsack problem using dynamic programming. +//! It includes structures for items and solutions, and functions to compute the optimal solution. -/// knapsack_table(w, weights, values) returns the knapsack table (`n`, `m`) with maximum values, where `n` is number of items -/// -/// Arguments: -/// * `w` - knapsack capacity -/// * `weights` - set of weights for each item -/// * `values` - set of values for each item -fn knapsack_table(w: &usize, weights: &[usize], values: &[usize]) -> Vec> { - // Initialize `n` - number of items - let n: usize = weights.len(); - // Initialize `m` - // m[i, w] - the maximum value that can be attained with weight less that or equal to `w` using items up to `i` - let mut m: Vec> = vec![vec![0; w + 1]; n + 1]; +use std::cmp::Ordering; - for i in 0..=n { - for j in 0..=*w { - // m[i, j] compiled according to the following rule: - if i == 0 || j == 0 { - m[i][j] = 0; - } else if weights[i - 1] <= j { - // If `i` is in the knapsack - // Then m[i, j] is equal to the maximum value of the knapsack, - // where the weight `j` is reduced by the weight of the `i-th` item and the set of admissible items plus the value `k` - m[i][j] = max(values[i - 1] + m[i - 1][j - weights[i - 1]], m[i - 1][j]); - } else { - // If the item `i` did not get into the knapsack - // Then m[i, j] is equal to the maximum cost of a knapsack with the same capacity and a set of admissible items - m[i][j] = m[i - 1][j] - } - } - } - m +/// Represents an item with a weight and a value. +#[derive(Debug, PartialEq, Eq)] +pub struct Item { + weight: usize, + value: usize, +} + +/// Represents the solution to the knapsack problem. +#[derive(Debug, PartialEq, Eq)] +pub struct KnapsackSolution { + /// The optimal profit obtained. + optimal_profit: usize, + /// The total weight of items included in the solution. + total_weight: usize, + /// The indices of items included in the solution. Indices might not be unique. + item_indices: Vec, } -/// knapsack_items(weights, m, i, j) returns the indices of the items of the optimal knapsack (from 1 to `n`) +/// Solves the knapsack problem and returns the optimal profit, total weight, and indices of items included. /// -/// Arguments: -/// * `weights` - set of weights for each item -/// * `m` - knapsack table with maximum values -/// * `i` - include items 1 through `i` in knapsack (for the initial value, use `n`) -/// * `j` - maximum weight of the knapsack -fn knapsack_items(weights: &[usize], m: &[Vec], i: usize, j: usize) -> Vec { - if i == 0 { - return vec![]; - } - if m[i][j] > m[i - 1][j] { - let mut knap: Vec = knapsack_items(weights, m, i - 1, j - weights[i - 1]); - knap.push(i); - knap - } else { - knapsack_items(weights, m, i - 1, j) +/// # Arguments: +/// * `capacity` - The maximum weight capacity of the knapsack. +/// * `items` - A vector of `Item` structs, each representing an item with weight and value. +/// +/// # Returns: +/// A `KnapsackSolution` struct containing: +/// - `optimal_profit` - The maximum profit achievable with the given capacity and items. +/// - `total_weight` - The total weight of items included in the solution. +/// - `item_indices` - Indices of items included in the solution. Indices might not be unique. +/// +/// # Note: +/// The indices of items in the solution might not be unique. +/// This function assumes that `items` is non-empty. +/// +/// # Complexity: +/// - Time complexity: O(num_items * capacity) +/// - Space complexity: O(num_items * capacity) +/// +/// where `num_items` is the number of items and `capacity` is the knapsack capacity. +pub fn knapsack(capacity: usize, items: Vec) -> KnapsackSolution { + let num_items = items.len(); + let item_weights: Vec = items.iter().map(|item| item.weight).collect(); + let item_values: Vec = items.iter().map(|item| item.value).collect(); + + let knapsack_matrix = generate_knapsack_matrix(capacity, &item_weights, &item_values); + let items_included = + retrieve_knapsack_items(&item_weights, &knapsack_matrix, num_items, capacity); + + let total_weight = items_included + .iter() + .map(|&index| item_weights[index - 1]) + .sum(); + + KnapsackSolution { + optimal_profit: knapsack_matrix[num_items][capacity], + total_weight, + item_indices: items_included, } } -/// knapsack(w, weights, values) returns the tuple where first value is `optimal profit`, -/// second value is `knapsack optimal weight` and the last value is `indices of items`, that we got (from 1 to `n`) +/// Generates the knapsack matrix (`num_items`, `capacity`) with maximum values. /// -/// Arguments: -/// * `w` - knapsack capacity -/// * `weights` - set of weights for each item -/// * `values` - set of values for each item +/// # Arguments: +/// * `capacity` - knapsack capacity +/// * `item_weights` - weights of each item +/// * `item_values` - values of each item +fn generate_knapsack_matrix( + capacity: usize, + item_weights: &[usize], + item_values: &[usize], +) -> Vec> { + let num_items = item_weights.len(); + + (0..=num_items).fold( + vec![vec![0; capacity + 1]; num_items + 1], + |mut matrix, item_index| { + (0..=capacity).for_each(|current_capacity| { + matrix[item_index][current_capacity] = if item_index == 0 || current_capacity == 0 { + 0 + } else if item_weights[item_index - 1] <= current_capacity { + usize::max( + item_values[item_index - 1] + + matrix[item_index - 1] + [current_capacity - item_weights[item_index - 1]], + matrix[item_index - 1][current_capacity], + ) + } else { + matrix[item_index - 1][current_capacity] + }; + }); + matrix + }, + ) +} + +/// Retrieves the indices of items included in the optimal knapsack solution. /// -/// Complexity -/// - time complexity: O(nw), -/// - space complexity: O(nw), +/// # Arguments: +/// * `item_weights` - weights of each item +/// * `knapsack_matrix` - knapsack matrix with maximum values +/// * `item_index` - number of items to consider (initially the total number of items) +/// * `remaining_capacity` - remaining capacity of the knapsack /// -/// where `n` and `w` are `number of items` and `knapsack capacity` -pub fn knapsack(w: usize, weights: Vec, values: Vec) -> (usize, usize, Vec) { - // Checks if the number of items in the list of weights is the same as the number of items in the list of values - assert_eq!(weights.len(), values.len(), "Number of items in the list of weights doesn't match the number of items in the list of values!"); - // Initialize `n` - number of items - let n: usize = weights.len(); - // Find the knapsack table - let m: Vec> = knapsack_table(&w, &weights, &values); - // Find the indices of the items - let items: Vec = knapsack_items(&weights, &m, n, w); - // Find the total weight of optimal knapsack - let mut total_weight: usize = 0; - for i in items.iter() { - total_weight += weights[i - 1]; +/// # Returns +/// A vector of item indices included in the optimal solution. The indices might not be unique. +fn retrieve_knapsack_items( + item_weights: &[usize], + knapsack_matrix: &[Vec], + item_index: usize, + remaining_capacity: usize, +) -> Vec { + match item_index { + 0 => vec![], + _ => { + let current_value = knapsack_matrix[item_index][remaining_capacity]; + let previous_value = knapsack_matrix[item_index - 1][remaining_capacity]; + + match current_value.cmp(&previous_value) { + Ordering::Greater => { + let mut knap = retrieve_knapsack_items( + item_weights, + knapsack_matrix, + item_index - 1, + remaining_capacity - item_weights[item_index - 1], + ); + knap.push(item_index); + knap + } + Ordering::Equal | Ordering::Less => retrieve_knapsack_items( + item_weights, + knapsack_matrix, + item_index - 1, + remaining_capacity, + ), + } + } } - // Return result - (m[n][w], total_weight, items) } #[cfg(test)] mod tests { - // Took test datasets from https://people.sc.fsu.edu/~jburkardt/datasets/bin_packing/bin_packing.html use super::*; - #[test] - fn test_p02() { - assert_eq!( - (51, 26, vec![2, 3, 4]), - knapsack(26, vec![12, 7, 11, 8, 9], vec![24, 13, 23, 15, 16]) - ); - } - - #[test] - fn test_p04() { - assert_eq!( - (150, 190, vec![1, 2, 5]), - knapsack( - 190, - vec![56, 59, 80, 64, 75, 17], - vec![50, 50, 64, 46, 50, 5] - ) - ); - } - - #[test] - fn test_p01() { - assert_eq!( - (309, 165, vec![1, 2, 3, 4, 6]), - knapsack( - 165, - vec![23, 31, 29, 44, 53, 38, 63, 85, 89, 82], - vec![92, 57, 49, 68, 60, 43, 67, 84, 87, 72] - ) - ); - } - - #[test] - fn test_p06() { - assert_eq!( - (1735, 169, vec![2, 4, 7]), - knapsack( - 170, - vec![41, 50, 49, 59, 55, 57, 60], - vec![442, 525, 511, 593, 546, 564, 617] - ) - ); + macro_rules! knapsack_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (capacity, items, expected) = $test_case; + assert_eq!(expected, knapsack(capacity, items)); + } + )* + } } - #[test] - fn test_p07() { - assert_eq!( - (1458, 749, vec![1, 3, 5, 7, 8, 9, 14, 15]), - knapsack( - 750, - vec![70, 73, 77, 80, 82, 87, 90, 94, 98, 106, 110, 113, 115, 118, 120], - vec![135, 139, 149, 150, 156, 163, 173, 184, 192, 201, 210, 214, 221, 229, 240] - ) - ); + knapsack_tests! { + test_basic_knapsack_small: ( + 165, + vec![ + Item { weight: 23, value: 92 }, + Item { weight: 31, value: 57 }, + Item { weight: 29, value: 49 }, + Item { weight: 44, value: 68 }, + Item { weight: 53, value: 60 }, + Item { weight: 38, value: 43 }, + Item { weight: 63, value: 67 }, + Item { weight: 85, value: 84 }, + Item { weight: 89, value: 87 }, + Item { weight: 82, value: 72 } + ], + KnapsackSolution { + optimal_profit: 309, + total_weight: 165, + item_indices: vec![1, 2, 3, 4, 6] + } + ), + test_basic_knapsack_tiny: ( + 26, + vec![ + Item { weight: 12, value: 24 }, + Item { weight: 7, value: 13 }, + Item { weight: 11, value: 23 }, + Item { weight: 8, value: 15 }, + Item { weight: 9, value: 16 } + ], + KnapsackSolution { + optimal_profit: 51, + total_weight: 26, + item_indices: vec![2, 3, 4] + } + ), + test_basic_knapsack_medium: ( + 190, + vec![ + Item { weight: 56, value: 50 }, + Item { weight: 59, value: 50 }, + Item { weight: 80, value: 64 }, + Item { weight: 64, value: 46 }, + Item { weight: 75, value: 50 }, + Item { weight: 17, value: 5 } + ], + KnapsackSolution { + optimal_profit: 150, + total_weight: 190, + item_indices: vec![1, 2, 5] + } + ), + test_diverse_weights_values_small: ( + 50, + vec![ + Item { weight: 31, value: 70 }, + Item { weight: 10, value: 20 }, + Item { weight: 20, value: 39 }, + Item { weight: 19, value: 37 }, + Item { weight: 4, value: 7 }, + Item { weight: 3, value: 5 }, + Item { weight: 6, value: 10 } + ], + KnapsackSolution { + optimal_profit: 107, + total_weight: 50, + item_indices: vec![1, 4] + } + ), + test_diverse_weights_values_medium: ( + 104, + vec![ + Item { weight: 25, value: 350 }, + Item { weight: 35, value: 400 }, + Item { weight: 45, value: 450 }, + Item { weight: 5, value: 20 }, + Item { weight: 25, value: 70 }, + Item { weight: 3, value: 8 }, + Item { weight: 2, value: 5 }, + Item { weight: 2, value: 5 } + ], + KnapsackSolution { + optimal_profit: 900, + total_weight: 104, + item_indices: vec![1, 3, 4, 5, 7, 8] + } + ), + test_high_value_items: ( + 170, + vec![ + Item { weight: 41, value: 442 }, + Item { weight: 50, value: 525 }, + Item { weight: 49, value: 511 }, + Item { weight: 59, value: 593 }, + Item { weight: 55, value: 546 }, + Item { weight: 57, value: 564 }, + Item { weight: 60, value: 617 } + ], + KnapsackSolution { + optimal_profit: 1735, + total_weight: 169, + item_indices: vec![2, 4, 7] + } + ), + test_large_knapsack: ( + 750, + vec![ + Item { weight: 70, value: 135 }, + Item { weight: 73, value: 139 }, + Item { weight: 77, value: 149 }, + Item { weight: 80, value: 150 }, + Item { weight: 82, value: 156 }, + Item { weight: 87, value: 163 }, + Item { weight: 90, value: 173 }, + Item { weight: 94, value: 184 }, + Item { weight: 98, value: 192 }, + Item { weight: 106, value: 201 }, + Item { weight: 110, value: 210 }, + Item { weight: 113, value: 214 }, + Item { weight: 115, value: 221 }, + Item { weight: 118, value: 229 }, + Item { weight: 120, value: 240 } + ], + KnapsackSolution { + optimal_profit: 1458, + total_weight: 749, + item_indices: vec![1, 3, 5, 7, 8, 9, 14, 15] + } + ), + test_zero_capacity: ( + 0, + vec![ + Item { weight: 1, value: 1 }, + Item { weight: 2, value: 2 }, + Item { weight: 3, value: 3 } + ], + KnapsackSolution { + optimal_profit: 0, + total_weight: 0, + item_indices: vec![] + } + ), + test_very_small_capacity: ( + 1, + vec![ + Item { weight: 10, value: 1 }, + Item { weight: 20, value: 2 }, + Item { weight: 30, value: 3 } + ], + KnapsackSolution { + optimal_profit: 0, + total_weight: 0, + item_indices: vec![] + } + ), + test_no_items: ( + 1, + vec![], + KnapsackSolution { + optimal_profit: 0, + total_weight: 0, + item_indices: vec![] + } + ), + test_item_too_heavy: ( + 1, + vec![ + Item { weight: 2, value: 100 } + ], + KnapsackSolution { + optimal_profit: 0, + total_weight: 0, + item_indices: vec![] + } + ), + test_greedy_algorithm_does_not_work: ( + 10, + vec![ + Item { weight: 10, value: 15 }, + Item { weight: 6, value: 7 }, + Item { weight: 4, value: 9 } + ], + KnapsackSolution { + optimal_profit: 16, + total_weight: 10, + item_indices: vec![2, 3] + } + ), + test_greedy_algorithm_does_not_work_weight_smaller_than_capacity: ( + 10, + vec![ + Item { weight: 10, value: 15 }, + Item { weight: 1, value: 9 }, + Item { weight: 2, value: 7 } + ], + KnapsackSolution { + optimal_profit: 16, + total_weight: 3, + item_indices: vec![2, 3] + } + ), } } diff --git a/src/dynamic_programming/longest_common_subsequence.rs b/src/dynamic_programming/longest_common_subsequence.rs index a92ad50e26e..58f82714f93 100644 --- a/src/dynamic_programming/longest_common_subsequence.rs +++ b/src/dynamic_programming/longest_common_subsequence.rs @@ -1,73 +1,116 @@ -/// Longest common subsequence via Dynamic Programming +//! This module implements the Longest Common Subsequence (LCS) algorithm. +//! The LCS problem is finding the longest subsequence common to two sequences. +//! It differs from the problem of finding common substrings: unlike substrings, subsequences +//! are not required to occupy consecutive positions within the original sequences. +//! This implementation handles Unicode strings efficiently and correctly, ensuring +//! that multi-byte characters are managed properly. -/// longest_common_subsequence(a, b) returns the longest common subsequence -/// between the strings a and b. -pub fn longest_common_subsequence(a: &str, b: &str) -> String { - let a: Vec<_> = a.chars().collect(); - let b: Vec<_> = b.chars().collect(); - let (na, nb) = (a.len(), b.len()); +/// Computes the longest common subsequence of two input strings. +/// +/// The longest common subsequence (LCS) of two strings is the longest sequence that can +/// be derived from both strings by deleting some elements without changing the order of +/// the remaining elements. +/// +/// ## Note +/// The function may return different LCSs for the same pair of strings depending on the +/// order of the inputs and the nature of the sequences. This is due to the way the dynamic +/// programming algorithm resolves ties when multiple common subsequences of the same length +/// exist. The order of the input strings can influence the specific path taken through the +/// DP table, resulting in different valid LCS outputs. +/// +/// For example: +/// `longest_common_subsequence("hello, world!", "world, hello!")` returns `"hello!"` +/// but +/// `longest_common_subsequence("world, hello!", "hello, world!")` returns `"world!"` +/// +/// This difference arises because the dynamic programming table is filled differently based +/// on the input order, leading to different tie-breaking decisions and thus different LCS results. +pub fn longest_common_subsequence(first_seq: &str, second_seq: &str) -> String { + let first_seq_chars = first_seq.chars().collect::>(); + let second_seq_chars = second_seq.chars().collect::>(); - // solutions[i][j] is the length of the longest common subsequence - // between a[0..i-1] and b[0..j-1] - let mut solutions = vec![vec![0; nb + 1]; na + 1]; + let lcs_lengths = initialize_lcs_lengths(&first_seq_chars, &second_seq_chars); + let lcs_chars = reconstruct_lcs(&first_seq_chars, &second_seq_chars, &lcs_lengths); - for (i, ci) in a.iter().enumerate() { - for (j, cj) in b.iter().enumerate() { - // if ci == cj, there is a new common character; - // otherwise, take the best of the two solutions - // at (i-1,j) and (i,j-1) - solutions[i + 1][j + 1] = if ci == cj { - solutions[i][j] + 1 + lcs_chars.into_iter().collect() +} + +fn initialize_lcs_lengths(first_seq_chars: &[char], second_seq_chars: &[char]) -> Vec> { + let first_seq_len = first_seq_chars.len(); + let second_seq_len = second_seq_chars.len(); + + let mut lcs_lengths = vec![vec![0; second_seq_len + 1]; first_seq_len + 1]; + + // Populate the LCS lengths table + (1..=first_seq_len).for_each(|i| { + (1..=second_seq_len).for_each(|j| { + lcs_lengths[i][j] = if first_seq_chars[i - 1] == second_seq_chars[j - 1] { + lcs_lengths[i - 1][j - 1] + 1 } else { - solutions[i][j + 1].max(solutions[i + 1][j]) - } - } - } + lcs_lengths[i - 1][j].max(lcs_lengths[i][j - 1]) + }; + }); + }); - // reconstitute the solution string from the lengths - let mut result: Vec = Vec::new(); - let (mut i, mut j) = (na, nb); + lcs_lengths +} + +fn reconstruct_lcs( + first_seq_chars: &[char], + second_seq_chars: &[char], + lcs_lengths: &[Vec], +) -> Vec { + let mut lcs_chars = Vec::new(); + let mut i = first_seq_chars.len(); + let mut j = second_seq_chars.len(); while i > 0 && j > 0 { - if a[i - 1] == b[j - 1] { - result.push(a[i - 1]); + if first_seq_chars[i - 1] == second_seq_chars[j - 1] { + lcs_chars.push(first_seq_chars[i - 1]); i -= 1; j -= 1; - } else if solutions[i - 1][j] > solutions[i][j - 1] { + } else if lcs_lengths[i - 1][j] >= lcs_lengths[i][j - 1] { i -= 1; } else { j -= 1; } } - result.reverse(); - result.iter().collect() + lcs_chars.reverse(); + lcs_chars } #[cfg(test)] mod tests { - use super::longest_common_subsequence; - - #[test] - fn test_longest_common_subsequence() { - // empty case - assert_eq!(&longest_common_subsequence("", ""), ""); - assert_eq!(&longest_common_subsequence("", "abcd"), ""); - assert_eq!(&longest_common_subsequence("abcd", ""), ""); + use super::*; - // simple cases - assert_eq!(&longest_common_subsequence("abcd", "c"), "c"); - assert_eq!(&longest_common_subsequence("abcd", "d"), "d"); - assert_eq!(&longest_common_subsequence("abcd", "e"), ""); - assert_eq!(&longest_common_subsequence("abcdefghi", "acegi"), "acegi"); - - // less simple cases - assert_eq!(&longest_common_subsequence("abcdgh", "aedfhr"), "adh"); - assert_eq!(&longest_common_subsequence("aggtab", "gxtxayb"), "gtab"); + macro_rules! longest_common_subsequence_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (first_seq, second_seq, expected_lcs) = $test_case; + assert_eq!(longest_common_subsequence(&first_seq, &second_seq), expected_lcs); + } + )* + }; + } - // unicode - assert_eq!( - &longest_common_subsequence("你好,世界", "再见世界"), - "世界" - ); + longest_common_subsequence_tests! { + empty_case: ("", "", ""), + one_empty: ("", "abcd", ""), + identical_strings: ("abcd", "abcd", "abcd"), + completely_different: ("abcd", "efgh", ""), + single_character: ("a", "a", "a"), + different_length: ("abcd", "abc", "abc"), + special_characters: ("$#%&", "#@!%", "#%"), + long_strings: ("abcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefgh", + "bcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefgha", + "bcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefgh"), + unicode_characters: ("你好,世界", "再见,世界", ",世界"), + spaces_and_punctuation_0: ("hello, world!", "world, hello!", "hello!"), + spaces_and_punctuation_1: ("hello, world!", "world, hello!", "hello!"), // longest_common_subsequence is not symmetric + random_case_1: ("abcdef", "xbcxxxe", "bce"), + random_case_2: ("xyz", "abc", ""), + random_case_3: ("abracadabra", "avadakedavra", "aaadara"), } } diff --git a/src/dynamic_programming/longest_common_substring.rs b/src/dynamic_programming/longest_common_substring.rs new file mode 100644 index 00000000000..52b858e5008 --- /dev/null +++ b/src/dynamic_programming/longest_common_substring.rs @@ -0,0 +1,70 @@ +//! This module provides a function to find the length of the longest common substring +//! between two strings using dynamic programming. + +/// Finds the length of the longest common substring between two strings using dynamic programming. +/// +/// The algorithm uses a 2D dynamic programming table where each cell represents +/// the length of the longest common substring ending at the corresponding indices in +/// the two input strings. The maximum value in the DP table is the result, i.e., the +/// length of the longest common substring. +/// +/// The time complexity is `O(n * m)`, where `n` and `m` are the lengths of the two strings. +/// # Arguments +/// +/// * `s1` - The first input string. +/// * `s2` - The second input string. +/// +/// # Returns +/// +/// Returns the length of the longest common substring between `s1` and `s2`. +pub fn longest_common_substring(s1: &str, s2: &str) -> usize { + let mut substr_len = vec![vec![0; s2.len() + 1]; s1.len() + 1]; + let mut max_len = 0; + + s1.as_bytes().iter().enumerate().for_each(|(i, &c1)| { + s2.as_bytes().iter().enumerate().for_each(|(j, &c2)| { + if c1 == c2 { + substr_len[i + 1][j + 1] = substr_len[i][j] + 1; + max_len = max_len.max(substr_len[i + 1][j + 1]); + } + }); + }); + + max_len +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_longest_common_substring { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (s1, s2, expected) = $inputs; + assert_eq!(longest_common_substring(s1, s2), expected); + assert_eq!(longest_common_substring(s2, s1), expected); + } + )* + } + } + + test_longest_common_substring! { + test_empty_strings: ("", "", 0), + test_one_empty_string: ("", "a", 0), + test_identical_single_char: ("a", "a", 1), + test_different_single_char: ("a", "b", 0), + test_common_substring_at_start: ("abcdef", "abc", 3), + test_common_substring_at_middle: ("abcdef", "bcd", 3), + test_common_substring_at_end: ("abcdef", "def", 3), + test_no_common_substring: ("abc", "xyz", 0), + test_overlapping_substrings: ("abcdxyz", "xyzabcd", 4), + test_special_characters: ("@abc#def$", "#def@", 4), + test_case_sensitive: ("abcDEF", "ABCdef", 0), + test_full_string_match: ("GeeksforGeeks", "GeeksforGeeks", 13), + test_substring_with_repeated_chars: ("aaaaaaaaaaaaa", "aaa", 3), + test_longer_strings_with_common_substring: ("OldSite:GeeksforGeeks.org", "NewSite:GeeksQuiz.com", 10), + test_no_common_substring_with_special_chars: ("!!!", "???", 0), + } +} diff --git a/src/dynamic_programming/longest_continuous_increasing_subsequence.rs b/src/dynamic_programming/longest_continuous_increasing_subsequence.rs new file mode 100644 index 00000000000..3d47b433ae6 --- /dev/null +++ b/src/dynamic_programming/longest_continuous_increasing_subsequence.rs @@ -0,0 +1,93 @@ +use std::cmp::Ordering; + +/// Finds the longest continuous increasing subsequence in a slice. +/// +/// Given a slice of elements, this function returns a slice representing +/// the longest continuous subsequence where each element is strictly +/// less than the following element. +/// +/// # Arguments +/// +/// * `arr` - A reference to a slice of elements +/// +/// # Returns +/// +/// A subslice of the input, representing the longest continuous increasing subsequence. +/// If there are multiple subsequences of the same length, the function returns the first one found. +pub fn longest_continuous_increasing_subsequence(arr: &[T]) -> &[T] { + if arr.len() <= 1 { + return arr; + } + + let mut start = 0; + let mut max_start = 0; + let mut max_len = 1; + let mut curr_len = 1; + + for i in 1..arr.len() { + match arr[i - 1].cmp(&arr[i]) { + // include current element is greater than or equal to the previous + // one elements in the current increasing sequence + Ordering::Less | Ordering::Equal => { + curr_len += 1; + } + // reset when a strictly decreasing element is found + Ordering::Greater => { + if curr_len > max_len { + max_len = curr_len; + max_start = start; + } + // reset start to the current position + start = i; + // reset current length + curr_len = 1; + } + } + } + + // final check for the last sequence + if curr_len > max_len { + max_len = curr_len; + max_start = start; + } + + &arr[max_start..max_start + max_len] +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(longest_continuous_increasing_subsequence(input), expected); + } + )* + }; + } + + test_cases! { + empty_array: (&[] as &[isize], &[] as &[isize]), + single_element: (&[1], &[1]), + all_increasing: (&[1, 2, 3, 4, 5], &[1, 2, 3, 4, 5]), + all_decreasing: (&[5, 4, 3, 2, 1], &[5]), + with_equal_elements: (&[1, 2, 2, 3, 4, 2], &[1, 2, 2, 3, 4]), + increasing_with_plateau: (&[1, 2, 2, 2, 3, 3, 4], &[1, 2, 2, 2, 3, 3, 4]), + mixed_elements: (&[5, 4, 3, 4, 2, 1], &[3, 4]), + alternating_increase_decrease: (&[1, 2, 1, 2, 1, 2], &[1, 2]), + zigzag: (&[1, 3, 2, 4, 3, 5], &[1, 3]), + single_negative_element: (&[-1], &[-1]), + negative_and_positive_mixed: (&[-2, -1, 0, 1, 2, 3], &[-2, -1, 0, 1, 2, 3]), + increasing_then_decreasing: (&[1, 2, 3, 4, 3, 2, 1], &[1, 2, 3, 4]), + single_increasing_subsequence_later: (&[3, 2, 1, 1, 2, 3, 4], &[1, 1, 2, 3, 4]), + longer_subsequence_at_start: (&[5, 6, 7, 8, 9, 2, 3, 4, 5], &[5, 6, 7, 8, 9]), + longer_subsequence_at_end: (&[2, 3, 4, 10, 5, 6, 7, 8, 9], &[5, 6, 7, 8, 9]), + longest_subsequence_at_start: (&[2, 3, 4, 5, 1, 0], &[2, 3, 4, 5]), + longest_subsequence_at_end: (&[1, 7, 2, 3, 4, 5,], &[2, 3, 4, 5]), + repeated_elements: (&[1, 1, 1, 1, 1], &[1, 1, 1, 1, 1]), + } +} diff --git a/src/dynamic_programming/longest_increasing_subsequence.rs b/src/dynamic_programming/longest_increasing_subsequence.rs new file mode 100644 index 00000000000..ed58500135a --- /dev/null +++ b/src/dynamic_programming/longest_increasing_subsequence.rs @@ -0,0 +1,109 @@ +/// Finds the longest increasing subsequence and returns it. +/// +/// If multiple subsequences with the longest possible subsequence length can be found, the +/// subsequence which appeared first will be returned (see `test_example_1`). +/// +/// Inspired by [this LeetCode problem](https://leetcode.com/problems/longest-increasing-subsequence/). +pub fn longest_increasing_subsequence(input_array: &[T]) -> Vec { + let n = input_array.len(); + if n <= 1 { + return input_array.to_vec(); + } + + let mut increasing_sequence: Vec<(T, usize)> = Vec::new(); + let mut previous = vec![0_usize; n]; + + increasing_sequence.push((input_array[0].clone(), 1)); + for i in 1..n { + let value = input_array[i].clone(); + if value > increasing_sequence.last().unwrap().0 { + previous[i] = increasing_sequence.last().unwrap().1 - 1; + increasing_sequence.push((value, i + 1)); + continue; + } + + let change_position = increasing_sequence + .binary_search(&(value.clone(), 0)) + .unwrap_or_else(|x| x); + increasing_sequence[change_position] = (value, i + 1); + previous[i] = match change_position { + 0 => i, + other => increasing_sequence[other - 1].1 - 1, + }; + } + + // Construct subsequence + let mut out: Vec = Vec::with_capacity(increasing_sequence.len()); + + out.push(increasing_sequence.last().unwrap().0.clone()); + let mut current_index = increasing_sequence.last().unwrap().1 - 1; + while previous[current_index] != current_index { + current_index = previous[current_index]; + out.push(input_array[current_index].clone()); + } + + out.into_iter().rev().collect() +} + +#[cfg(test)] +mod tests { + use super::longest_increasing_subsequence; + + #[test] + /// Need to specify generic type T in order to function + fn test_empty_vec() { + assert_eq!(longest_increasing_subsequence::(&[]), vec![]); + } + + #[test] + fn test_example_1() { + assert_eq!( + longest_increasing_subsequence(&[10, 9, 2, 5, 3, 7, 101, 18]), + vec![2, 3, 7, 18] + ); + } + + #[test] + fn test_example_2() { + assert_eq!( + longest_increasing_subsequence(&[0, 1, 0, 3, 2, 3]), + vec![0, 1, 2, 3] + ); + } + + #[test] + fn test_example_3() { + assert_eq!( + longest_increasing_subsequence(&[7, 7, 7, 7, 7, 7, 7]), + vec![7] + ); + } + + #[test] + fn test_tle() { + let mut input_array = vec![0i64; 1e5 as usize]; + let mut expected_result: Vec = Vec::with_capacity(5e4 as usize); + for (idx, num) in input_array.iter_mut().enumerate() { + match idx % 2 { + 0 => { + *num = idx as i64; + expected_result.push(*num); + } + 1 => *num = -(idx as i64), + _ => unreachable!(), + } + } + expected_result[0] = -1; + assert_eq!( + longest_increasing_subsequence(&input_array), + expected_result + ); + // should be [-1, 2, 4, 6, 8, ...] + // the first number is not 0, it would be replaced by -1 before 2 is added + } + + #[test] + fn test_negative_elements() { + assert_eq!(longest_increasing_subsequence(&[-2, -1]), vec![-2, -1]); + } +} diff --git a/src/dynamic_programming/matrix_chain_multiply.rs b/src/dynamic_programming/matrix_chain_multiply.rs new file mode 100644 index 00000000000..410aec741e5 --- /dev/null +++ b/src/dynamic_programming/matrix_chain_multiply.rs @@ -0,0 +1,91 @@ +//! This module implements a dynamic programming solution to find the minimum +//! number of multiplications needed to multiply a chain of matrices with given dimensions. +//! +//! The algorithm uses a dynamic programming approach with tabulation to calculate the minimum +//! number of multiplications required for matrix chain multiplication. +//! +//! # Time Complexity +//! +//! The algorithm runs in O(n^3) time complexity and O(n^2) space complexity, where n is the +//! number of matrices. + +/// Custom error types for matrix chain multiplication +#[derive(Debug, PartialEq)] +pub enum MatrixChainMultiplicationError { + EmptyDimensions, + InsufficientDimensions, +} + +/// Calculates the minimum number of scalar multiplications required to multiply a chain +/// of matrices with given dimensions. +/// +/// # Arguments +/// +/// * `dimensions`: A vector where each element represents the dimensions of consecutive matrices +/// in the chain. For example, [1, 2, 3, 4] represents matrices of dimensions (1x2), (2x3), and (3x4). +/// +/// # Returns +/// +/// The minimum number of scalar multiplications needed to compute the product of the matrices +/// in the optimal order. +/// +/// # Errors +/// +/// Returns an error if the input is invalid (i.e., empty or length less than 2). +pub fn matrix_chain_multiply( + dimensions: Vec, +) -> Result { + if dimensions.is_empty() { + return Err(MatrixChainMultiplicationError::EmptyDimensions); + } + + if dimensions.len() == 1 { + return Err(MatrixChainMultiplicationError::InsufficientDimensions); + } + + let mut min_operations = vec![vec![0; dimensions.len()]; dimensions.len()]; + + (2..dimensions.len()).for_each(|chain_len| { + (0..dimensions.len() - chain_len).for_each(|start| { + let end = start + chain_len; + min_operations[start][end] = (start + 1..end) + .map(|split| { + min_operations[start][split] + + min_operations[split][end] + + dimensions[start] * dimensions[split] * dimensions[end] + }) + .min() + .unwrap_or(usize::MAX); + }); + }); + + Ok(min_operations[0][dimensions.len() - 1]) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(matrix_chain_multiply(input.clone()), expected); + assert_eq!(matrix_chain_multiply(input.into_iter().rev().collect()), expected); + } + )* + }; + } + + test_cases! { + basic_chain_of_matrices: (vec![1, 2, 3, 4], Ok(18)), + chain_of_large_matrices: (vec![40, 20, 30, 10, 30], Ok(26000)), + long_chain_of_matrices: (vec![1, 2, 3, 4, 3, 5, 7, 6, 10], Ok(182)), + complex_chain_of_matrices: (vec![4, 10, 3, 12, 20, 7], Ok(1344)), + empty_dimensions_input: (vec![], Err(MatrixChainMultiplicationError::EmptyDimensions)), + single_dimensions_input: (vec![10], Err(MatrixChainMultiplicationError::InsufficientDimensions)), + single_matrix_input: (vec![10, 20], Ok(0)), + } +} diff --git a/src/dynamic_programming/maximal_square.rs b/src/dynamic_programming/maximal_square.rs new file mode 100644 index 00000000000..706d0b9aeb8 --- /dev/null +++ b/src/dynamic_programming/maximal_square.rs @@ -0,0 +1,66 @@ +use std::cmp::max; +use std::cmp::min; + +/// Maximal Square +/// +/// Given an `m` * `n` binary matrix filled with 0's and 1's, find the largest square containing only 1's and return its area.\ +/// +/// +/// # Arguments: +/// * `matrix` - an array of integer array +/// +/// # Complexity +/// - time complexity: O(n^2), +/// - space complexity: O(n), +pub fn maximal_square(matrix: &mut [Vec]) -> i32 { + if matrix.is_empty() { + return 0; + } + + let rows = matrix.len(); + let cols = matrix[0].len(); + let mut result: i32 = 0; + + for row in 0..rows { + for col in 0..cols { + if matrix[row][col] == 1 { + if row == 0 || col == 0 { + result = max(result, 1); + } else { + let temp = min(matrix[row - 1][col - 1], matrix[row - 1][col]); + + let count: i32 = min(temp, matrix[row][col - 1]) + 1; + result = max(result, count); + + matrix[row][col] = count; + } + } + } + } + + result * result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test() { + assert_eq!(maximal_square(&mut []), 0); + + let mut matrix = vec![vec![0, 1], vec![1, 0]]; + assert_eq!(maximal_square(&mut matrix), 1); + + let mut matrix = vec![ + vec![1, 0, 1, 0, 0], + vec![1, 0, 1, 1, 1], + vec![1, 1, 1, 1, 1], + vec![1, 0, 0, 1, 0], + ]; + assert_eq!(maximal_square(&mut matrix), 4); + + let mut matrix = vec![vec![0]]; + assert_eq!(maximal_square(&mut matrix), 0); + } +} diff --git a/src/dynamic_programming/maximum_subarray.rs b/src/dynamic_programming/maximum_subarray.rs new file mode 100644 index 00000000000..740f8009d60 --- /dev/null +++ b/src/dynamic_programming/maximum_subarray.rs @@ -0,0 +1,82 @@ +//! This module provides a function to find the largest sum of the subarray +//! in a given array of integers using dynamic programming. It also includes +//! tests to verify the correctness of the implementation. + +/// Custom error type for maximum subarray +#[derive(Debug, PartialEq)] +pub enum MaximumSubarrayError { + EmptyArray, +} + +/// Finds the subarray (containing at least one number) which has the largest sum +/// and returns its sum. +/// +/// A subarray is a contiguous part of an array. +/// +/// # Arguments +/// +/// * `array` - A slice of integers. +/// +/// # Returns +/// +/// A `Result` which is: +/// * `Ok(isize)` representing the largest sum of a contiguous subarray. +/// * `Err(MaximumSubarrayError)` if the array is empty. +/// +/// # Complexity +/// +/// * Time complexity: `O(array.len())` +/// * Space complexity: `O(1)` +pub fn maximum_subarray(array: &[isize]) -> Result { + if array.is_empty() { + return Err(MaximumSubarrayError::EmptyArray); + } + + let mut cur_sum = array[0]; + let mut max_sum = cur_sum; + + for &x in &array[1..] { + cur_sum = (cur_sum + x).max(x); + max_sum = max_sum.max(cur_sum); + } + + Ok(max_sum) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! maximum_subarray_tests { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (array, expected) = $tc; + assert_eq!(maximum_subarray(&array), expected); + } + )* + } + } + + maximum_subarray_tests! { + test_all_non_negative: (vec![1, 0, 5, 8], Ok(14)), + test_all_negative: (vec![-3, -1, -8, -2], Ok(-1)), + test_mixed_negative_and_positive: (vec![-4, 3, -2, 5, -8], Ok(6)), + test_single_element_positive: (vec![6], Ok(6)), + test_single_element_negative: (vec![-6], Ok(-6)), + test_mixed_elements: (vec![-2, 1, -3, 4, -1, 2, 1, -5, 4], Ok(6)), + test_empty_array: (vec![], Err(MaximumSubarrayError::EmptyArray)), + test_all_zeroes: (vec![0, 0, 0, 0], Ok(0)), + test_single_zero: (vec![0], Ok(0)), + test_alternating_signs: (vec![3, -2, 5, -1], Ok(6)), + test_all_negatives_with_one_positive: (vec![-3, -4, 1, -7, -2], Ok(1)), + test_all_positives_with_one_negative: (vec![3, 4, -1, 7, 2], Ok(15)), + test_all_positives: (vec![2, 3, 1, 5], Ok(11)), + test_large_values: (vec![1000, -500, 1000, -500, 1000], Ok(2000)), + test_large_array: ((0..1000).collect::>(), Ok(499500)), + test_large_negative_array: ((0..1000).map(|x| -x).collect::>(), Ok(0)), + test_single_large_positive: (vec![1000000], Ok(1000000)), + test_single_large_negative: (vec![-1000000], Ok(-1000000)), + } +} diff --git a/src/dynamic_programming/minimum_cost_path.rs b/src/dynamic_programming/minimum_cost_path.rs new file mode 100644 index 00000000000..e06481199cf --- /dev/null +++ b/src/dynamic_programming/minimum_cost_path.rs @@ -0,0 +1,177 @@ +use std::cmp::min; + +/// Represents possible errors that can occur when calculating the minimum cost path in a matrix. +#[derive(Debug, PartialEq, Eq)] +pub enum MatrixError { + /// Error indicating that the matrix is empty or has empty rows. + EmptyMatrix, + /// Error indicating that the matrix is not rectangular in shape. + NonRectangularMatrix, +} + +/// Computes the minimum cost path from the top-left to the bottom-right +/// corner of a matrix, where movement is restricted to right and down directions. +/// +/// # Arguments +/// +/// * `matrix` - A 2D vector of positive integers, where each element represents +/// the cost to step on that cell. +/// +/// # Returns +/// +/// * `Ok(usize)` - The minimum path cost to reach the bottom-right corner from +/// the top-left corner of the matrix. +/// * `Err(MatrixError)` - An error if the matrix is empty or improperly formatted. +/// +/// # Complexity +/// +/// * Time complexity: `O(m * n)`, where `m` is the number of rows +/// and `n` is the number of columns in the input matrix. +/// * Space complexity: `O(n)`, as only a single row of cumulative costs +/// is stored at any time. +pub fn minimum_cost_path(matrix: Vec>) -> Result { + // Check if the matrix is rectangular + if !matrix.iter().all(|row| row.len() == matrix[0].len()) { + return Err(MatrixError::NonRectangularMatrix); + } + + // Check if the matrix is empty or contains empty rows + if matrix.is_empty() || matrix.iter().all(|row| row.is_empty()) { + return Err(MatrixError::EmptyMatrix); + } + + // Initialize the first row of the cost vector + let mut cost = matrix[0] + .iter() + .scan(0, |acc, &val| { + *acc += val; + Some(*acc) + }) + .collect::>(); + + // Process each row from the second to the last + for row in matrix.iter().skip(1) { + // Update the first element of cost for this row + cost[0] += row[0]; + + // Update the rest of the elements in the current row of cost + for col in 1..matrix[0].len() { + cost[col] = row[col] + min(cost[col - 1], cost[col]); + } + } + + // The last element in cost contains the minimum path cost to the bottom-right corner + Ok(cost[matrix[0].len() - 1]) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! minimum_cost_path_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (matrix, expected) = $test_case; + assert_eq!(minimum_cost_path(matrix), expected); + } + )* + }; + } + + minimum_cost_path_tests! { + basic: ( + vec![ + vec![2, 1, 4], + vec![2, 1, 3], + vec![3, 2, 1] + ], + Ok(7) + ), + single_element: ( + vec![ + vec![5] + ], + Ok(5) + ), + single_row: ( + vec![ + vec![1, 3, 2, 1, 5] + ], + Ok(12) + ), + single_column: ( + vec![ + vec![1], + vec![3], + vec![2], + vec![1], + vec![5] + ], + Ok(12) + ), + large_matrix: ( + vec![ + vec![1, 3, 1, 5], + vec![2, 1, 4, 2], + vec![3, 2, 1, 3], + vec![4, 3, 2, 1] + ], + Ok(10) + ), + uniform_matrix: ( + vec![ + vec![1, 1, 1], + vec![1, 1, 1], + vec![1, 1, 1] + ], + Ok(5) + ), + increasing_values: ( + vec![ + vec![1, 2, 3], + vec![4, 5, 6], + vec![7, 8, 9] + ], + Ok(21) + ), + high_cost_path: ( + vec![ + vec![1, 100, 1], + vec![1, 100, 1], + vec![1, 1, 1] + ], + Ok(5) + ), + complex_matrix: ( + vec![ + vec![5, 9, 6, 8], + vec![1, 4, 7, 3], + vec![2, 1, 8, 2], + vec![3, 6, 9, 4] + ], + Ok(23) + ), + empty_matrix: ( + vec![], + Err(MatrixError::EmptyMatrix) + ), + empty_row: ( + vec![ + vec![], + vec![], + vec![] + ], + Err(MatrixError::EmptyMatrix) + ), + non_rectangular: ( + vec![ + vec![1, 2, 3], + vec![4, 5], + vec![6, 7, 8] + ], + Err(MatrixError::NonRectangularMatrix) + ), + } +} diff --git a/src/dynamic_programming/mod.rs b/src/dynamic_programming/mod.rs index 8fe40643477..f18c1847479 100644 --- a/src/dynamic_programming/mod.rs +++ b/src/dynamic_programming/mod.rs @@ -1,14 +1,49 @@ mod coin_change; -mod edit_distance; mod egg_dropping; mod fibonacci; +mod fractional_knapsack; +mod is_subsequence; mod knapsack; mod longest_common_subsequence; +mod longest_common_substring; +mod longest_continuous_increasing_subsequence; +mod longest_increasing_subsequence; +mod matrix_chain_multiply; +mod maximal_square; +mod maximum_subarray; +mod minimum_cost_path; +mod optimal_bst; +mod rod_cutting; +mod snail; +mod subset_generation; +mod trapped_rainwater; +mod word_break; pub use self::coin_change::coin_change; -pub use self::edit_distance::{edit_distance, edit_distance_se}; pub use self::egg_dropping::egg_drop; +pub use self::fibonacci::binary_lifting_fibonacci; +pub use self::fibonacci::classical_fibonacci; pub use self::fibonacci::fibonacci; +pub use self::fibonacci::last_digit_of_the_sum_of_nth_fibonacci_number; +pub use self::fibonacci::logarithmic_fibonacci; +pub use self::fibonacci::matrix_fibonacci; +pub use self::fibonacci::memoized_fibonacci; +pub use self::fibonacci::nth_fibonacci_number_modulo_m; pub use self::fibonacci::recursive_fibonacci; +pub use self::fractional_knapsack::fractional_knapsack; +pub use self::is_subsequence::is_subsequence; pub use self::knapsack::knapsack; pub use self::longest_common_subsequence::longest_common_subsequence; +pub use self::longest_common_substring::longest_common_substring; +pub use self::longest_continuous_increasing_subsequence::longest_continuous_increasing_subsequence; +pub use self::longest_increasing_subsequence::longest_increasing_subsequence; +pub use self::matrix_chain_multiply::matrix_chain_multiply; +pub use self::maximal_square::maximal_square; +pub use self::maximum_subarray::maximum_subarray; +pub use self::minimum_cost_path::minimum_cost_path; +pub use self::optimal_bst::optimal_search_tree; +pub use self::rod_cutting::rod_cut; +pub use self::snail::snail; +pub use self::subset_generation::list_subset; +pub use self::trapped_rainwater::trapped_rainwater; +pub use self::word_break::word_break; diff --git a/src/dynamic_programming/optimal_bst.rs b/src/dynamic_programming/optimal_bst.rs new file mode 100644 index 00000000000..162351a21c6 --- /dev/null +++ b/src/dynamic_programming/optimal_bst.rs @@ -0,0 +1,93 @@ +// Optimal Binary Search Tree Algorithm in Rust +// Time Complexity: O(n^3) with prefix sum optimization +// Space Complexity: O(n^2) for the dp table and prefix sum array + +/// Constructs an Optimal Binary Search Tree from a list of key frequencies. +/// The goal is to minimize the expected search cost given key access frequencies. +/// +/// # Arguments +/// * `freq` - A slice of integers representing the frequency of key access +/// +/// # Returns +/// * An integer representing the minimum cost of the optimal BST +pub fn optimal_search_tree(freq: &[i32]) -> i32 { + let n = freq.len(); + if n == 0 { + return 0; + } + + // dp[i][j] stores the cost of optimal BST that can be formed from keys[i..=j] + let mut dp = vec![vec![0; n]; n]; + + // prefix_sum[i] stores sum of freq[0..i] + let mut prefix_sum = vec![0; n + 1]; + for i in 0..n { + prefix_sum[i + 1] = prefix_sum[i] + freq[i]; + } + + // Base case: Trees with only one key + for i in 0..n { + dp[i][i] = freq[i]; + } + + // Build chains of increasing length l (from 2 to n) + for l in 2..=n { + for i in 0..=n - l { + let j = i + l - 1; + dp[i][j] = i32::MAX; + + // Compute the total frequency sum in the range [i..=j] using prefix sum + let fsum = prefix_sum[j + 1] - prefix_sum[i]; + + // Try making each key in freq[i..=j] the root of the tree + for r in i..=j { + // Cost of left subtree + let left = if r > i { dp[i][r - 1] } else { 0 }; + // Cost of right subtree + let right = if r < j { dp[r + 1][j] } else { 0 }; + + // Total cost = left + right + sum of frequencies (fsum) + let cost = left + right + fsum; + + // Choose the minimum among all possible roots + if cost < dp[i][j] { + dp[i][j] = cost; + } + } + } + } + + // Minimum cost of the optimal BST storing all keys + dp[0][n - 1] +} + +#[cfg(test)] +mod tests { + use super::*; + + // Macro to generate multiple test cases for the optimal_search_tree function + macro_rules! optimal_bst_tests { + ($($name:ident: $input:expr => $expected:expr,)*) => { + $( + #[test] + fn $name() { + let freq = $input; + assert_eq!(optimal_search_tree(freq), $expected); + } + )* + }; + } + + optimal_bst_tests! { + // Common test cases + test_case_1: &[34, 10, 8, 50] => 180, + test_case_2: &[10, 12] => 32, + test_case_3: &[10, 12, 20] => 72, + test_case_4: &[25, 10, 20] => 95, + test_case_5: &[4, 2, 6, 3] => 26, + + // Edge test cases + test_case_single: &[42] => 42, + test_case_empty: &[] => 0, + } +} diff --git a/src/dynamic_programming/rod_cutting.rs b/src/dynamic_programming/rod_cutting.rs new file mode 100644 index 00000000000..e56d482fdf7 --- /dev/null +++ b/src/dynamic_programming/rod_cutting.rs @@ -0,0 +1,65 @@ +//! This module provides functions for solving the rod-cutting problem using dynamic programming. +use std::cmp::max; + +/// Calculates the maximum possible profit from cutting a rod into pieces of varying lengths. +/// +/// Returns the maximum profit achievable by cutting a rod into pieces such that the profit from each +/// piece is determined by its length and predefined prices. +/// +/// # Complexity +/// - Time complexity: `O(n^2)` +/// - Space complexity: `O(n)` +/// +/// where `n` is the number of different rod lengths considered. +pub fn rod_cut(prices: &[usize]) -> usize { + if prices.is_empty() { + return 0; + } + + (1..=prices.len()).fold(vec![0; prices.len() + 1], |mut max_profit, rod_length| { + max_profit[rod_length] = (1..=rod_length) + .map(|cut_position| prices[cut_position - 1] + max_profit[rod_length - cut_position]) + .fold(prices[rod_length - 1], |max_price, current_price| { + max(max_price, current_price) + }); + max_profit + })[prices.len()] +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! rod_cut_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected_output) = $test_case; + assert_eq!(expected_output, rod_cut(input)); + } + )* + }; + } + + rod_cut_tests! { + test_empty_prices: (&[], 0), + test_example_with_three_prices: (&[5, 8, 2], 15), + test_example_with_four_prices: (&[1, 5, 8, 9], 10), + test_example_with_five_prices: (&[5, 8, 2, 1, 7], 25), + test_all_zeros_except_last: (&[0, 0, 0, 0, 0, 87], 87), + test_descending_prices: (&[7, 6, 5, 4, 3, 2, 1], 49), + test_varied_prices: (&[1, 5, 8, 9, 10, 17, 17, 20], 22), + test_complex_prices: (&[6, 4, 8, 2, 5, 8, 2, 3, 7, 11], 60), + test_increasing_prices: (&[1, 5, 8, 9, 10, 17, 17, 20, 24, 30], 30), + test_large_range_prices: (&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 12), + test_single_length_price: (&[5], 5), + test_zero_length_price: (&[0], 0), + test_repeated_prices: (&[5, 5, 5, 5], 20), + test_no_profit: (&[0, 0, 0, 0], 0), + test_large_input: (&[1; 1000], 1000), + test_all_zero_input: (&[0; 100], 0), + test_very_large_prices: (&[1000000, 2000000, 3000000], 3000000), + test_greedy_does_not_work: (&[2, 5, 7, 8], 10), + } +} diff --git a/src/dynamic_programming/snail.rs b/src/dynamic_programming/snail.rs new file mode 100644 index 00000000000..8b1a8358f0b --- /dev/null +++ b/src/dynamic_programming/snail.rs @@ -0,0 +1,148 @@ +/// ## Spiral Sorting +/// +/// Given an n x m array, return the array elements arranged from outermost elements +/// to the middle element, traveling INWARD FROM TOP-LEFT, CLOCKWISE. +pub fn snail(matrix: &[Vec]) -> Vec { + // break on empty matrix + if matrix.is_empty() || matrix[0].is_empty() { + return vec![]; + } + + let col_count = matrix[0].len(); + let row_count = matrix.len(); + + // Initial maximum/minimum indices + let mut max_col = col_count - 1; + let mut min_col = 0; + let mut max_row = row_count - 1; + let mut min_row = 0; + + // Initial direction is Right because + // we start from the top-left corner of the matrix at indices [0][0] + let mut dir = Direction::Right; + let mut row = 0; + let mut col = 0; + let mut result = Vec::new(); + + while result.len() < row_count * col_count { + result.push(matrix[row][col]); + dir.snail_move( + &mut col, + &mut row, + &mut min_col, + &mut max_col, + &mut min_row, + &mut max_row, + ); + } + + result +} + +enum Direction { + Right, + Left, + Down, + Up, +} + +impl Direction { + pub fn snail_move( + &mut self, + col: &mut usize, + row: &mut usize, + min_col: &mut usize, + max_col: &mut usize, + min_row: &mut usize, + max_row: &mut usize, + ) { + match self { + Self::Right => { + *col = if *col < *max_col { + *col + 1 + } else { + *self = Self::Down; + *min_row += 1; + *row = *min_row; + *col + }; + } + + Self::Down => { + *row = if *row < *max_row { + *row + 1 + } else { + *self = Self::Left; + *max_col -= 1; + *col = *max_col; + *row + }; + } + + Self::Left => { + *col = if *col > usize::MIN && *col > *min_col { + *col - 1 + } else { + *self = Self::Up; + *max_row -= 1; + *row = *max_row; + *col + }; + } + + Self::Up => { + *row = if *row > usize::MIN && *row > *min_row { + *row - 1 + } else { + *self = Self::Right; + *min_col += 1; + *col = *min_col; + *row + }; + } + }; + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_empty() { + let empty: &[Vec] = &[vec![]]; + assert_eq!(snail(empty), vec![]); + } + + #[test] + fn test_int() { + let square = &[vec![1, 2, 3], vec![4, 5, 6], vec![7, 8, 9]]; + assert_eq!(snail(square), vec![1, 2, 3, 6, 9, 8, 7, 4, 5]); + } + + #[test] + fn test_char() { + let square = &[ + vec!['S', 'O', 'M'], + vec!['E', 'T', 'H'], + vec!['I', 'N', 'G'], + ]; + assert_eq!( + snail(square), + vec!['S', 'O', 'M', 'H', 'G', 'N', 'I', 'E', 'T'] + ); + } + + #[test] + fn test_rect() { + let square = &[ + vec!['H', 'E', 'L', 'L'], + vec!['O', ' ', 'W', 'O'], + vec!['R', 'L', 'D', ' '], + ]; + assert_eq!( + snail(square), + vec!['H', 'E', 'L', 'L', 'O', ' ', 'D', 'L', 'R', 'O', ' ', 'W'] + ); + } +} diff --git a/src/dynamic_programming/subset_generation.rs b/src/dynamic_programming/subset_generation.rs new file mode 100644 index 00000000000..39dac5d293e --- /dev/null +++ b/src/dynamic_programming/subset_generation.rs @@ -0,0 +1,116 @@ +// list all subset combinations of n element in given set of r element. +// This is a recursive function that collects all subsets of the set of size n +// with the given set of size r. +pub fn list_subset( + set: &[i32], + n: usize, + r: usize, + index: usize, + data: &mut [i32], + i: usize, +) -> Vec> { + let mut res = Vec::new(); + + // Current subset is ready to be added to the list + if i == r { + let mut subset = Vec::new(); + for j in data.iter().take(r) { + subset.push(*j); + } + res.push(subset); + return res; + } + + // When no more elements are there to put in data[] + if index >= n { + return res; + } + + // current is included, put next at next location + data[i] = set[index]; + res.append(&mut list_subset(set, n, r, index + 1, data, i + 1)); + + // current is excluded, replace it with next (Note that + // i+1 is passed, but index is not changed) + res.append(&mut list_subset(set, n, r, index + 1, data, i)); + + res +} + +// Test module +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_print_subset3() { + let set = [1, 2, 3, 4, 5]; + let n = set.len(); + const R: usize = 3; + let mut data = [0; R]; + + let res = list_subset(&set, n, R, 0, &mut data, 0); + + assert_eq!( + res, + vec![ + vec![1, 2, 3], + vec![1, 2, 4], + vec![1, 2, 5], + vec![1, 3, 4], + vec![1, 3, 5], + vec![1, 4, 5], + vec![2, 3, 4], + vec![2, 3, 5], + vec![2, 4, 5], + vec![3, 4, 5] + ] + ); + } + + #[test] + fn test_print_subset4() { + let set = [1, 2, 3, 4, 5]; + let n = set.len(); + const R: usize = 4; + let mut data = [0; R]; + + let res = list_subset(&set, n, R, 0, &mut data, 0); + + assert_eq!( + res, + vec![ + vec![1, 2, 3, 4], + vec![1, 2, 3, 5], + vec![1, 2, 4, 5], + vec![1, 3, 4, 5], + vec![2, 3, 4, 5] + ] + ); + } + + #[test] + fn test_print_subset5() { + let set = [1, 2, 3, 4, 5]; + let n = set.len(); + const R: usize = 5; + let mut data = [0; R]; + + let res = list_subset(&set, n, R, 0, &mut data, 0); + + assert_eq!(res, vec![vec![1, 2, 3, 4, 5]]); + } + + #[test] + fn test_print_incorrect_subset() { + let set = [1, 2, 3, 4, 5]; + let n = set.len(); + const R: usize = 6; + let mut data = [0; R]; + + let res = list_subset(&set, n, R, 0, &mut data, 0); + + let result_set: Vec> = Vec::new(); + assert_eq!(res, result_set); + } +} diff --git a/src/dynamic_programming/trapped_rainwater.rs b/src/dynamic_programming/trapped_rainwater.rs new file mode 100644 index 00000000000..b220754ca23 --- /dev/null +++ b/src/dynamic_programming/trapped_rainwater.rs @@ -0,0 +1,125 @@ +//! Module to calculate trapped rainwater in an elevation map. + +/// Computes the total volume of trapped rainwater in a given elevation map. +/// +/// # Arguments +/// +/// * `elevation_map` - A slice containing the heights of the terrain elevations. +/// +/// # Returns +/// +/// The total volume of trapped rainwater. +pub fn trapped_rainwater(elevation_map: &[u32]) -> u32 { + let left_max = calculate_max_values(elevation_map, false); + let right_max = calculate_max_values(elevation_map, true); + let mut water_trapped = 0; + // Calculate trapped water + for i in 0..elevation_map.len() { + water_trapped += left_max[i].min(right_max[i]) - elevation_map[i]; + } + water_trapped +} + +/// Determines the maximum heights from either direction in the elevation map. +/// +/// # Arguments +/// +/// * `elevation_map` - A slice representing the heights of the terrain elevations. +/// * `reverse` - A boolean that indicates the direction of calculation. +/// - `false` for left-to-right. +/// - `true` for right-to-left. +/// +/// # Returns +/// +/// A vector containing the maximum heights encountered up to each position. +fn calculate_max_values(elevation_map: &[u32], reverse: bool) -> Vec { + let mut max_values = vec![0; elevation_map.len()]; + let mut current_max = 0; + for i in create_iter(elevation_map.len(), reverse) { + current_max = current_max.max(elevation_map[i]); + max_values[i] = current_max; + } + max_values +} + +/// Creates an iterator for the given length, optionally reversing it. +/// +/// # Arguments +/// +/// * `len` - The length of the iterator. +/// * `reverse` - A boolean that determines the order of iteration. +/// - `false` for forward iteration. +/// - `true` for reverse iteration. +/// +/// # Returns +/// +/// A boxed iterator that iterates over the range of indices. +fn create_iter(len: usize, reverse: bool) -> Box> { + if reverse { + Box::new((0..len).rev()) + } else { + Box::new(0..len) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! trapped_rainwater_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (elevation_map, expected_trapped_water) = $test_case; + assert_eq!(trapped_rainwater(&elevation_map), expected_trapped_water); + let elevation_map_rev: Vec = elevation_map.iter().rev().cloned().collect(); + assert_eq!(trapped_rainwater(&elevation_map_rev), expected_trapped_water); + } + )* + }; + } + + trapped_rainwater_tests! { + test_trapped_rainwater_basic: ( + [0, 1, 0, 2, 1, 0, 1, 3, 2, 1, 2, 1], + 6 + ), + test_trapped_rainwater_peak_under_water: ( + [3, 0, 2, 0, 4], + 7, + ), + test_bucket: ( + [5, 1, 5], + 4 + ), + test_skewed_bucket: ( + [4, 1, 5], + 3 + ), + test_trapped_rainwater_empty: ( + [], + 0 + ), + test_trapped_rainwater_flat: ( + [0, 0, 0, 0, 0], + 0 + ), + test_trapped_rainwater_no_trapped_water: ( + [1, 1, 2, 4, 0, 0, 0], + 0 + ), + test_trapped_rainwater_single_elevation_map: ( + [5], + 0 + ), + test_trapped_rainwater_two_point_elevation_map: ( + [5, 1], + 0 + ), + test_trapped_rainwater_large_elevation_map_difference: ( + [5, 1, 6, 1, 7, 1, 8], + 15 + ), + } +} diff --git a/src/dynamic_programming/word_break.rs b/src/dynamic_programming/word_break.rs new file mode 100644 index 00000000000..f3153525f6e --- /dev/null +++ b/src/dynamic_programming/word_break.rs @@ -0,0 +1,89 @@ +use crate::data_structures::Trie; + +/// Checks if a string can be segmented into a space-separated sequence +/// of one or more words from the given dictionary. +/// +/// # Arguments +/// * `s` - The input string to be segmented. +/// * `word_dict` - A slice of words forming the dictionary. +/// +/// # Returns +/// * `bool` - `true` if the string can be segmented, `false` otherwise. +pub fn word_break(s: &str, word_dict: &[&str]) -> bool { + let mut trie = Trie::new(); + for &word in word_dict { + trie.insert(word.chars(), true); + } + + // Memoization vector: one extra space to handle out-of-bound end case. + let mut memo = vec![None; s.len() + 1]; + search(&trie, s, 0, &mut memo) +} + +/// Recursively checks if the substring starting from `start` can be segmented +/// using words in the trie and memoizes the results. +/// +/// # Arguments +/// * `trie` - The Trie containing the dictionary words. +/// * `s` - The input string. +/// * `start` - The starting index for the current substring. +/// * `memo` - A vector for memoization to store intermediate results. +/// +/// # Returns +/// * `bool` - `true` if the substring can be segmented, `false` otherwise. +fn search(trie: &Trie, s: &str, start: usize, memo: &mut Vec>) -> bool { + if start == s.len() { + return true; + } + + if let Some(res) = memo[start] { + return res; + } + + for end in start + 1..=s.len() { + if trie.get(s[start..end].chars()).is_some() && search(trie, s, end, memo) { + memo[start] = Some(true); + return true; + } + } + + memo[start] = Some(false); + false +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, dict, expected) = $test_case; + assert_eq!(word_break(input, &dict), expected); + } + )* + } + } + + test_cases! { + typical_case_1: ("applepenapple", vec!["apple", "pen"], true), + typical_case_2: ("catsandog", vec!["cats", "dog", "sand", "and", "cat"], false), + typical_case_3: ("cars", vec!["car", "ca", "rs"], true), + edge_case_empty_string: ("", vec!["apple", "pen"], true), + edge_case_empty_dict: ("apple", vec![], false), + edge_case_single_char_in_dict: ("a", vec!["a"], true), + edge_case_single_char_not_in_dict: ("b", vec!["a"], false), + edge_case_all_words_larger_than_input: ("a", vec!["apple", "banana"], false), + edge_case_no_solution_large_string: ("abcdefghijklmnoqrstuv", vec!["a", "bc", "def", "ghij", "klmno", "pqrst"], false), + successful_segmentation_large_string: ("abcdefghijklmnopqrst", vec!["a", "bc", "def", "ghij", "klmno", "pqrst"], true), + long_string_repeated_pattern: (&"ab".repeat(100), vec!["a", "b", "ab"], true), + long_string_no_solution: (&"a".repeat(100), vec!["b"], false), + mixed_size_dict_1: ("pineapplepenapple", vec!["apple", "pen", "applepen", "pine", "pineapple"], true), + mixed_size_dict_2: ("catsandog", vec!["cats", "dog", "sand", "and", "cat"], false), + mixed_size_dict_3: ("abcd", vec!["a", "abc", "b", "cd"], true), + performance_stress_test_large_valid: (&"abc".repeat(1000), vec!["a", "ab", "abc"], true), + performance_stress_test_large_invalid: (&"x".repeat(1000), vec!["a", "ab", "abc"], false), + } +} diff --git a/src/financial/mod.rs b/src/financial/mod.rs new file mode 100644 index 00000000000..89b36bfa5e0 --- /dev/null +++ b/src/financial/mod.rs @@ -0,0 +1,2 @@ +mod present_value; +pub use present_value::present_value; diff --git a/src/financial/present_value.rs b/src/financial/present_value.rs new file mode 100644 index 00000000000..5294b71758c --- /dev/null +++ b/src/financial/present_value.rs @@ -0,0 +1,91 @@ +/// In economics and finance, present value (PV), also known as present discounted value, +/// is the value of an expected income stream determined as of the date of valuation. +/// +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Present_value + +#[derive(PartialEq, Eq, Debug)] +pub enum PresentValueError { + NegetiveDiscount, + EmptyCashFlow, +} + +pub fn present_value(discount_rate: f64, cash_flows: Vec) -> Result { + if discount_rate < 0.0 { + return Err(PresentValueError::NegetiveDiscount); + } + if cash_flows.is_empty() { + return Err(PresentValueError::EmptyCashFlow); + } + + let present_value = cash_flows + .iter() + .enumerate() + .map(|(i, &cash_flow)| cash_flow / (1.0 + discount_rate).powi(i as i32)) + .sum::(); + + Ok(round(present_value)) +} + +fn round(value: f64) -> f64 { + (value * 100.0).round() / 100.0 +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_present_value { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let ((discount_rate,cash_flows), expected) = $inputs; + assert_eq!(present_value(discount_rate,cash_flows).unwrap(), expected); + } + )* + } + } + + macro_rules! test_present_value_Err { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let ((discount_rate,cash_flows), expected) = $inputs; + assert_eq!(present_value(discount_rate,cash_flows).unwrap_err(), expected); + } + )* + } + } + + macro_rules! test_round { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $inputs; + assert_eq!(round(input), expected); + } + )* + } + } + + test_present_value! { + general_inputs1:((0.13, vec![10.0, 20.70, -293.0, 297.0]),4.69), + general_inputs2:((0.07, vec![-109129.39, 30923.23, 15098.93, 29734.0, 39.0]),-42739.63), + general_inputs3:((0.07, vec![109129.39, 30923.23, 15098.93, 29734.0, 39.0]), 175519.15), + zero_input:((0.0, vec![109129.39, 30923.23, 15098.93, 29734.0, 39.0]), 184924.55), + + } + + test_present_value_Err! { + negative_discount_rate:((-1.0, vec![10.0, 20.70, -293.0, 297.0]), PresentValueError::NegetiveDiscount), + empty_cash_flow:((1.0, vec![]), PresentValueError::EmptyCashFlow), + + } + test_round! { + test1:(0.55434, 0.55), + test2:(10.453, 10.45), + test3:(1111_f64, 1111_f64), + } +} diff --git a/src/general/convex_hull.rs b/src/general/convex_hull.rs index 272004adee2..e2a073b4e57 100644 --- a/src/general/convex_hull.rs +++ b/src/general/convex_hull.rs @@ -5,9 +5,9 @@ fn sort_by_min_angle(pts: &[(f64, f64)], min: &(f64, f64)) -> Vec<(f64, f64)> { .iter() .map(|x| { ( - ((x.1 - min.1) as f64).atan2((x.0 - min.0) as f64), + (x.1 - min.1).atan2(x.0 - min.0), // angle - ((x.1 - min.1) as f64).hypot((x.0 - min.0) as f64), + (x.1 - min.1).hypot(x.0 - min.0), // distance (we want the closest to be first) *x, ) @@ -70,7 +70,7 @@ mod tests { #[test] fn empty() { - assert_eq!(convex_hull_graham(&vec![]), vec![]); + assert_eq!(convex_hull_graham(&[]), vec![]); } #[test] diff --git a/src/general/fisher_yates_shuffle.rs b/src/general/fisher_yates_shuffle.rs new file mode 100644 index 00000000000..8ba448b2105 --- /dev/null +++ b/src/general/fisher_yates_shuffle.rs @@ -0,0 +1,25 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +use crate::math::PCG32; + +const DEFAULT: u64 = 4294967296; + +fn gen_range(range: usize, generator: &mut PCG32) -> usize { + generator.get_u64() as usize % range +} + +pub fn fisher_yates_shuffle(array: &mut [i32]) { + let seed = match SystemTime::now().duration_since(UNIX_EPOCH) { + Ok(duration) => duration.as_millis() as u64, + Err(_) => DEFAULT, + }; + + let mut random_generator = PCG32::new_default(seed); + + let len = array.len(); + + for i in 0..(len - 2) { + let r = gen_range(len - i, &mut random_generator); + array.swap(i, i + r); + } +} diff --git a/src/general/genetic.rs b/src/general/genetic.rs new file mode 100644 index 00000000000..43221989a23 --- /dev/null +++ b/src/general/genetic.rs @@ -0,0 +1,461 @@ +use std::cmp::Ordering; +use std::collections::BTreeSet; +use std::fmt::Debug; + +/// The goal is to showcase how Genetic algorithms generically work +/// See: https://en.wikipedia.org/wiki/Genetic_algorithm for concepts + +/// This is the definition of a Chromosome for a genetic algorithm +/// We can picture this as "one contending solution to our problem" +/// It is generic over: +/// * Eval, which could be a float, or any other totally ordered type, so that we can rank solutions to our problem +/// * Rng: a random number generator (could be thread rng, etc.) +pub trait Chromosome { + /// Mutates this Chromosome, changing its genes + fn mutate(&mut self, rng: &mut Rng); + + /// Mixes this chromosome with another one + fn crossover(&self, other: &Self, rng: &mut Rng) -> Self; + + /// How well this chromosome fits the problem we're trying to solve + /// **The smaller the better it fits** (we could use abs(... - expected_value) for instance + fn fitness(&self) -> Eval; +} + +pub trait SelectionStrategy { + fn new(rng: Rng) -> Self; + + /// Selects a portion of the population for reproduction + /// Could be totally random ones or the ones that fit best, etc. + /// This assumes the population is sorted by how it fits the solution (the first the better) + fn select<'a, Eval: Into, C: Chromosome>( + &mut self, + population: &'a [C], + ) -> (&'a C, &'a C); +} + +/// A roulette wheel selection strategy +/// https://en.wikipedia.org/wiki/Fitness_proportionate_selection +pub struct RouletteWheel { + rng: Rng, +} +impl SelectionStrategy for RouletteWheel { + fn new(rng: Rng) -> Self { + Self { rng } + } + + fn select<'a, Eval: Into, C: Chromosome>( + &mut self, + population: &'a [C], + ) -> (&'a C, &'a C) { + // We will assign a probability for every item in the population, based on its proportion towards the sum of all fitness + // This would work well for an increasing fitness function, but not in our case of a fitness function for which "lower is better" + // We thus need to take the reciprocal + let mut parents = Vec::with_capacity(2); + let fitnesses: Vec = population + .iter() + .filter_map(|individual| { + let fitness = individual.fitness().into(); + if individual.fitness().into() == 0.0 { + parents.push(individual); + None + } else { + Some(1.0 / fitness) + } + }) + .collect(); + if parents.len() == 2 { + return (parents[0], parents[1]); + } + let sum: f64 = fitnesses.iter().sum(); + let mut spin = self.rng.random_range(0.0..=sum); + for individual in population { + let fitness: f64 = individual.fitness().into(); + if spin <= fitness { + parents.push(individual); + if parents.len() == 2 { + return (parents[0], parents[1]); + } + } else { + spin -= fitness; + } + } + panic!("Could not select parents"); + } +} + +pub struct Tournament { + rng: Rng, +} +impl SelectionStrategy for Tournament { + fn new(rng: Rng) -> Self { + Self { rng } + } + + fn select<'a, Eval, C: Chromosome>( + &mut self, + population: &'a [C], + ) -> (&'a C, &'a C) { + if K < 2 { + panic!("K must be > 2"); + } + // This strategy is defined as the following: pick K chromosomes randomly, use the 2 that fits the best + // We assume the population is sorted + // This means we can draw K random (distinct) numbers between (0..population.len()) and return the chromosomes at the 2 lowest indices + let mut picked_indices = BTreeSet::new(); // will keep indices ordered + while picked_indices.len() < K { + picked_indices.insert(self.rng.random_range(0..population.len())); + } + let mut iter = picked_indices.into_iter(); + ( + &population[iter.next().unwrap()], + &population[iter.next().unwrap()], + ) + } +} + +type Comparator = Box Ordering>; +pub struct GeneticAlgorithm< + Rng: rand::Rng, + Eval: PartialOrd, + C: Chromosome, + Selection: SelectionStrategy, +> { + rng: Rng, // will be used to draw random numbers for initial population, mutations and crossovers + population: Vec, // the set of random solutions (chromosomes) + threshold: Eval, // Any chromosome fitting over this threshold is considered a valid solution + max_generations: usize, // So that we don't loop infinitely + mutation_chance: f64, // what's the probability a chromosome will mutate + crossover_chance: f64, // what's the probability two chromosomes will cross-over and give birth to a new chromosome + compare: Comparator, + selection: Selection, // how we will select parent chromosomes for crossing over, see `SelectionStrategy` +} + +pub struct GenericAlgorithmParams { + max_generations: usize, + mutation_chance: f64, + crossover_chance: f64, +} + +impl< + Rng: rand::Rng, + Eval: Into + PartialOrd + Debug, + C: Chromosome + Clone + Debug, + Selection: SelectionStrategy, + > GeneticAlgorithm +{ + pub fn init( + rng: Rng, + population: Vec, + threshold: Eval, + params: GenericAlgorithmParams, + compare: Comparator, + selection: Selection, + ) -> Self { + let GenericAlgorithmParams { + max_generations, + mutation_chance, + crossover_chance, + } = params; + Self { + rng, + population, + threshold, + max_generations, + mutation_chance, + crossover_chance, + compare, + selection, + } + } + + pub fn solve(&mut self) -> Option { + let mut generations = 1; // 1st generation is our initial population + while generations <= self.max_generations { + // 1. Sort the population by fitness score, remember: the lower the better (so natural ordering) + self.population + .sort_by(|c1: &C, c2: &C| (self.compare)(&c1.fitness(), &c2.fitness())); + + // 2. Stop condition: we might have found a good solution + if let Some(solution) = self.population.first() { + if solution.fitness() <= self.threshold { + return Some(solution).cloned(); + } + } + + // 3. Apply random mutations to the whole population + for chromosome in self.population.iter_mut() { + if self.rng.random::() <= self.mutation_chance { + chromosome.mutate(&mut self.rng); + } + } + // 4. Select parents that will be mating to create new chromosomes + let mut new_population = Vec::with_capacity(self.population.len() + 1); + while new_population.len() < self.population.len() { + let (p1, p2) = self.selection.select(&self.population); + if self.rng.random::() <= self.crossover_chance { + let child = p1.crossover(p2, &mut self.rng); + new_population.push(child); + } else { + // keep parents + new_population.extend([p1.clone(), p2.clone()]); + } + } + if new_population.len() > self.population.len() { + // We might have added 2 parents + new_population.pop(); + } + self.population = new_population; + // 5. Rinse & Repeat until we find a proper solution or we reach the maximum number of generations + generations += 1; + } + None + } +} + +#[cfg(test)] +mod tests { + use crate::general::genetic::{ + Chromosome, GenericAlgorithmParams, GeneticAlgorithm, RouletteWheel, SelectionStrategy, + Tournament, + }; + use rand::rngs::ThreadRng; + use rand::{rng, Rng}; + use std::collections::HashMap; + use std::fmt::{Debug, Formatter}; + use std::ops::RangeInclusive; + + #[test] + #[ignore] // Too long and not deterministic enough to be part of CI, more of an example than a test + fn find_secret() { + let chars = 'a'..='z'; + let secret = "thisistopsecret".to_owned(); + // Note: we'll pick genes (a, b, c) in the range -10, 10 + #[derive(Clone)] + struct TestString { + chars: RangeInclusive, + secret: String, + genes: Vec, + } + impl TestString { + fn new(rng: &mut ThreadRng, secret: String, chars: RangeInclusive) -> Self { + let current = (0..secret.len()) + .map(|_| rng.random_range(chars.clone())) + .collect::>(); + + Self { + chars, + secret, + genes: current, + } + } + } + impl Debug for TestString { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.genes.iter().collect::()) + } + } + impl Chromosome for TestString { + fn mutate(&mut self, rng: &mut ThreadRng) { + // let's assume mutations happen completely randomly, one "gene" at a time (i.e. one char at a time) + let gene_idx = rng.random_range(0..self.secret.len()); + let new_char = rng.random_range(self.chars.clone()); + self.genes[gene_idx] = new_char; + } + + fn crossover(&self, other: &Self, rng: &mut ThreadRng) -> Self { + // Let's not assume anything here, simply mixing random genes from both parents + let genes = (0..self.secret.len()) + .map(|idx| { + if rng.random_bool(0.5) { + // pick gene from self + self.genes[idx] + } else { + // pick gene from other parent + other.genes[idx] + } + }) + .collect(); + Self { + chars: self.chars.clone(), + secret: self.secret.clone(), + genes, + } + } + + fn fitness(&self) -> i32 { + // We are just counting how many chars are distinct from secret + self.genes + .iter() + .zip(self.secret.chars()) + .filter(|(char, expected)| expected != *char) + .count() as i32 + } + } + let mut rng = rng(); + let pop_count = 1_000; + let mut population = Vec::with_capacity(pop_count); + for _ in 0..pop_count { + population.push(TestString::new(&mut rng, secret.clone(), chars.clone())); + } + let selection: Tournament<100, ThreadRng> = Tournament::new(rng.clone()); + let params = GenericAlgorithmParams { + max_generations: 100, + mutation_chance: 0.2, + crossover_chance: 0.4, + }; + let mut solver = + GeneticAlgorithm::init(rng, population, 0, params, Box::new(i32::cmp), selection); + let res = solver.solve(); + assert!(res.is_some()); + assert_eq!(res.unwrap().genes, secret.chars().collect::>()) + } + + #[test] + #[ignore] // Too long and not deterministic enough to be part of CI, more of an example than a test + fn solve_mastermind() { + #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] + enum ColoredPeg { + Red, + Yellow, + Green, + Blue, + White, + Black, + } + struct GuessAnswer { + right_pos: i32, // right color at the right pos + wrong_pos: i32, // right color, but at wrong pos + } + #[derive(Clone, Debug)] + struct CodeMaker { + // the player coming up with a secret code + code: [ColoredPeg; 4], + count_by_color: HashMap, + } + impl CodeMaker { + fn new(code: [ColoredPeg; 4]) -> Self { + let mut count_by_color = HashMap::with_capacity(4); + for peg in &code { + *count_by_color.entry(*peg).or_insert(0) += 1; + } + Self { + code, + count_by_color, + } + } + fn eval(&self, guess: &[ColoredPeg; 4]) -> GuessAnswer { + let mut right_pos = 0; + let mut wrong_pos = 0; + let mut idx_by_colors = self.count_by_color.clone(); + for (idx, color) in guess.iter().enumerate() { + if self.code[idx] == *color { + right_pos += 1; + let count = idx_by_colors.get_mut(color).unwrap(); + *count -= 1; // don't reuse to say "right color but wrong pos" + if *count == 0 { + idx_by_colors.remove(color); + } + } + } + for (idx, color) in guess.iter().enumerate() { + if self.code[idx] != *color { + // try to use another color + if let Some(count) = idx_by_colors.get_mut(color) { + *count -= 1; + if *count == 0 { + idx_by_colors.remove(color); + } + wrong_pos += 1; + } + } + } + GuessAnswer { + right_pos, + wrong_pos, + } + } + } + + #[derive(Clone)] + struct CodeBreaker { + maker: CodeMaker, // so that we can ask the code maker if our guess is good or not + guess: [ColoredPeg; 4], + } + impl Debug for CodeBreaker { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(format!("{:?}", self.guess).as_str()) + } + } + fn random_color(rng: &mut ThreadRng) -> ColoredPeg { + match rng.random_range(0..=5) { + 0 => ColoredPeg::Red, + 1 => ColoredPeg::Yellow, + 2 => ColoredPeg::Green, + 3 => ColoredPeg::Blue, + 4 => ColoredPeg::White, + _ => ColoredPeg::Black, + } + } + fn random_guess(rng: &mut ThreadRng) -> [ColoredPeg; 4] { + std::array::from_fn(|_| random_color(rng)) + } + impl Chromosome for CodeBreaker { + fn mutate(&mut self, rng: &mut ThreadRng) { + // change one random color + let idx = rng.random_range(0..4); + self.guess[idx] = random_color(rng); + } + + fn crossover(&self, other: &Self, rng: &mut ThreadRng) -> Self { + Self { + maker: self.maker.clone(), + guess: std::array::from_fn(|i| { + if rng.random::() < 0.5 { + self.guess[i] + } else { + other.guess[i] + } + }), + } + } + + fn fitness(&self) -> i32 { + // Ask the code maker for the result + let answer = self.maker.eval(&self.guess); + // Remember: we need to have fitness return 0 if the guess is good, and the higher number we return, the further we are from a proper solution + let mut res = 32; // worst case scenario, everything is wrong + res -= answer.right_pos * 8; // count 8 points for the right item at the right spot + res -= answer.wrong_pos; // count 1 point for having a right color + res + } + } + let code = [ + ColoredPeg::Red, + ColoredPeg::Red, + ColoredPeg::White, + ColoredPeg::Blue, + ]; + let maker = CodeMaker::new(code); + let population_count = 10; + let params = GenericAlgorithmParams { + max_generations: 100, + mutation_chance: 0.5, + crossover_chance: 0.3, + }; + let mut rng = rng(); + let mut initial_pop = Vec::with_capacity(population_count); + for _ in 0..population_count { + initial_pop.push(CodeBreaker { + maker: maker.clone(), + guess: random_guess(&mut rng), + }); + } + let selection = RouletteWheel { rng: rng.clone() }; + let mut solver = + GeneticAlgorithm::init(rng, initial_pop, 0, params, Box::new(i32::cmp), selection); + let res = solver.solve(); + assert!(res.is_some()); + assert_eq!(code, res.unwrap().guess); + } +} diff --git a/src/general/huffman_encoding.rs b/src/general/huffman_encoding.rs new file mode 100644 index 00000000000..fc26d3cb5ee --- /dev/null +++ b/src/general/huffman_encoding.rs @@ -0,0 +1,224 @@ +use std::{ + cmp::Ordering, + collections::{BTreeMap, BinaryHeap}, +}; + +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug, Default)] +pub struct HuffmanValue { + // For the `value` to overflow, the sum of frequencies should be bigger + // than u64. So we should be safe here + /// The encoded value + pub value: u64, + /// number of bits used (up to 64) + pub bits: u32, +} + +pub struct HuffmanNode { + pub left: Option>>, + pub right: Option>>, + pub symbol: Option, + pub frequency: u64, +} + +impl PartialEq for HuffmanNode { + fn eq(&self, other: &Self) -> bool { + self.frequency == other.frequency + } +} + +impl PartialOrd for HuffmanNode { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Eq for HuffmanNode {} + +impl Ord for HuffmanNode { + fn cmp(&self, other: &Self) -> Ordering { + self.frequency.cmp(&other.frequency).reverse() + } +} + +impl HuffmanNode { + /// Turn the tree into the map that can be used in encoding + pub fn get_alphabet( + height: u32, + path: u64, + node: &HuffmanNode, + map: &mut BTreeMap, + ) { + match node.symbol { + Some(s) => { + map.insert( + s, + HuffmanValue { + value: path, + bits: height, + }, + ); + } + None => { + Self::get_alphabet(height + 1, path, node.left.as_ref().unwrap(), map); + Self::get_alphabet( + height + 1, + path | (1 << height), + node.right.as_ref().unwrap(), + map, + ); + } + } + } +} + +pub struct HuffmanDictionary { + pub alphabet: BTreeMap, + pub root: HuffmanNode, +} + +impl HuffmanDictionary { + /// The list of alphabet symbols and their respective frequency should + /// be given as input + pub fn new(alphabet: &[(T, u64)]) -> Self { + let mut alph: BTreeMap = BTreeMap::new(); + let mut queue: BinaryHeap> = BinaryHeap::new(); + for (symbol, freq) in alphabet.iter() { + queue.push(HuffmanNode { + left: None, + right: None, + symbol: Some(*symbol), + frequency: *freq, + }); + } + for _ in 1..alphabet.len() { + let left = queue.pop().unwrap(); + let right = queue.pop().unwrap(); + let sm_freq = left.frequency + right.frequency; + queue.push(HuffmanNode { + left: Some(Box::new(left)), + right: Some(Box::new(right)), + symbol: None, + frequency: sm_freq, + }); + } + let root = queue.pop().unwrap(); + HuffmanNode::get_alphabet(0, 0, &root, &mut alph); + HuffmanDictionary { + alphabet: alph, + root, + } + } + pub fn encode(&self, data: &[T]) -> HuffmanEncoding { + let mut result = HuffmanEncoding::new(); + data.iter() + .for_each(|value| result.add_data(self.alphabet[value])); + result + } +} +pub struct HuffmanEncoding { + pub num_bits: u64, + pub data: Vec, +} + +impl Default for HuffmanEncoding { + fn default() -> Self { + Self::new() + } +} + +impl HuffmanEncoding { + pub fn new() -> Self { + HuffmanEncoding { + num_bits: 0, + data: vec![0], + } + } + #[inline] + pub fn add_data(&mut self, data: HuffmanValue) { + let shift = (self.num_bits & 63) as u32; + let val = data.value; + *self.data.last_mut().unwrap() |= val.wrapping_shl(shift); + if (shift + data.bits) >= 64 { + self.data.push(val.wrapping_shr(64 - shift)); + } + self.num_bits += data.bits as u64; + } + fn get_bit(&self, pos: u64) -> bool { + (self.data[(pos >> 6) as usize] & (1 << (pos & 63))) != 0 + } + /// In case the encoding is invalid, `None` is returned + pub fn decode(&self, dict: &HuffmanDictionary) -> Option> { + let mut state = &dict.root; + let mut result: Vec = vec![]; + for i in 0..self.num_bits { + if state.symbol.is_some() { + result.push(state.symbol.unwrap()); + state = &dict.root; + } + state = if self.get_bit(i) { + state.right.as_ref().unwrap() + } else { + state.left.as_ref().unwrap() + } + } + if self.num_bits > 0 { + result.push(state.symbol?); + } + Some(result) + } +} + +#[cfg(test)] +mod tests { + use super::*; + fn get_frequency(bytes: &[u8]) -> Vec<(u8, u64)> { + let mut cnts: Vec = vec![0; 256]; + bytes.iter().for_each(|&b| cnts[b as usize] += 1); + let mut result = vec![]; + cnts.iter() + .enumerate() + .filter(|(_, &v)| v > 0) + .for_each(|(b, &cnt)| result.push((b as u8, cnt))); + result + } + #[test] + fn small_text() { + let text = "Hello world"; + let bytes = text.as_bytes(); + let freq = get_frequency(bytes); + let dict = HuffmanDictionary::new(&freq); + let encoded = dict.encode(bytes); + assert_eq!(encoded.num_bits, 32); + let decoded = encoded.decode(&dict).unwrap(); + assert_eq!(decoded, bytes); + } + #[test] + fn lorem_ipsum() { + let text = concat!( + "The quick brown fox jumped over the lazy dog.", + "Lorem ipsum dolor sit amet, consectetur ", + "adipiscing elit, sed do eiusmod tempor incididunt ut labore et ", + "dolore magna aliqua. Facilisis magna etiam tempor orci. Nullam ", + "non nisi est sit amet facilisis magna. Commodo nulla facilisi ", + "nullam vehicula. Interdum posuere lorem ipsum dolor. Elit eget ", + "gravida cum sociis natoque penatibus. Dictum sit amet justo donec ", + "enim. Tempor commodo ullamcorper a lacus vestibulum sed. Nisl ", + "suscipit adipiscing bibendum est ultricies. Sit amet aliquam id ", + "diam maecenas ultricies." + ); + let bytes = text.as_bytes(); + let freq = get_frequency(bytes); + let dict = HuffmanDictionary::new(&freq); + let encoded = dict.encode(bytes); + assert_eq!(encoded.num_bits, 2372); + let decoded = encoded.decode(&dict).unwrap(); + assert_eq!(decoded, bytes); + + let text = "The dictionary should work on other texts too"; + let bytes = text.as_bytes(); + let encoded = dict.encode(bytes); + assert_eq!(encoded.num_bits, 215); + let decoded = encoded.decode(&dict).unwrap(); + assert_eq!(decoded, bytes); + } +} diff --git a/src/general/kadane_algorithm.rs b/src/general/kadane_algorithm.rs new file mode 100644 index 00000000000..a7452b62d23 --- /dev/null +++ b/src/general/kadane_algorithm.rs @@ -0,0 +1,86 @@ +/** + * @file + * @brief Find the maximum subarray sum using Kadane's algorithm.(https://en.wikipedia.org/wiki/Maximum_subarray_problem) + * + * @details + * This program provides a function to find the maximum subarray sum in an array of integers + * using Kadane's algorithm. + * + * @param arr A slice of integers representing the array. + * @return The maximum subarray sum. + * + * @author [Gyandeep] (https://github.com/Gyan172004) + * @see Wikipedia - Maximum subarray problem + */ + +/** + * Find the maximum subarray sum using Kadane's algorithm. + * @param arr A slice of integers representing the array. + * @return The maximum subarray sum. + */ +pub fn max_sub_array(nums: Vec) -> i32 { + if nums.is_empty() { + return 0; + } + + let mut max_current = nums[0]; + let mut max_global = nums[0]; + + nums.iter().skip(1).for_each(|&item| { + max_current = std::cmp::max(item, max_current + item); + if max_current > max_global { + max_global = max_current; + } + }); + max_global +} + +#[cfg(test)] +mod tests { + use super::*; + + /** + * Test case for Kadane's algorithm with positive numbers. + */ + #[test] + fn test_kadanes_algorithm_positive() { + let arr = [1, 2, 3, 4, 5]; + assert_eq!(max_sub_array(arr.to_vec()), 15); + } + + /** + * Test case for Kadane's algorithm with negative numbers. + */ + #[test] + fn test_kadanes_algorithm_negative() { + let arr = [-2, -3, -4, -1, -2]; + assert_eq!(max_sub_array(arr.to_vec()), -1); + } + + /** + * Test case for Kadane's algorithm with mixed numbers. + */ + #[test] + fn test_kadanes_algorithm_mixed() { + let arr = [-2, 1, -3, 4, -1, 2, 1, -5, 4]; + assert_eq!(max_sub_array(arr.to_vec()), 6); + } + + /** + * Test case for Kadane's algorithm with an empty array. + */ + #[test] + fn test_kadanes_algorithm_empty() { + let arr: [i32; 0] = []; + assert_eq!(max_sub_array(arr.to_vec()), 0); + } + + /** + * Test case for Kadane's algorithm with a single positive number. + */ + #[test] + fn test_kadanes_algorithm_single_positive() { + let arr = [10]; + assert_eq!(max_sub_array(arr.to_vec()), 10); + } +} diff --git a/src/general/kmeans.rs b/src/general/kmeans.rs index 1d9162ae35e..54022c36dd7 100644 --- a/src/general/kmeans.rs +++ b/src/general/kmeans.rs @@ -4,7 +4,6 @@ macro_rules! impl_kmeans { ($kind: ty, $modname: ident) => { // Since we can't overload methods in rust, we have to use namespacing pub mod $modname { - use std::$modname::INFINITY; /// computes sum of squared deviation between two identically sized vectors /// `x`, and `y`. @@ -22,7 +21,7 @@ macro_rules! impl_kmeans { // Find the argmin by folding using a tuple containing the argmin // and the minimum distance. let (argmin, _) = centroids.iter().enumerate().fold( - (0_usize, INFINITY), + (0_usize, <$kind>::INFINITY), |(min_ix, min_dist), (ix, ci)| { let dist = distance(xi, ci); if dist < min_dist { @@ -89,9 +88,8 @@ macro_rules! impl_kmeans { { // We need to use `return` to break out of the `loop` return clustering; - } else { - clustering = new_clustering; } + clustering = new_clustering; } } } diff --git a/src/general/mex.rs b/src/general/mex.rs new file mode 100644 index 00000000000..a0514a35c54 --- /dev/null +++ b/src/general/mex.rs @@ -0,0 +1,79 @@ +use std::collections::BTreeSet; + +// Find minimum excluded number from a set of given numbers using a set +/// Finds the MEX of the values provided in `arr` +/// Uses [`BTreeSet`](std::collections::BTreeSet) +/// O(nlog(n)) implementation +pub fn mex_using_set(arr: &[i64]) -> i64 { + let mut s: BTreeSet = BTreeSet::new(); + for i in 0..=arr.len() { + s.insert(i as i64); + } + for x in arr { + s.remove(x); + } + // TODO: change the next 10 lines to *s.first().unwrap() when merged into stable + // set should never have 0 elements + if let Some(x) = s.into_iter().next() { + x + } else { + panic!("Some unknown error in mex_using_set") + } +} +/// Finds the MEX of the values provided in `arr` +/// Uses sorting +/// O(nlog(n)) implementation +pub fn mex_using_sort(arr: &[i64]) -> i64 { + let mut arr = arr.to_vec(); + arr.sort(); + let mut mex = 0; + for x in arr { + if x == mex { + mex += 1; + } + } + mex +} + +#[cfg(test)] +mod tests { + use super::*; + struct MexTests { + test_arrays: Vec>, + outputs: Vec, + } + impl MexTests { + fn new() -> Self { + Self { + test_arrays: vec![ + vec![-1, 0, 1, 2, 3], + vec![-100, 0, 1, 2, 3, 5], + vec![-1000000, 0, 1, 2, 5], + vec![2, 0, 1, 2, 4], + vec![1, 2, 3, 0, 4], + vec![0, 1, 5, 2, 4, 3], + vec![0, 1, 2, 3, 4, 5, 6], + vec![0, 1, 2, 3, 4, 5, 6, 7], + vec![0, 1, 2, 3, 4, 5, 6, 7, 8], + ], + outputs: vec![4, 4, 3, 3, 5, 6, 7, 8, 9], + } + } + fn test_function(&self, f: fn(&[i64]) -> i64) { + for (nums, output) in self.test_arrays.iter().zip(self.outputs.iter()) { + assert_eq!(f(nums), *output); + } + } + } + #[test] + fn test_mex_using_set() { + let tests = MexTests::new(); + mex_using_set(&[1, 23, 3]); + tests.test_function(mex_using_set); + } + #[test] + fn test_mex_using_sort() { + let tests = MexTests::new(); + tests.test_function(mex_using_sort); + } +} diff --git a/src/general/mod.rs b/src/general/mod.rs index 00e75e950a6..3572b146f4a 100644 --- a/src/general/mod.rs +++ b/src/general/mod.rs @@ -1,8 +1,25 @@ mod convex_hull; +mod fisher_yates_shuffle; +mod genetic; mod hanoi; +mod huffman_encoding; +mod kadane_algorithm; mod kmeans; +mod mex; +mod permutations; +mod two_sum; pub use self::convex_hull::convex_hull_graham; +pub use self::fisher_yates_shuffle::fisher_yates_shuffle; +pub use self::genetic::GeneticAlgorithm; pub use self::hanoi::hanoi; +pub use self::huffman_encoding::{HuffmanDictionary, HuffmanEncoding}; +pub use self::kadane_algorithm::max_sub_array; pub use self::kmeans::f32::kmeans as kmeans_f32; pub use self::kmeans::f64::kmeans as kmeans_f64; +pub use self::mex::mex_using_set; +pub use self::mex::mex_using_sort; +pub use self::permutations::{ + heap_permute, permute, permute_unique, steinhaus_johnson_trotter_permute, +}; +pub use self::two_sum::two_sum; diff --git a/src/general/permutations/heap.rs b/src/general/permutations/heap.rs new file mode 100644 index 00000000000..b1c3b38d198 --- /dev/null +++ b/src/general/permutations/heap.rs @@ -0,0 +1,66 @@ +use std::fmt::Debug; + +/// Computes all permutations of an array using Heap's algorithm +/// Read `recurse_naive` first, since we're building on top of the same intuition +pub fn heap_permute(arr: &[T]) -> Vec> { + if arr.is_empty() { + return vec![vec![]]; + } + let n = arr.len(); + let mut collector = Vec::with_capacity((1..=n).product()); // collects the permuted arrays + let mut arr = arr.to_owned(); // Heap's algorithm needs to mutate the array + heap_recurse(&mut arr, n, &mut collector); + collector +} + +fn heap_recurse(arr: &mut [T], k: usize, collector: &mut Vec>) { + if k == 1 { + // same base-case as in the naive version + collector.push((*arr).to_owned()); + return; + } + // Remember the naive recursion. We did the following: swap(i, last), recurse, swap back(i, last) + // Heap's algorithm has a more clever way of permuting the elements so that we never need to swap back! + for i in 0..k { + // now deal with [a, b] + let swap_idx = if k % 2 == 0 { i } else { 0 }; + arr.swap(swap_idx, k - 1); + heap_recurse(arr, k - 1, collector); + } +} + +#[cfg(test)] +mod tests { + use quickcheck_macros::quickcheck; + + use crate::general::permutations::heap_permute; + use crate::general::permutations::tests::{ + assert_permutations, assert_valid_permutation, NotTooBigVec, + }; + + #[test] + fn test_3_different_values() { + let original = vec![1, 2, 3]; + let res = heap_permute(&original); + assert_eq!(res.len(), 6); // 3! + for permut in res { + assert_valid_permutation(&original, &permut) + } + } + + #[test] + fn test_3_times_the_same_value() { + let original = vec![1, 1, 1]; + let res = heap_permute(&original); + assert_eq!(res.len(), 6); // 3! + for permut in res { + assert_valid_permutation(&original, &permut) + } + } + + #[quickcheck] + fn test_some_elements(NotTooBigVec { inner: original }: NotTooBigVec) { + let permutations = heap_permute(&original); + assert_permutations(&original, &permutations) + } +} diff --git a/src/general/permutations/mod.rs b/src/general/permutations/mod.rs new file mode 100644 index 00000000000..0893eacb8ec --- /dev/null +++ b/src/general/permutations/mod.rs @@ -0,0 +1,93 @@ +mod heap; +mod naive; +mod steinhaus_johnson_trotter; + +pub use self::heap::heap_permute; +pub use self::naive::{permute, permute_unique}; +pub use self::steinhaus_johnson_trotter::steinhaus_johnson_trotter_permute; + +#[cfg(test)] +mod tests { + use quickcheck::{Arbitrary, Gen}; + use std::collections::HashMap; + + pub fn assert_permutations(original: &[i32], permutations: &[Vec]) { + if original.is_empty() { + assert_eq!(vec![vec![] as Vec], permutations); + return; + } + let n = original.len(); + assert_eq!((1..=n).product::(), permutations.len()); // n! + for permut in permutations { + assert_valid_permutation(original, permut); + } + } + + pub fn assert_valid_permutation(original: &[i32], permuted: &[i32]) { + assert_eq!(original.len(), permuted.len()); + let mut indices = HashMap::with_capacity(original.len()); + for value in original { + *indices.entry(*value).or_insert(0) += 1; + } + for permut_value in permuted { + let count = indices.get_mut(permut_value).unwrap_or_else(|| { + panic!("Value {permut_value} appears too many times in permutation") + }); + *count -= 1; // use this value + if *count == 0 { + indices.remove(permut_value); // so that we can simply check every value has been removed properly + } + } + assert!(indices.is_empty()) + } + + #[test] + fn test_valid_permutations() { + assert_valid_permutation(&[1, 2, 3], &[1, 2, 3]); + assert_valid_permutation(&[1, 2, 3], &[1, 3, 2]); + assert_valid_permutation(&[1, 2, 3], &[2, 1, 3]); + assert_valid_permutation(&[1, 2, 3], &[2, 3, 1]); + assert_valid_permutation(&[1, 2, 3], &[3, 1, 2]); + assert_valid_permutation(&[1, 2, 3], &[3, 2, 1]); + } + + #[test] + #[should_panic] + fn test_invalid_permutation_1() { + assert_valid_permutation(&[1, 2, 3], &[4, 2, 3]); + } + + #[test] + #[should_panic] + fn test_invalid_permutation_2() { + assert_valid_permutation(&[1, 2, 3], &[1, 4, 3]); + } + + #[test] + #[should_panic] + fn test_invalid_permutation_3() { + assert_valid_permutation(&[1, 2, 3], &[1, 2, 4]); + } + + #[test] + #[should_panic] + fn test_invalid_permutation_repeat() { + assert_valid_permutation(&[1, 2, 3], &[1, 2, 2]); + } + + /// A Data Structure for testing permutations + /// Holds a Vec with just a few items, so that it's not too long to compute permutations + #[derive(Debug, Clone)] + pub struct NotTooBigVec { + pub(crate) inner: Vec, // opaque type alias so that we can implement Arbitrary + } + + const MAX_SIZE: usize = 8; // 8! ~= 40k permutations already + impl Arbitrary for NotTooBigVec { + fn arbitrary(g: &mut Gen) -> Self { + let size = usize::arbitrary(g) % MAX_SIZE; + let res = (0..size).map(|_| i32::arbitrary(g)).collect(); + NotTooBigVec { inner: res } + } + } +} diff --git a/src/general/permutations/naive.rs b/src/general/permutations/naive.rs new file mode 100644 index 00000000000..748d763555e --- /dev/null +++ b/src/general/permutations/naive.rs @@ -0,0 +1,135 @@ +use std::collections::HashSet; +use std::fmt::Debug; +use std::hash::Hash; + +/// Here's a basic (naive) implementation for generating permutations +pub fn permute(arr: &[T]) -> Vec> { + if arr.is_empty() { + return vec![vec![]]; + } + let n = arr.len(); + let count = (1..=n).product(); // n! permutations + let mut collector = Vec::with_capacity(count); // collects the permuted arrays + let mut arr = arr.to_owned(); // we'll need to mutate the array + + // the idea is the following: imagine [a, b, c] + // always swap an item with the last item, then generate all permutations from the first k characters + // permute_recurse(arr, k - 1, collector); // leave the last character alone, and permute the first k-1 characters + permute_recurse(&mut arr, n, &mut collector); + collector +} + +fn permute_recurse(arr: &mut Vec, k: usize, collector: &mut Vec>) { + if k == 1 { + collector.push(arr.to_owned()); + return; + } + for i in 0..k { + arr.swap(i, k - 1); // swap i with the last character + permute_recurse(arr, k - 1, collector); // collect the permutations of the rest + arr.swap(i, k - 1); // swap back to original + } +} + +/// A common variation of generating permutations is to generate only unique permutations +/// Of course, we could use the version above together with a Set as collector instead of a Vec. +/// But let's try something different: how can we avoid to generate duplicated permutations in the first place, can we tweak the algorithm above? +pub fn permute_unique(arr: &[T]) -> Vec> { + if arr.is_empty() { + return vec![vec![]]; + } + let n = arr.len(); + let count = (1..=n).product(); // n! permutations + let mut collector = Vec::with_capacity(count); // collects the permuted arrays + let mut arr = arr.to_owned(); // Heap's algorithm needs to mutate the array + permute_recurse_unique(&mut arr, n, &mut collector); + collector +} + +fn permute_recurse_unique( + arr: &mut Vec, + k: usize, + collector: &mut Vec>, +) { + // We have the same base-case as previously, whenever we reach the first element in the array, collect the result + if k == 1 { + collector.push(arr.to_owned()); + return; + } + // We'll keep the same idea (swap with last item, and generate all permutations for the first k - 1) + // But we'll have to be careful though: how would we generate duplicates? + // Basically if, when swapping i with k-1, we generate the exact same array as in a previous iteration + // Imagine [a, a, b] + // i = 0: + // Swap (a, b) => [b, a, a], fix 'a' as last, and generate all permutations of [b, a] => [b, a, a], [a, b, a] + // Swap Back to [a, a, b] + // i = 1: + // Swap(a, b) => [b, a, a], we've done that already!! + let mut swapped = HashSet::with_capacity(k); + for i in 0..k { + if swapped.contains(&arr[i]) { + continue; + } + swapped.insert(arr[i]); + arr.swap(i, k - 1); // swap i with the last character + permute_recurse_unique(arr, k - 1, collector); // collect the permutations + arr.swap(i, k - 1); // go back to original + } +} + +#[cfg(test)] +mod tests { + use crate::general::permutations::naive::{permute, permute_unique}; + use crate::general::permutations::tests::{ + assert_permutations, assert_valid_permutation, NotTooBigVec, + }; + use quickcheck_macros::quickcheck; + use std::collections::HashSet; + + #[test] + fn test_3_different_values() { + let original = vec![1, 2, 3]; + let res = permute(&original); + assert_eq!(res.len(), 6); // 3! + for permut in res { + assert_valid_permutation(&original, &permut) + } + } + + #[test] + fn empty_array() { + let empty: std::vec::Vec = vec![]; + assert_eq!(permute(&empty), vec![vec![]]); + assert_eq!(permute_unique(&empty), vec![vec![]]); + } + + #[test] + fn test_3_times_the_same_value() { + let original = vec![1, 1, 1]; + let res = permute(&original); + assert_eq!(res.len(), 6); // 3! + for permut in res { + assert_valid_permutation(&original, &permut) + } + } + + #[quickcheck] + fn test_some_elements(NotTooBigVec { inner: original }: NotTooBigVec) { + let permutations = permute(&original); + assert_permutations(&original, &permutations) + } + + #[test] + fn test_unique_values() { + let original = vec![1, 1, 2, 2]; + let unique_permutations = permute_unique(&original); + let every_permutation = permute(&original); + for unique_permutation in &unique_permutations { + assert!(every_permutation.contains(unique_permutation)); + } + assert_eq!( + unique_permutations.len(), + every_permutation.iter().collect::>().len() + ) + } +} diff --git a/src/general/permutations/steinhaus_johnson_trotter.rs b/src/general/permutations/steinhaus_johnson_trotter.rs new file mode 100644 index 00000000000..4784a927ccc --- /dev/null +++ b/src/general/permutations/steinhaus_johnson_trotter.rs @@ -0,0 +1,61 @@ +/// +pub fn steinhaus_johnson_trotter_permute(array: &[T]) -> Vec> { + let len = array.len(); + let mut array = array.to_owned(); + let mut inversion_vector = vec![0; len]; + let mut i = 1; + let mut res = Vec::with_capacity((1..=len).product()); + res.push(array.clone()); + while i < len { + if inversion_vector[i] < i { + if i % 2 == 0 { + array.swap(0, i); + } else { + array.swap(inversion_vector[i], i); + } + res.push(array.to_vec()); + inversion_vector[i] += 1; + i = 1; + } else { + inversion_vector[i] = 0; + i += 1; + } + } + res +} + +#[cfg(test)] +mod tests { + use quickcheck_macros::quickcheck; + + use crate::general::permutations::steinhaus_johnson_trotter::steinhaus_johnson_trotter_permute; + use crate::general::permutations::tests::{ + assert_permutations, assert_valid_permutation, NotTooBigVec, + }; + + #[test] + fn test_3_different_values() { + let original = vec![1, 2, 3]; + let res = steinhaus_johnson_trotter_permute(&original); + assert_eq!(res.len(), 6); // 3! + for permut in res { + assert_valid_permutation(&original, &permut) + } + } + + #[test] + fn test_3_times_the_same_value() { + let original = vec![1, 1, 1]; + let res = steinhaus_johnson_trotter_permute(&original); + assert_eq!(res.len(), 6); // 3! + for permut in res { + assert_valid_permutation(&original, &permut) + } + } + + #[quickcheck] + fn test_some_elements(NotTooBigVec { inner: original }: NotTooBigVec) { + let permutations = steinhaus_johnson_trotter_permute(&original); + assert_permutations(&original, &permutations) + } +} diff --git a/src/general/two_sum.rs b/src/general/two_sum.rs new file mode 100644 index 00000000000..3c82ab8f5a8 --- /dev/null +++ b/src/general/two_sum.rs @@ -0,0 +1,65 @@ +use std::collections::HashMap; + +/// Given an array of integers nums and an integer target, +/// return indices of the two numbers such that they add up to target. +/// +/// # Parameters +/// +/// - `nums`: A list of numbers to check. +/// - `target`: The target sum. +/// +/// # Returns +/// +/// If the target sum is found in the array, the indices of the augend and +/// addend are returned as a tuple. +/// +/// If the target sum cannot be found in the array, `None` is returned. +/// +pub fn two_sum(nums: Vec, target: i32) -> Option<(usize, usize)> { + // This HashMap is used to look up a corresponding index in the `nums` list. + // Given that we know where we are at in the array, we can look up our + // complementary value using this table and only go through the list once. + // + // We populate this table with distances from the target. As we go through + // the list, a distance is computed like so: + // + // `target - current_value` + // + // This distance also tells us about the complementary value we're looking + // for in the array. If we don't find that value, we insert `current_value` + // into the table for future look-ups. As we iterate through the list, + // the number we just inserted might be the perfect distance for another + // number, and we've found a match! + // + let mut distance_table: HashMap = HashMap::new(); + + for (i, current_value) in nums.iter().enumerate() { + match distance_table.get(&(target - current_value)) { + Some(j) => return Some((i, *j)), + _ => distance_table.insert(*current_value, i), + }; + } + + // No match was found! + None +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test() { + let nums = vec![2, 7, 11, 15]; + assert_eq!(two_sum(nums, 9), Some((1, 0))); + + let nums = vec![3, 2, 4]; + assert_eq!(two_sum(nums, 6), Some((2, 1))); + + let nums = vec![3, 3]; + assert_eq!(two_sum(nums, 6), Some((1, 0))); + + let nums = vec![2, 7, 11, 15]; + assert_eq!(two_sum(nums, 16), None); + } +} diff --git a/src/geometry/closest_points.rs b/src/geometry/closest_points.rs new file mode 100644 index 00000000000..e92dc562501 --- /dev/null +++ b/src/geometry/closest_points.rs @@ -0,0 +1,244 @@ +use crate::geometry::Point; +use std::cmp::Ordering; + +fn cmp_x(p1: &Point, p2: &Point) -> Ordering { + let acmp = f64_cmp(&p1.x, &p2.x); + match acmp { + Ordering::Equal => f64_cmp(&p1.y, &p2.y), + _ => acmp, + } +} + +fn f64_cmp(a: &f64, b: &f64) -> Ordering { + a.partial_cmp(b).unwrap() +} + +/// returns the two closest points +/// or None if there are zero or one point +pub fn closest_points(points: &[Point]) -> Option<(Point, Point)> { + let mut points_x: Vec = points.to_vec(); + points_x.sort_by(cmp_x); + let mut points_y = points_x.clone(); + points_y.sort_by(|p1: &Point, p2: &Point| -> Ordering { p1.y.partial_cmp(&p2.y).unwrap() }); + + closest_points_aux(&points_x, points_y, 0, points_x.len()) +} + +// We maintain two vectors with the same points, one sort by x coordinates and one sorted by y +// coordinates. +fn closest_points_aux( + points_x: &[Point], + points_y: Vec, + mut start: usize, + mut end: usize, +) -> Option<(Point, Point)> { + let n = end - start; + + if n <= 1 { + return None; + } + + if n <= 3 { + // bruteforce + let mut min = points_x[0].euclidean_distance(&points_x[1]); + let mut pair = (points_x[0].clone(), points_x[1].clone()); + + for i in 1..n { + for j in (i + 1)..n { + let new = points_x[i].euclidean_distance(&points_x[j]); + if new < min { + min = new; + pair = (points_x[i].clone(), points_x[j].clone()); + } + } + } + return Some(pair); + } + + let mid = start + (end - start) / 2; + let mid_x = points_x[mid].x; + + // Separate points into y_left and y_right vectors based on their x-coordinate. Since y is + // already sorted by y_axis, y_left and y_right will also be sorted. + let mut y_left = vec![]; + let mut y_right = vec![]; + for point in &points_y { + if point.x < mid_x { + y_left.push(point.clone()); + } else { + y_right.push(point.clone()); + } + } + + let left = closest_points_aux(points_x, y_left, start, mid); + let right = closest_points_aux(points_x, y_right, mid, end); + + let (mut min_sqr_dist, mut pair) = match (left, right) { + (Some((l1, l2)), Some((r1, r2))) => { + let dl = l1.euclidean_distance(&l2); + let dr = r1.euclidean_distance(&r2); + if dl < dr { + (dl, (l1, l2)) + } else { + (dr, (r1, r2)) + } + } + (Some((a, b)), None) | (None, Some((a, b))) => (a.euclidean_distance(&b), (a, b)), + (None, None) => unreachable!(), + }; + + let dist = min_sqr_dist; + while points_x[start].x < mid_x - dist { + start += 1; + } + while points_x[end - 1].x > mid_x + dist { + end -= 1; + } + + for (i, e) in points_y.iter().enumerate() { + for k in 1..8 { + if i + k >= points_y.len() { + break; + } + + let new = e.euclidean_distance(&points_y[i + k]); + if new < min_sqr_dist { + min_sqr_dist = new; + pair = ((*e).clone(), points_y[i + k].clone()); + } + } + } + + Some(pair) +} + +#[cfg(test)] +mod tests { + use super::closest_points; + use super::Point; + + fn eq(p1: Option<(Point, Point)>, p2: Option<(Point, Point)>) -> bool { + match (p1, p2) { + (None, None) => true, + (Some((p1, p2)), Some((p3, p4))) => (p1 == p3 && p2 == p4) || (p1 == p4 && p2 == p3), + _ => false, + } + } + + macro_rules! assert_display { + ($left: expr, $right: expr) => { + assert!( + eq($left, $right), + "assertion failed: `(left == right)`\nleft: `{:?}`,\nright: `{:?}`", + $left, + $right + ) + }; + } + + #[test] + fn zero_points() { + let vals: [Point; 0] = []; + assert_display!(closest_points(&vals), None::<(Point, Point)>); + } + + #[test] + fn one_points() { + let vals = [Point::new(0., 0.)]; + assert_display!(closest_points(&vals), None::<(Point, Point)>); + } + + #[test] + fn two_points() { + let vals = [Point::new(0., 0.), Point::new(1., 1.)]; + assert_display!( + closest_points(&vals), + Some((vals[0].clone(), vals[1].clone())) + ); + } + + #[test] + fn three_points() { + let vals = [Point::new(0., 0.), Point::new(1., 1.), Point::new(3., 3.)]; + assert_display!( + closest_points(&vals), + Some((vals[0].clone(), vals[1].clone())) + ); + } + + #[test] + fn list_1() { + let vals = [ + Point::new(0., 0.), + Point::new(2., 1.), + Point::new(5., 2.), + Point::new(2., 3.), + Point::new(4., 0.), + Point::new(0., 4.), + Point::new(5., 6.), + Point::new(4., 4.), + Point::new(7., 3.), + Point::new(-1., 2.), + Point::new(2., 6.), + ]; + assert_display!( + closest_points(&vals), + Some((Point::new(2., 1.), Point::new(2., 3.))) + ); + } + + #[test] + fn list_2() { + let vals = [ + Point::new(1., 3.), + Point::new(4., 6.), + Point::new(8., 8.), + Point::new(7., 5.), + Point::new(5., 3.), + Point::new(10., 3.), + Point::new(7., 1.), + Point::new(8., 3.), + Point::new(4., 9.), + Point::new(4., 12.), + Point::new(4., 15.), + Point::new(7., 14.), + Point::new(8., 12.), + Point::new(6., 10.), + Point::new(4., 14.), + Point::new(2., 7.), + Point::new(3., 8.), + Point::new(5., 8.), + Point::new(6., 7.), + Point::new(8., 10.), + Point::new(6., 12.), + ]; + assert_display!( + closest_points(&vals), + Some((Point::new(4., 14.), Point::new(4., 15.))) + ); + } + + #[test] + fn vertical_points() { + let vals = [ + Point::new(0., 0.), + Point::new(0., 50.), + Point::new(0., -25.), + Point::new(0., 40.), + Point::new(0., 42.), + Point::new(0., 100.), + Point::new(0., 17.), + Point::new(0., 29.), + Point::new(0., -50.), + Point::new(0., 37.), + Point::new(0., 34.), + Point::new(0., 8.), + Point::new(0., 3.), + Point::new(0., 46.), + ]; + assert_display!( + closest_points(&vals), + Some((Point::new(0., 40.), Point::new(0., 42.))) + ); + } +} diff --git a/src/geometry/graham_scan.rs b/src/geometry/graham_scan.rs new file mode 100644 index 00000000000..cdd512e7510 --- /dev/null +++ b/src/geometry/graham_scan.rs @@ -0,0 +1,206 @@ +use crate::geometry::Point; +use std::cmp::Ordering; + +fn point_min(a: &&Point, b: &&Point) -> Ordering { + // Find the bottom-most point. In the case of a tie, find the left-most. + if a.y == b.y { + a.x.partial_cmp(&b.x).unwrap() + } else { + a.y.partial_cmp(&b.y).unwrap() + } +} + +// Returns a Vec of Points that make up the convex hull of `points`. Returns an empty Vec if there +// is no convex hull. +pub fn graham_scan(mut points: Vec) -> Vec { + if points.len() <= 2 { + return vec![]; + } + + let min_point = points.iter().min_by(point_min).unwrap().clone(); + points.retain(|p| p != &min_point); + if points.is_empty() { + // edge case where all the points are the same + return vec![]; + } + + let point_cmp = |a: &Point, b: &Point| -> Ordering { + // Sort points in counter-clockwise direction relative to the min point. We can this by + // checking the orientation of consecutive vectors (min_point, a) and (a, b). + let orientation = min_point.consecutive_orientation(a, b); + if orientation < 0.0 { + Ordering::Greater + } else if orientation > 0.0 { + Ordering::Less + } else { + let a_dist = min_point.euclidean_distance(a); + let b_dist = min_point.euclidean_distance(b); + // When two points have the same relative angle to the min point, we should only + // include the further point in the convex hull. We sort further points into a lower + // index, and in the algorithm, remove all consecutive points with the same relative + // angle. + b_dist.partial_cmp(&a_dist).unwrap() + } + }; + points.sort_by(point_cmp); + let mut convex_hull: Vec = vec![]; + + // We always add the min_point, and the first two points in the sorted vec. + convex_hull.push(min_point.clone()); + convex_hull.push(points[0].clone()); + let mut top = 1; + for point in points.iter().skip(1) { + if min_point.consecutive_orientation(point, &convex_hull[top]) == 0.0 { + // Remove consecutive points with the same angle. We make sure include the furthest + // point in the convex hull in the sort comparator. + continue; + } + loop { + // In this loop, we remove points that we determine are no longer part of the convex + // hull. + if top <= 1 { + break; + } + // If there is a segment(i+1, i+2) turns right relative to segment(i, i+1), point(i+1) + // is not part of the convex hull. + let orientation = + convex_hull[top - 1].consecutive_orientation(&convex_hull[top], point); + if orientation <= 0.0 { + top -= 1; + convex_hull.pop(); + } else { + break; + } + } + convex_hull.push(point.clone()); + top += 1; + } + if convex_hull.len() <= 2 { + return vec![]; + } + convex_hull +} + +#[cfg(test)] +mod tests { + use super::graham_scan; + use super::Point; + + fn test_graham(convex_hull: Vec, others: Vec) { + let mut points = convex_hull.clone(); + points.append(&mut others.clone()); + let graham = graham_scan(points); + for point in convex_hull { + assert!(graham.contains(&point)); + } + for point in others { + assert!(!graham.contains(&point)); + } + } + + #[test] + fn too_few_points() { + test_graham(vec![], vec![]); + test_graham(vec![], vec![Point::new(0.0, 0.0)]); + } + + #[test] + fn duplicate_point() { + let p = Point::new(0.0, 0.0); + test_graham(vec![], vec![p.clone(), p.clone(), p.clone(), p.clone(), p]); + } + + #[test] + fn points_same_line() { + let p1 = Point::new(1.0, 0.0); + let p2 = Point::new(2.0, 0.0); + let p3 = Point::new(3.0, 0.0); + let p4 = Point::new(4.0, 0.0); + let p5 = Point::new(5.0, 0.0); + // let p6 = Point::new(1.0, 1.0); + test_graham(vec![], vec![p1, p2, p3, p4, p5]); + } + + #[test] + fn triangle() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(1.5, 2.0); + let points = vec![p1, p2, p3]; + test_graham(points, vec![]); + } + + #[test] + fn rectangle() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(2.0, 2.0); + let p4 = Point::new(1.0, 2.0); + let points = vec![p1, p2, p3, p4]; + test_graham(points, vec![]); + } + + #[test] + fn triangle_with_points_in_middle() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(1.5, 2.0); + let p4 = Point::new(1.5, 1.5); + let p5 = Point::new(1.2, 1.3); + let p6 = Point::new(1.8, 1.2); + let p7 = Point::new(1.5, 1.9); + let hull = vec![p1, p2, p3]; + let others = vec![p4, p5, p6, p7]; + test_graham(hull, others); + } + + #[test] + fn rectangle_with_points_in_middle() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(2.0, 2.0); + let p4 = Point::new(1.0, 2.0); + let p5 = Point::new(1.5, 1.5); + let p6 = Point::new(1.2, 1.3); + let p7 = Point::new(1.8, 1.2); + let p8 = Point::new(1.9, 1.7); + let p9 = Point::new(1.4, 1.9); + let hull = vec![p1, p2, p3, p4]; + let others = vec![p5, p6, p7, p8, p9]; + test_graham(hull, others); + } + + #[test] + fn star() { + // A single stroke star shape (kind of). Only the tips(p1-5) are part of the convex hull. The + // other points would create angles >180 degrees if they were part of the polygon. + let p1 = Point::new(-5.0, 6.0); + let p2 = Point::new(-11.0, 0.0); + let p3 = Point::new(-9.0, -8.0); + let p4 = Point::new(4.0, 4.0); + let p5 = Point::new(6.0, -7.0); + let p6 = Point::new(-7.0, -2.0); + let p7 = Point::new(-2.0, -4.0); + let p8 = Point::new(0.0, 1.0); + let p9 = Point::new(1.0, 0.0); + let p10 = Point::new(-6.0, 1.0); + let hull = vec![p1, p2, p3, p4, p5]; + let others = vec![p6, p7, p8, p9, p10]; + test_graham(hull, others); + } + + #[test] + fn rectangle_with_points_on_same_line() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(2.0, 2.0); + let p4 = Point::new(1.0, 2.0); + let p5 = Point::new(1.5, 1.0); + let p6 = Point::new(1.0, 1.5); + let p7 = Point::new(2.0, 1.5); + let p8 = Point::new(1.5, 2.0); + let hull = vec![p1, p2, p3, p4]; + let others = vec![p5, p6, p7, p8]; + test_graham(hull, others); + } +} diff --git a/src/geometry/jarvis_scan.rs b/src/geometry/jarvis_scan.rs new file mode 100644 index 00000000000..c8999c910ae --- /dev/null +++ b/src/geometry/jarvis_scan.rs @@ -0,0 +1,193 @@ +use crate::geometry::Point; +use crate::geometry::Segment; + +// Returns a Vec of Points that make up the convex hull of `points`. Returns an empty Vec if there +// is no convex hull. +pub fn jarvis_march(points: Vec) -> Vec { + if points.len() <= 2 { + return vec![]; + } + + let mut convex_hull = vec![]; + let mut left_point = 0; + for i in 1..points.len() { + // Find the initial point, which is the leftmost point. In the case of a tie, we take the + // bottom-most point. This helps prevent adding colinear points on the last segment to the hull. + if points[i].x < points[left_point].x + || (points[i].x == points[left_point].x && points[i].y < points[left_point].y) + { + left_point = i; + } + } + convex_hull.push(points[left_point].clone()); + + let mut p = left_point; + loop { + // Find the next counter-clockwise point. + let mut next_p = (p + 1) % points.len(); + for i in 0..points.len() { + let orientation = points[p].consecutive_orientation(&points[i], &points[next_p]); + if orientation > 0.0 { + next_p = i; + } + } + + if next_p == left_point { + // Completed constructing the hull. Exit the loop. + break; + } + p = next_p; + + let last = convex_hull.len() - 1; + if convex_hull.len() > 1 + && Segment::from_points(points[p].clone(), convex_hull[last - 1].clone()) + .on_segment(&convex_hull[last]) + { + // If the last point lies on the segment with the new point and the second to last + // point, we can remove the last point from the hull. + convex_hull[last] = points[p].clone(); + } else { + convex_hull.push(points[p].clone()); + } + } + + if convex_hull.len() <= 2 { + return vec![]; + } + let last = convex_hull.len() - 1; + if Segment::from_points(convex_hull[0].clone(), convex_hull[last - 1].clone()) + .on_segment(&convex_hull[last]) + { + // Check for the edge case where the last point lies on the segment with the zero'th and + // second the last point. In this case, we remove the last point from the hull. + convex_hull.pop(); + if convex_hull.len() == 2 { + return vec![]; + } + } + convex_hull +} + +#[cfg(test)] +mod tests { + use super::jarvis_march; + use super::Point; + + fn test_jarvis(convex_hull: Vec, others: Vec) { + let mut points = others.clone(); + points.append(&mut convex_hull.clone()); + let jarvis = jarvis_march(points); + for point in convex_hull { + assert!(jarvis.contains(&point)); + } + for point in others { + assert!(!jarvis.contains(&point)); + } + } + + #[test] + fn too_few_points() { + test_jarvis(vec![], vec![]); + test_jarvis(vec![], vec![Point::new(0.0, 0.0)]); + } + + #[test] + fn duplicate_point() { + let p = Point::new(0.0, 0.0); + test_jarvis(vec![], vec![p.clone(), p.clone(), p.clone(), p.clone(), p]); + } + + #[test] + fn points_same_line() { + let p1 = Point::new(1.0, 0.0); + let p2 = Point::new(2.0, 0.0); + let p3 = Point::new(3.0, 0.0); + let p4 = Point::new(4.0, 0.0); + let p5 = Point::new(5.0, 0.0); + // let p6 = Point::new(1.0, 1.0); + test_jarvis(vec![], vec![p1, p2, p3, p4, p5]); + } + + #[test] + fn triangle() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(1.5, 2.0); + let points = vec![p1, p2, p3]; + test_jarvis(points, vec![]); + } + + #[test] + fn rectangle() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(2.0, 2.0); + let p4 = Point::new(1.0, 2.0); + let points = vec![p1, p2, p3, p4]; + test_jarvis(points, vec![]); + } + + #[test] + fn triangle_with_points_in_middle() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(1.5, 2.0); + let p4 = Point::new(1.5, 1.5); + let p5 = Point::new(1.2, 1.3); + let p6 = Point::new(1.8, 1.2); + let p7 = Point::new(1.5, 1.9); + let hull = vec![p1, p2, p3]; + let others = vec![p4, p5, p6, p7]; + test_jarvis(hull, others); + } + + #[test] + fn rectangle_with_points_in_middle() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(2.0, 2.0); + let p4 = Point::new(1.0, 2.0); + let p5 = Point::new(1.5, 1.5); + let p6 = Point::new(1.2, 1.3); + let p7 = Point::new(1.8, 1.2); + let p8 = Point::new(1.9, 1.7); + let p9 = Point::new(1.4, 1.9); + let hull = vec![p1, p2, p3, p4]; + let others = vec![p5, p6, p7, p8, p9]; + test_jarvis(hull, others); + } + + #[test] + fn star() { + // A single stroke star shape (kind of). Only the tips(p1-5) are part of the convex hull. The + // other points would create angles >180 degrees if they were part of the polygon. + let p1 = Point::new(-5.0, 6.0); + let p2 = Point::new(-11.0, 0.0); + let p3 = Point::new(-9.0, -8.0); + let p4 = Point::new(4.0, 4.0); + let p5 = Point::new(6.0, -7.0); + let p6 = Point::new(-7.0, -2.0); + let p7 = Point::new(-2.0, -4.0); + let p8 = Point::new(0.0, 1.0); + let p9 = Point::new(1.0, 0.0); + let p10 = Point::new(-6.0, 1.0); + let hull = vec![p1, p2, p3, p4, p5]; + let others = vec![p6, p7, p8, p9, p10]; + test_jarvis(hull, others); + } + + #[test] + fn rectangle_with_points_on_same_line() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(2.0, 2.0); + let p4 = Point::new(1.0, 2.0); + let p5 = Point::new(1.5, 1.0); + let p6 = Point::new(1.0, 1.5); + let p7 = Point::new(2.0, 1.5); + let p8 = Point::new(1.5, 2.0); + let hull = vec![p1, p2, p3, p4]; + let others = vec![p5, p6, p7, p8]; + test_jarvis(hull, others); + } +} diff --git a/src/geometry/mod.rs b/src/geometry/mod.rs new file mode 100644 index 00000000000..e883cc004bc --- /dev/null +++ b/src/geometry/mod.rs @@ -0,0 +1,15 @@ +mod closest_points; +mod graham_scan; +mod jarvis_scan; +mod point; +mod polygon_points; +mod ramer_douglas_peucker; +mod segment; + +pub use self::closest_points::closest_points; +pub use self::graham_scan::graham_scan; +pub use self::jarvis_scan::jarvis_march; +pub use self::point::Point; +pub use self::polygon_points::lattice_points; +pub use self::ramer_douglas_peucker::ramer_douglas_peucker; +pub use self::segment::Segment; diff --git a/src/geometry/point.rs b/src/geometry/point.rs new file mode 100644 index 00000000000..4b4bd8db501 --- /dev/null +++ b/src/geometry/point.rs @@ -0,0 +1,38 @@ +use std::ops::Sub; + +#[derive(Clone, Debug, PartialEq)] +pub struct Point { + pub x: f64, + pub y: f64, +} + +impl Point { + pub fn new(x: f64, y: f64) -> Point { + Point { x, y } + } + + // Returns the orientation of consecutive segments ab and bc. + pub fn consecutive_orientation(&self, b: &Point, c: &Point) -> f64 { + let p1 = b - self; + let p2 = c - self; + p1.cross_prod(&p2) + } + + pub fn cross_prod(&self, other: &Point) -> f64 { + self.x * other.y - self.y * other.x + } + + pub fn euclidean_distance(&self, other: &Point) -> f64 { + ((self.x - other.x).powi(2) + (self.y - other.y).powi(2)).sqrt() + } +} + +impl Sub for &Point { + type Output = Point; + + fn sub(self, other: Self) -> Point { + let x = self.x - other.x; + let y = self.y - other.y; + Point::new(x, y) + } +} diff --git a/src/geometry/polygon_points.rs b/src/geometry/polygon_points.rs new file mode 100644 index 00000000000..7d492681cb7 --- /dev/null +++ b/src/geometry/polygon_points.rs @@ -0,0 +1,84 @@ +type Ll = i64; +type Pll = (Ll, Ll); + +fn cross(x1: Ll, y1: Ll, x2: Ll, y2: Ll) -> Ll { + x1 * y2 - x2 * y1 +} + +pub fn polygon_area(pts: &[Pll]) -> Ll { + let mut ats = 0; + for i in 2..pts.len() { + ats += cross( + pts[i].0 - pts[0].0, + pts[i].1 - pts[0].1, + pts[i - 1].0 - pts[0].0, + pts[i - 1].1 - pts[0].1, + ); + } + Ll::abs(ats / 2) +} + +fn gcd(mut a: Ll, mut b: Ll) -> Ll { + while b != 0 { + let temp = b; + b = a % b; + a = temp; + } + a +} + +fn boundary(pts: &[Pll]) -> Ll { + let mut ats = pts.len() as Ll; + for i in 0..pts.len() { + let deltax = pts[i].0 - pts[(i + 1) % pts.len()].0; + let deltay = pts[i].1 - pts[(i + 1) % pts.len()].1; + ats += Ll::abs(gcd(deltax, deltay)) - 1; + } + ats +} + +pub fn lattice_points(pts: &[Pll]) -> Ll { + let bounds = boundary(pts); + let area = polygon_area(pts); + area + 1 - bounds / 2 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_calculate_cross() { + assert_eq!(cross(1, 2, 3, 4), 4 - 3 * 2); + } + + #[test] + fn test_polygon_3_coordinates() { + let pts = vec![(0, 0), (0, 3), (4, 0)]; + assert_eq!(polygon_area(&pts), 6); + } + + #[test] + fn test_polygon_4_coordinates() { + let pts = vec![(0, 0), (0, 2), (2, 2), (2, 0)]; + assert_eq!(polygon_area(&pts), 4); + } + + #[test] + fn test_gcd_multiple_of_common_factor() { + assert_eq!(gcd(14, 28), 14); + } + + #[test] + fn test_boundary() { + let pts = vec![(0, 0), (0, 3), (0, 4), (2, 2)]; + assert_eq!(boundary(&pts), 8); + } + + #[test] + fn test_lattice_points() { + let pts = vec![(1, 1), (5, 1), (5, 4)]; + let result = lattice_points(&pts); + assert_eq!(result, 3); + } +} diff --git a/src/geometry/ramer_douglas_peucker.rs b/src/geometry/ramer_douglas_peucker.rs new file mode 100644 index 00000000000..ca9d53084b7 --- /dev/null +++ b/src/geometry/ramer_douglas_peucker.rs @@ -0,0 +1,115 @@ +use crate::geometry::Point; + +pub fn ramer_douglas_peucker(points: &[Point], epsilon: f64) -> Vec { + if points.len() < 3 { + return points.to_vec(); + } + let mut dmax = 0.0; + let mut index = 0; + let end = points.len() - 1; + + for i in 1..end { + let d = perpendicular_distance(&points[i], &points[0], &points[end]); + if d > dmax { + index = i; + dmax = d; + } + } + + if dmax > epsilon { + let mut results = ramer_douglas_peucker(&points[..=index], epsilon); + results.pop(); + results.extend(ramer_douglas_peucker(&points[index..], epsilon)); + results + } else { + vec![points[0].clone(), points[end].clone()] + } +} + +fn perpendicular_distance(p: &Point, a: &Point, b: &Point) -> f64 { + let num = (b.y - a.y) * p.x - (b.x - a.x) * p.y + b.x * a.y - b.y * a.x; + let den = a.euclidean_distance(b); + num.abs() / den +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_perpendicular_distance { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (p, a, b, expected) = $test_case; + assert_eq!(perpendicular_distance(&p, &a, &b), expected); + assert_eq!(perpendicular_distance(&p, &b, &a), expected); + } + )* + }; + } + + test_perpendicular_distance! { + basic: (Point::new(4.0, 0.0), Point::new(0.0, 0.0), Point::new(0.0, 3.0), 4.0), + basic_shifted_1: (Point::new(4.0, 1.0), Point::new(0.0, 1.0), Point::new(0.0, 4.0), 4.0), + basic_shifted_2: (Point::new(2.0, 1.0), Point::new(-2.0, 1.0), Point::new(-2.0, 4.0), 4.0), + } + + #[test] + fn test_ramer_douglas_peucker_polygon() { + let a = Point::new(0.0, 0.0); + let b = Point::new(1.0, 0.0); + let c = Point::new(2.0, 0.0); + let d = Point::new(2.0, 1.0); + let e = Point::new(2.0, 2.0); + let f = Point::new(1.0, 2.0); + let g = Point::new(0.0, 2.0); + let h = Point::new(0.0, 1.0); + let polygon = vec![ + a.clone(), + b, + c.clone(), + d, + e.clone(), + f, + g.clone(), + h.clone(), + ]; + let epsilon = 0.7; + let result = ramer_douglas_peucker(&polygon, epsilon); + assert_eq!(result, vec![a, c, e, g, h]); + } + + #[test] + fn test_ramer_douglas_peucker_polygonal_chain() { + let a = Point::new(0., 0.); + let b = Point::new(2., 0.5); + let c = Point::new(3., 3.); + let d = Point::new(6., 3.); + let e = Point::new(8., 4.); + + let points = vec![a.clone(), b, c, d, e.clone()]; + + let epsilon = 3.; // The epsilon is quite large, so the result will be a single line + let result = ramer_douglas_peucker(&points, epsilon); + assert_eq!(result, vec![a, e]); + } + + #[test] + fn test_less_than_three_points() { + let a = Point::new(0., 0.); + let b = Point::new(1., 1.); + + let epsilon = 0.1; + + assert_eq!(ramer_douglas_peucker(&[], epsilon), vec![]); + assert_eq!( + ramer_douglas_peucker(&[a.clone()], epsilon), + vec![a.clone()] + ); + assert_eq!( + ramer_douglas_peucker(&[a.clone(), b.clone()], epsilon), + vec![a, b] + ); + } +} diff --git a/src/geometry/segment.rs b/src/geometry/segment.rs new file mode 100644 index 00000000000..e43162e002f --- /dev/null +++ b/src/geometry/segment.rs @@ -0,0 +1,193 @@ +use super::Point; + +const TOLERANCE: f64 = 0.0001; + +pub struct Segment { + pub a: Point, + pub b: Point, +} + +impl Segment { + pub fn new(x1: f64, y1: f64, x2: f64, y2: f64) -> Segment { + Segment { + a: Point::new(x1, y1), + b: Point::new(x2, y2), + } + } + + pub fn from_points(a: Point, b: Point) -> Segment { + Segment { a, b } + } + + pub fn direction(&self, p: &Point) -> f64 { + let a = Point::new(p.x - self.a.x, p.y - self.a.y); + let b = Point::new(self.b.x - self.a.x, self.b.y - self.a.y); + a.cross_prod(&b) + } + + pub fn is_vertical(&self) -> bool { + self.a.x == self.b.x + } + + // returns (slope, y-intercept) + pub fn get_line_equation(&self) -> (f64, f64) { + let slope = (self.a.y - self.b.y) / (self.a.x - self.b.x); + let y_intercept = self.a.y - slope * self.a.x; + (slope, y_intercept) + } + + // Compute the value of y at x. Uses the line equation, and assumes the segment + // has infinite length. + pub fn compute_y_at_x(&self, x: f64) -> f64 { + let (slope, y_intercept) = self.get_line_equation(); + slope * x + y_intercept + } + + pub fn is_colinear(&self, p: &Point) -> bool { + if self.is_vertical() { + p.x == self.a.x + } else { + (self.compute_y_at_x(p.x) - p.y).abs() < TOLERANCE + } + } + + // p must be colinear with the segment + pub fn colinear_point_on_segment(&self, p: &Point) -> bool { + assert!(self.is_colinear(p), "p must be colinear!"); + let (low_x, high_x) = if self.a.x < self.b.x { + (self.a.x, self.b.x) + } else { + (self.b.x, self.a.x) + }; + let (low_y, high_y) = if self.a.y < self.b.y { + (self.a.y, self.b.y) + } else { + (self.b.y, self.a.y) + }; + + p.x >= low_x && p.x <= high_x && p.y >= low_y && p.y <= high_y + } + + pub fn on_segment(&self, p: &Point) -> bool { + if !self.is_colinear(p) { + return false; + } + self.colinear_point_on_segment(p) + } + + pub fn intersects(&self, other: &Segment) -> bool { + let direction1 = self.direction(&other.a); + let direction2 = self.direction(&other.b); + let direction3 = other.direction(&self.a); + let direction4 = other.direction(&self.b); + + // If the segments saddle each others' endpoints, they intersect + if ((direction1 > 0.0 && direction2 < 0.0) || (direction1 < 0.0 && direction2 > 0.0)) + && ((direction3 > 0.0 && direction4 < 0.0) || (direction3 < 0.0 && direction4 > 0.0)) + { + return true; + } + + // Edge cases where an endpoint lies on a segment + (direction1 == 0.0 && self.colinear_point_on_segment(&other.a)) + || (direction2 == 0.0 && self.colinear_point_on_segment(&other.b)) + || (direction3 == 0.0 && other.colinear_point_on_segment(&self.a)) + || (direction4 == 0.0 && other.colinear_point_on_segment(&self.b)) + } +} + +#[cfg(test)] +mod tests { + use super::Point; + use super::Segment; + + #[test] + fn colinear() { + let segment = Segment::new(2.0, 3.0, 6.0, 5.0); + assert_eq!((0.5, 2.0), segment.get_line_equation()); + + assert!(segment.is_colinear(&Point::new(2.0, 3.0))); + assert!(segment.is_colinear(&Point::new(6.0, 5.0))); + assert!(segment.is_colinear(&Point::new(0.0, 2.0))); + assert!(segment.is_colinear(&Point::new(-5.0, -0.5))); + assert!(segment.is_colinear(&Point::new(10.0, 7.0))); + + assert!(!segment.is_colinear(&Point::new(0.0, 0.0))); + assert!(!segment.is_colinear(&Point::new(1.9, 3.0))); + assert!(!segment.is_colinear(&Point::new(2.1, 3.0))); + assert!(!segment.is_colinear(&Point::new(2.0, 2.9))); + assert!(!segment.is_colinear(&Point::new(2.0, 3.1))); + assert!(!segment.is_colinear(&Point::new(5.9, 5.0))); + assert!(!segment.is_colinear(&Point::new(6.1, 5.0))); + assert!(!segment.is_colinear(&Point::new(6.0, 4.9))); + assert!(!segment.is_colinear(&Point::new(6.0, 5.1))); + } + + #[test] + fn colinear_vertical() { + let segment = Segment::new(2.0, 3.0, 2.0, 5.0); + assert!(segment.is_colinear(&Point::new(2.0, 1.0))); + assert!(segment.is_colinear(&Point::new(2.0, 3.0))); + assert!(segment.is_colinear(&Point::new(2.0, 4.0))); + assert!(segment.is_colinear(&Point::new(2.0, 5.0))); + assert!(segment.is_colinear(&Point::new(2.0, 6.0))); + + assert!(!segment.is_colinear(&Point::new(1.0, 3.0))); + assert!(!segment.is_colinear(&Point::new(3.0, 3.0))); + } + + fn test_intersect(s1: &Segment, s2: &Segment, result: bool) { + assert_eq!(s1.intersects(s2), result); + assert_eq!(s2.intersects(s1), result); + } + + #[test] + fn intersects() { + let s1 = Segment::new(2.0, 3.0, 6.0, 5.0); + let s2 = Segment::new(-1.0, 9.0, 10.0, -3.0); + let s3 = Segment::new(-0.0, 10.0, 11.0, -2.0); + let s4 = Segment::new(100.0, 200.0, 40.0, 50.0); + test_intersect(&s1, &s2, true); + test_intersect(&s1, &s3, true); + test_intersect(&s2, &s3, false); + test_intersect(&s1, &s4, false); + test_intersect(&s2, &s4, false); + test_intersect(&s3, &s4, false); + } + + #[test] + fn intersects_endpoint_on_segment() { + let s1 = Segment::new(2.0, 3.0, 6.0, 5.0); + let s2 = Segment::new(4.0, 4.0, -11.0, 20.0); + let s3 = Segment::new(4.0, 4.0, 14.0, -19.0); + test_intersect(&s1, &s2, true); + test_intersect(&s1, &s3, true); + } + + #[test] + fn intersects_self() { + let s1 = Segment::new(2.0, 3.0, 6.0, 5.0); + let s2 = Segment::new(2.0, 3.0, 6.0, 5.0); + test_intersect(&s1, &s2, true); + } + + #[test] + fn too_short_to_intersect() { + let s1 = Segment::new(2.0, 3.0, 6.0, 5.0); + let s2 = Segment::new(-1.0, 10.0, 3.0, 5.0); + let s3 = Segment::new(5.0, 3.0, 10.0, -11.0); + test_intersect(&s1, &s2, false); + test_intersect(&s1, &s3, false); + test_intersect(&s2, &s3, false); + } + + #[test] + fn parallel_segments() { + let s1 = Segment::new(-5.0, 0.0, 5.0, 0.0); + let s2 = Segment::new(-5.0, 1.0, 5.0, 1.0); + let s3 = Segment::new(-5.0, -1.0, 5.0, -1.0); + test_intersect(&s1, &s2, false); + test_intersect(&s1, &s3, false); + test_intersect(&s2, &s3, false); + } +} diff --git a/src/graph/astar.rs b/src/graph/astar.rs new file mode 100644 index 00000000000..a4244c87b8b --- /dev/null +++ b/src/graph/astar.rs @@ -0,0 +1,261 @@ +use std::{ + collections::{BTreeMap, BinaryHeap}, + ops::Add, +}; + +use num_traits::Zero; + +type Graph = BTreeMap>; + +#[derive(Clone, Debug, Eq, PartialEq)] +struct Candidate { + estimated_weight: E, + real_weight: E, + state: V, +} + +impl PartialOrd for Candidate { + fn partial_cmp(&self, other: &Self) -> Option { + // Note the inverted order; we want nodes with lesser weight to have + // higher priority + Some(self.cmp(other)) + } +} + +impl Ord for Candidate { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + // Note the inverted order; we want nodes with lesser weight to have + // higher priority + other.estimated_weight.cmp(&self.estimated_weight) + } +} + +pub fn astar + Zero>( + graph: &Graph, + start: V, + target: V, + heuristic: impl Fn(V) -> E, +) -> Option<(E, Vec)> { + // traversal front + let mut queue = BinaryHeap::new(); + // maps each node to its predecessor in the final path + let mut previous = BTreeMap::new(); + // weights[v] is the accumulated weight from start to v + let mut weights = BTreeMap::new(); + // initialize traversal + weights.insert(start, E::zero()); + queue.push(Candidate { + estimated_weight: heuristic(start), + real_weight: E::zero(), + state: start, + }); + while let Some(Candidate { + real_weight, + state: current, + .. + }) = queue.pop() + { + if current == target { + break; + } + for (&next, &weight) in &graph[¤t] { + let real_weight = real_weight + weight; + if weights + .get(&next) + .is_none_or(|&weight| real_weight < weight) + { + // current allows us to reach next with lower weight (or at all) + // add next to the front + let estimated_weight = real_weight + heuristic(next); + weights.insert(next, real_weight); + queue.push(Candidate { + estimated_weight, + real_weight, + state: next, + }); + previous.insert(next, current); + } + } + } + let weight = if let Some(&weight) = weights.get(&target) { + weight + } else { + // we did not reach target from start + return None; + }; + // build path in reverse + let mut current = target; + let mut path = vec![current]; + while current != start { + let prev = previous + .get(¤t) + .copied() + .expect("We reached the target, but are unable to reconsistute the path"); + current = prev; + path.push(current); + } + path.reverse(); + Some((weight, path)) +} + +#[cfg(test)] +mod tests { + use super::{astar, Graph}; + use num_traits::Zero; + use std::collections::BTreeMap; + + // the null heuristic make A* equivalent to Dijkstra + fn null_heuristic(_v: V) -> E { + E::zero() + } + + fn add_edge(graph: &mut Graph, v1: V, v2: V, c: E) { + graph.entry(v1).or_default().insert(v2, c); + graph.entry(v2).or_default(); + } + + #[test] + fn single_vertex() { + let mut graph: Graph = BTreeMap::new(); + graph.insert(0, BTreeMap::new()); + + assert_eq!(astar(&graph, 0, 0, null_heuristic), Some((0, vec![0]))); + assert_eq!(astar(&graph, 0, 1, null_heuristic), None); + } + + #[test] + fn single_edge() { + let mut graph = BTreeMap::new(); + add_edge(&mut graph, 0, 1, 2); + + assert_eq!(astar(&graph, 0, 1, null_heuristic), Some((2, vec![0, 1]))); + assert_eq!(astar(&graph, 1, 0, null_heuristic), None); + } + + #[test] + fn graph_1() { + let mut graph = BTreeMap::new(); + add_edge(&mut graph, 'a', 'c', 12); + add_edge(&mut graph, 'a', 'd', 60); + add_edge(&mut graph, 'b', 'a', 10); + add_edge(&mut graph, 'c', 'b', 20); + add_edge(&mut graph, 'c', 'd', 32); + add_edge(&mut graph, 'e', 'a', 7); + + // from a + assert_eq!( + astar(&graph, 'a', 'a', null_heuristic), + Some((0, vec!['a'])) + ); + assert_eq!( + astar(&graph, 'a', 'b', null_heuristic), + Some((32, vec!['a', 'c', 'b'])) + ); + assert_eq!( + astar(&graph, 'a', 'c', null_heuristic), + Some((12, vec!['a', 'c'])) + ); + assert_eq!( + astar(&graph, 'a', 'd', null_heuristic), + Some((12 + 32, vec!['a', 'c', 'd'])) + ); + assert_eq!(astar(&graph, 'a', 'e', null_heuristic), None); + + // from b + assert_eq!( + astar(&graph, 'b', 'a', null_heuristic), + Some((10, vec!['b', 'a'])) + ); + assert_eq!( + astar(&graph, 'b', 'b', null_heuristic), + Some((0, vec!['b'])) + ); + assert_eq!( + astar(&graph, 'b', 'c', null_heuristic), + Some((10 + 12, vec!['b', 'a', 'c'])) + ); + assert_eq!( + astar(&graph, 'b', 'd', null_heuristic), + Some((10 + 12 + 32, vec!['b', 'a', 'c', 'd'])) + ); + assert_eq!(astar(&graph, 'b', 'e', null_heuristic), None); + + // from c + assert_eq!( + astar(&graph, 'c', 'a', null_heuristic), + Some((20 + 10, vec!['c', 'b', 'a'])) + ); + assert_eq!( + astar(&graph, 'c', 'b', null_heuristic), + Some((20, vec!['c', 'b'])) + ); + assert_eq!( + astar(&graph, 'c', 'c', null_heuristic), + Some((0, vec!['c'])) + ); + assert_eq!( + astar(&graph, 'c', 'd', null_heuristic), + Some((32, vec!['c', 'd'])) + ); + assert_eq!(astar(&graph, 'c', 'e', null_heuristic), None); + + // from d + assert_eq!(astar(&graph, 'd', 'a', null_heuristic), None); + assert_eq!(astar(&graph, 'd', 'b', null_heuristic), None); + assert_eq!(astar(&graph, 'd', 'c', null_heuristic), None); + assert_eq!( + astar(&graph, 'd', 'd', null_heuristic), + Some((0, vec!['d'])) + ); + assert_eq!(astar(&graph, 'd', 'e', null_heuristic), None); + + // from e + assert_eq!( + astar(&graph, 'e', 'a', null_heuristic), + Some((7, vec!['e', 'a'])) + ); + assert_eq!( + astar(&graph, 'e', 'b', null_heuristic), + Some((7 + 12 + 20, vec!['e', 'a', 'c', 'b'])) + ); + assert_eq!( + astar(&graph, 'e', 'c', null_heuristic), + Some((7 + 12, vec!['e', 'a', 'c'])) + ); + assert_eq!( + astar(&graph, 'e', 'd', null_heuristic), + Some((7 + 12 + 32, vec!['e', 'a', 'c', 'd'])) + ); + assert_eq!( + astar(&graph, 'e', 'e', null_heuristic), + Some((0, vec!['e'])) + ); + } + + #[test] + fn test_heuristic() { + // make a grid + let mut graph = BTreeMap::new(); + let rows = 100; + let cols = 100; + for row in 0..rows { + for col in 0..cols { + add_edge(&mut graph, (row, col), (row + 1, col), 1); + add_edge(&mut graph, (row, col), (row, col + 1), 1); + add_edge(&mut graph, (row, col), (row + 1, col + 1), 1); + add_edge(&mut graph, (row + 1, col), (row, col), 1); + add_edge(&mut graph, (row + 1, col + 1), (row, col), 1); + } + } + + // Dijkstra would explore most of the 101 × 101 nodes + // the heuristic should allow exploring only about 200 nodes + let now = std::time::Instant::now(); + let res = astar(&graph, (0, 0), (100, 90), |(i, j)| 100 - i + 90 - j); + assert!(now.elapsed() < std::time::Duration::from_millis(10)); + + let (weight, path) = res.unwrap(); + assert_eq!(weight, 100); + assert_eq!(path.len(), 101); + } +} diff --git a/src/graph/bellman_ford.rs b/src/graph/bellman_ford.rs new file mode 100644 index 00000000000..8faf55860d5 --- /dev/null +++ b/src/graph/bellman_ford.rs @@ -0,0 +1,268 @@ +use std::collections::BTreeMap; +use std::ops::Add; + +use std::ops::Neg; + +type Graph = BTreeMap>; + +// performs the Bellman-Ford algorithm on the given graph from the given start +// the graph is an undirected graph +// +// if there is a negative weighted loop it returns None +// else it returns a map that for each reachable vertex associates the distance and the predecessor +// since the start has no predecessor but is reachable, map[start] will be None +pub fn bellman_ford< + V: Ord + Copy, + E: Ord + Copy + Add + Neg + std::ops::Sub, +>( + graph: &Graph, + start: &V, +) -> Option>> { + let mut ans: BTreeMap> = BTreeMap::new(); + + ans.insert(*start, None); + + for _ in 1..(graph.len()) { + for (u, edges) in graph { + let dist_u = match ans.get(u) { + Some(Some((_, d))) => Some(*d), + Some(None) => None, + None => continue, + }; + + for (v, d) in edges { + match ans.get(v) { + Some(Some((_, dist))) + // if this is a longer path, do nothing + if match dist_u { + Some(dist_u) => dist_u + *d >= *dist, + None => d >= dist, + } => {} + Some(None) => { + match dist_u { + // if dist_u + d < 0 there is a negative loop going by start + // else it's just a longer path + Some(dist_u) if dist_u >= -*d => {} + // negative self edge or negative loop + _ => { + if *d > *d + *d { + return None; + } + } + }; + } + // it's a shorter path: either dist_v was infinite or it was longer than dist_u + d + _ => { + ans.insert( + *v, + Some(( + *u, + match dist_u { + Some(dist) => dist + *d, + None => *d, + }, + )), + ); + } + } + } + } + } + + for (u, edges) in graph { + for (v, d) in edges { + match (ans.get(u), ans.get(v)) { + (Some(None), Some(None)) if *d > *d + *d => return None, + (Some(None), Some(Some((_, dv)))) if d < dv => return None, + (Some(Some((_, du))), Some(None)) if *du < -*d => return None, + (Some(Some((_, du))), Some(Some((_, dv)))) if *du + *d < *dv => return None, + (_, _) => {} + } + } + } + + Some(ans) +} + +#[cfg(test)] +mod tests { + use super::{bellman_ford, Graph}; + use std::collections::BTreeMap; + + fn add_edge(graph: &mut Graph, v1: V, v2: V, c: E) { + graph.entry(v1).or_default().insert(v2, c); + graph.entry(v2).or_default(); + } + + #[test] + fn single_vertex() { + let mut graph: Graph = BTreeMap::new(); + graph.insert(0, BTreeMap::new()); + + let mut dists = BTreeMap::new(); + dists.insert(0, None); + + assert_eq!(bellman_ford(&graph, &0), Some(dists)); + } + + #[test] + fn single_edge() { + let mut graph = BTreeMap::new(); + add_edge(&mut graph, 0, 1, 2); + + let mut dists_0 = BTreeMap::new(); + dists_0.insert(0, None); + dists_0.insert(1, Some((0, 2))); + + assert_eq!(bellman_ford(&graph, &0), Some(dists_0)); + + let mut dists_1 = BTreeMap::new(); + dists_1.insert(1, None); + + assert_eq!(bellman_ford(&graph, &1), Some(dists_1)); + } + + #[test] + fn tree_1() { + let mut graph = BTreeMap::new(); + let mut dists = BTreeMap::new(); + dists.insert(1, None); + for i in 1..100 { + add_edge(&mut graph, i, i * 2, i * 2); + add_edge(&mut graph, i, i * 2 + 1, i * 2 + 1); + + match dists[&i] { + Some((_, d)) => { + dists.insert(i * 2, Some((i, d + i * 2))); + dists.insert(i * 2 + 1, Some((i, d + i * 2 + 1))); + } + None => { + dists.insert(i * 2, Some((i, i * 2))); + dists.insert(i * 2 + 1, Some((i, i * 2 + 1))); + } + } + } + + assert_eq!(bellman_ford(&graph, &1), Some(dists)); + } + + #[test] + fn graph_1() { + let mut graph = BTreeMap::new(); + add_edge(&mut graph, 'a', 'c', 12); + add_edge(&mut graph, 'a', 'd', 60); + add_edge(&mut graph, 'b', 'a', 10); + add_edge(&mut graph, 'c', 'b', 20); + add_edge(&mut graph, 'c', 'd', 32); + add_edge(&mut graph, 'e', 'a', 7); + + let mut dists_a = BTreeMap::new(); + dists_a.insert('a', None); + dists_a.insert('c', Some(('a', 12))); + dists_a.insert('d', Some(('c', 44))); + dists_a.insert('b', Some(('c', 32))); + assert_eq!(bellman_ford(&graph, &'a'), Some(dists_a)); + + let mut dists_b = BTreeMap::new(); + dists_b.insert('b', None); + dists_b.insert('a', Some(('b', 10))); + dists_b.insert('c', Some(('a', 22))); + dists_b.insert('d', Some(('c', 54))); + assert_eq!(bellman_ford(&graph, &'b'), Some(dists_b)); + + let mut dists_c = BTreeMap::new(); + dists_c.insert('c', None); + dists_c.insert('b', Some(('c', 20))); + dists_c.insert('d', Some(('c', 32))); + dists_c.insert('a', Some(('b', 30))); + assert_eq!(bellman_ford(&graph, &'c'), Some(dists_c)); + + let mut dists_d = BTreeMap::new(); + dists_d.insert('d', None); + assert_eq!(bellman_ford(&graph, &'d'), Some(dists_d)); + + let mut dists_e = BTreeMap::new(); + dists_e.insert('e', None); + dists_e.insert('a', Some(('e', 7))); + dists_e.insert('c', Some(('a', 19))); + dists_e.insert('d', Some(('c', 51))); + dists_e.insert('b', Some(('c', 39))); + assert_eq!(bellman_ford(&graph, &'e'), Some(dists_e)); + } + + #[test] + fn graph_2() { + let mut graph = BTreeMap::new(); + add_edge(&mut graph, 0, 1, 6); + add_edge(&mut graph, 0, 3, 7); + add_edge(&mut graph, 1, 2, 5); + add_edge(&mut graph, 1, 3, 8); + add_edge(&mut graph, 1, 4, -4); + add_edge(&mut graph, 2, 1, -2); + add_edge(&mut graph, 3, 2, -3); + add_edge(&mut graph, 3, 4, 9); + add_edge(&mut graph, 4, 0, 3); + add_edge(&mut graph, 4, 2, 7); + + let mut dists_0 = BTreeMap::new(); + dists_0.insert(0, None); + dists_0.insert(1, Some((2, 2))); + dists_0.insert(2, Some((3, 4))); + dists_0.insert(3, Some((0, 7))); + dists_0.insert(4, Some((1, -2))); + assert_eq!(bellman_ford(&graph, &0), Some(dists_0)); + + let mut dists_1 = BTreeMap::new(); + dists_1.insert(0, Some((4, -1))); + dists_1.insert(1, None); + dists_1.insert(2, Some((4, 3))); + dists_1.insert(3, Some((0, 6))); + dists_1.insert(4, Some((1, -4))); + assert_eq!(bellman_ford(&graph, &1), Some(dists_1)); + + let mut dists_2 = BTreeMap::new(); + dists_2.insert(0, Some((4, -3))); + dists_2.insert(1, Some((2, -2))); + dists_2.insert(2, None); + dists_2.insert(3, Some((0, 4))); + dists_2.insert(4, Some((1, -6))); + assert_eq!(bellman_ford(&graph, &2), Some(dists_2)); + + let mut dists_3 = BTreeMap::new(); + dists_3.insert(0, Some((4, -6))); + dists_3.insert(1, Some((2, -5))); + dists_3.insert(2, Some((3, -3))); + dists_3.insert(3, None); + dists_3.insert(4, Some((1, -9))); + assert_eq!(bellman_ford(&graph, &3), Some(dists_3)); + + let mut dists_4 = BTreeMap::new(); + dists_4.insert(0, Some((4, 3))); + dists_4.insert(1, Some((2, 5))); + dists_4.insert(2, Some((4, 7))); + dists_4.insert(3, Some((0, 10))); + dists_4.insert(4, None); + assert_eq!(bellman_ford(&graph, &4), Some(dists_4)); + } + + #[test] + fn graph_with_negative_loop() { + let mut graph = BTreeMap::new(); + add_edge(&mut graph, 0, 1, 6); + add_edge(&mut graph, 0, 3, 7); + add_edge(&mut graph, 1, 2, 5); + add_edge(&mut graph, 1, 3, 8); + add_edge(&mut graph, 1, 4, -4); + add_edge(&mut graph, 2, 1, -4); + add_edge(&mut graph, 3, 2, -3); + add_edge(&mut graph, 3, 4, 9); + add_edge(&mut graph, 4, 0, 3); + add_edge(&mut graph, 4, 2, 7); + + assert_eq!(bellman_ford(&graph, &0), None); + assert_eq!(bellman_ford(&graph, &1), None); + assert_eq!(bellman_ford(&graph, &2), None); + assert_eq!(bellman_ford(&graph, &3), None); + assert_eq!(bellman_ford(&graph, &4), None); + } +} diff --git a/src/graph/bipartite_matching.rs b/src/graph/bipartite_matching.rs new file mode 100644 index 00000000000..48c25e8064f --- /dev/null +++ b/src/graph/bipartite_matching.rs @@ -0,0 +1,270 @@ +// Adjacency List +use std::collections::VecDeque; +type Graph = Vec>; + +pub struct BipartiteMatching { + pub adj: Graph, + pub num_vertices_grp1: usize, + pub num_vertices_grp2: usize, + // mt1[i] = v is the matching of i in grp1 to v in grp2 + pub mt1: Vec, + pub mt2: Vec, + pub used: Vec, +} +impl BipartiteMatching { + pub fn new(num_vertices_grp1: usize, num_vertices_grp2: usize) -> Self { + BipartiteMatching { + adj: vec![vec![]; num_vertices_grp1 + 1], + num_vertices_grp1, + num_vertices_grp2, + mt2: vec![-1; num_vertices_grp2 + 1], + mt1: vec![-1; num_vertices_grp1 + 1], + used: vec![false; num_vertices_grp1 + 1], + } + } + #[inline] + // Add an directed edge u->v in the graph + pub fn add_edge(&mut self, u: usize, v: usize) { + self.adj[u].push(v); + } + + fn try_kuhn(&mut self, cur: usize) -> bool { + if self.used[cur] { + return false; + } + self.used[cur] = true; + for i in 0..self.adj[cur].len() { + let to = self.adj[cur][i]; + if self.mt2[to] == -1 || self.try_kuhn(self.mt2[to] as usize) { + self.mt2[to] = cur as i32; + return true; + } + } + false + } + // Note: It does not modify self.mt1, it only works on self.mt2 + pub fn kuhn(&mut self) { + self.mt2 = vec![-1; self.num_vertices_grp2 + 1]; + for v in 1..=self.num_vertices_grp1 { + self.used = vec![false; self.num_vertices_grp1 + 1]; + self.try_kuhn(v); + } + } + pub fn print_matching(&self) { + for i in 1..=self.num_vertices_grp2 { + if self.mt2[i] == -1 { + continue; + } + println!("Vertex {} in grp1 matched with {} grp2", self.mt2[i], i) + } + } + fn bfs(&self, dist: &mut [i32]) -> bool { + let mut q = VecDeque::new(); + for (u, d_i) in dist + .iter_mut() + .enumerate() + .skip(1) + .take(self.num_vertices_grp1) + { + if self.mt1[u] == 0 { + // u is not matched + *d_i = 0; + q.push_back(u); + } else { + // else set the vertex distance as infinite because it is matched + // this will be considered the next time + + *d_i = i32::MAX; + } + } + dist[0] = i32::MAX; + while !q.is_empty() { + let u = *q.front().unwrap(); + q.pop_front(); + if dist[u] < dist[0] { + for i in 0..self.adj[u].len() { + let v = self.adj[u][i]; + if dist[self.mt2[v] as usize] == i32::MAX { + dist[self.mt2[v] as usize] = dist[u] + 1; + q.push_back(self.mt2[v] as usize); + } + } + } + } + dist[0] != i32::MAX + } + fn dfs(&mut self, u: i32, dist: &mut Vec) -> bool { + if u == 0 { + return true; + } + for i in 0..self.adj[u as usize].len() { + let v = self.adj[u as usize][i]; + if dist[self.mt2[v] as usize] == dist[u as usize] + 1 && self.dfs(self.mt2[v], dist) { + self.mt2[v] = u; + self.mt1[u as usize] = v as i32; + return true; + } + } + dist[u as usize] = i32::MAX; + false + } + pub fn hopcroft_karp(&mut self) -> i32 { + // NOTE: how to use: https://cses.fi/paste/7558dba8d00436a847eab8/ + self.mt2 = vec![0; self.num_vertices_grp2 + 1]; + self.mt1 = vec![0; self.num_vertices_grp1 + 1]; + let mut dist = vec![i32::MAX; self.num_vertices_grp1 + 1]; + let mut res = 0; + while self.bfs(&mut dist) { + for u in 1..=self.num_vertices_grp1 { + if self.mt1[u] == 0 && self.dfs(u as i32, &mut dist) { + res += 1; + } + } + } + // for x in self.mt2 change x to -1 if it is 0 + for x in self.mt2.iter_mut() { + if *x == 0 { + *x = -1; + } + } + for x in self.mt1.iter_mut() { + if *x == 0 { + *x = -1; + } + } + res + } +} +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn small_graph_kuhn() { + let n1 = 6; + let n2 = 6; + let mut g = BipartiteMatching::new(n1, n2); + // vertex 1 in grp1 to vertex 1 in grp 2 + // denote the ith grp2 vertex as n1+i + g.add_edge(1, 2); + g.add_edge(1, 3); + // 2 is not connected to any vertex + g.add_edge(3, 4); + g.add_edge(3, 1); + g.add_edge(4, 3); + g.add_edge(5, 3); + g.add_edge(5, 4); + g.add_edge(6, 6); + g.kuhn(); + g.print_matching(); + let answer: Vec = vec![-1, 2, -1, 1, 3, 4, 6]; + for i in 1..g.mt2.len() { + if g.mt2[i] == -1 { + // 5 in group2 has no pair + assert_eq!(i, 5); + continue; + } + // 2 in group1 has no pair + assert!(g.mt2[i] != 2); + assert_eq!(i as i32, answer[g.mt2[i] as usize]); + } + } + #[test] + fn small_graph_hopcroft() { + let n1 = 6; + let n2 = 6; + let mut g = BipartiteMatching::new(n1, n2); + // vertex 1 in grp1 to vertex 1 in grp 2 + // denote the ith grp2 vertex as n1+i + g.add_edge(1, 2); + g.add_edge(1, 3); + // 2 is not connected to any vertex + g.add_edge(3, 4); + g.add_edge(3, 1); + g.add_edge(4, 3); + g.add_edge(5, 3); + g.add_edge(5, 4); + g.add_edge(6, 6); + let x = g.hopcroft_karp(); + assert_eq!(x, 5); + g.print_matching(); + let answer: Vec = vec![-1, 2, -1, 1, 3, 4, 6]; + for i in 1..g.mt2.len() { + if g.mt2[i] == -1 { + // 5 in group2 has no pair + assert_eq!(i, 5); + continue; + } + // 2 in group1 has no pair + assert!(g.mt2[i] != 2); + assert_eq!(i as i32, answer[g.mt2[i] as usize]); + } + } + #[test] + fn super_small_graph_kuhn() { + let n1 = 1; + let n2 = 1; + let mut g = BipartiteMatching::new(n1, n2); + g.add_edge(1, 1); + g.kuhn(); + g.print_matching(); + assert_eq!(g.mt2[1], 1); + } + #[test] + fn super_small_graph_hopcroft() { + let n1 = 1; + let n2 = 1; + let mut g = BipartiteMatching::new(n1, n2); + g.add_edge(1, 1); + let x = g.hopcroft_karp(); + assert_eq!(x, 1); + g.print_matching(); + assert_eq!(g.mt2[1], 1); + assert_eq!(g.mt1[1], 1); + } + + #[test] + fn only_one_vertex_graph_kuhn() { + let n1 = 10; + let n2 = 10; + let mut g = BipartiteMatching::new(n1, n2); + g.add_edge(1, 1); + g.add_edge(2, 1); + g.add_edge(3, 1); + g.add_edge(4, 1); + g.add_edge(5, 1); + g.add_edge(6, 1); + g.add_edge(7, 1); + g.add_edge(8, 1); + g.add_edge(9, 1); + g.add_edge(10, 1); + g.kuhn(); + g.print_matching(); + assert_eq!(g.mt2[1], 1); + for i in 2..g.mt2.len() { + assert!(g.mt2[i] == -1); + } + } + #[test] + fn only_one_vertex_graph_hopcroft() { + let n1 = 10; + let n2 = 10; + let mut g = BipartiteMatching::new(n1, n2); + g.add_edge(1, 1); + g.add_edge(2, 1); + g.add_edge(3, 1); + g.add_edge(4, 1); + g.add_edge(5, 1); + g.add_edge(6, 1); + g.add_edge(7, 1); + g.add_edge(8, 1); + g.add_edge(9, 1); + g.add_edge(10, 1); + let x = g.hopcroft_karp(); + assert_eq!(x, 1); + g.print_matching(); + assert_eq!(g.mt2[1], 1); + for i in 2..g.mt2.len() { + assert!(g.mt2[i] == -1); + } + } +} diff --git a/src/graph/breadth_first_search.rs b/src/graph/breadth_first_search.rs new file mode 100644 index 00000000000..4b4875ab721 --- /dev/null +++ b/src/graph/breadth_first_search.rs @@ -0,0 +1,203 @@ +use std::collections::HashSet; +use std::collections::VecDeque; + +/// Perform a breadth-first search on Graph `graph`. +/// +/// # Parameters +/// +/// - `graph`: The graph to search. +/// - `root`: The starting node of the graph from which to begin searching. +/// - `target`: The target node for the search. +/// +/// # Returns +/// +/// If the target is found, an Optional vector is returned with the history +/// of nodes visited as its contents. +/// +/// If the target is not found or there is no path from the root, +/// `None` is returned. +/// +pub fn breadth_first_search(graph: &Graph, root: Node, target: Node) -> Option> { + let mut visited: HashSet = HashSet::new(); + let mut history: Vec = Vec::new(); + let mut queue = VecDeque::new(); + + visited.insert(root); + queue.push_back(root); + while let Some(currentnode) = queue.pop_front() { + history.push(currentnode.value()); + + // If we reach the goal, return our travel history. + if currentnode == target { + return Some(history); + } + + // Check the neighboring nodes for any that we've not visited yet. + for neighbor in currentnode.neighbors(graph) { + if visited.insert(neighbor) { + queue.push_back(neighbor); + } + } + } + + // All nodes were visited, yet the target was not found. + None +} + +// Data Structures + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct Node(u32); + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct Edge(u32, u32); + +#[derive(Clone)] +pub struct Graph { + #[allow(dead_code)] + nodes: Vec, + edges: Vec, +} + +impl Graph { + pub fn new(nodes: Vec, edges: Vec) -> Self { + Graph { nodes, edges } + } +} + +impl From for Node { + fn from(item: u32) -> Self { + Node(item) + } +} + +impl Node { + pub fn value(&self) -> u32 { + self.0 + } + + pub fn neighbors(&self, graph: &Graph) -> Vec { + graph + .edges + .iter() + .filter(|e| e.0 == self.0) + .map(|e| e.1.into()) + .collect() + } +} + +impl From<(u32, u32)> for Edge { + fn from(item: (u32, u32)) -> Self { + Edge(item.0, item.1) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /* Example graph #1: + * + * (1) <--- Root + * / \ + * (2) (3) + * / | | \ + * (4) (5) (6) (7) + * | + * (8) + */ + fn graph1() -> Graph { + let nodes = vec![1, 2, 3, 4, 5, 6, 7]; + let edges = vec![(1, 2), (1, 3), (2, 4), (2, 5), (3, 6), (3, 7), (5, 8)]; + + Graph::new( + nodes.into_iter().map(|v| v.into()).collect(), + edges.into_iter().map(|e| e.into()).collect(), + ) + } + + #[test] + fn breadth_first_search_graph1_when_node_not_found_returns_none() { + let graph = graph1(); + let root = 1; + let target = 10; + + assert_eq!( + breadth_first_search(&graph, root.into(), target.into()), + None + ); + } + + #[test] + fn breadth_first_search_graph1_when_target_8_should_evaluate_all_nodes_first() { + let graph = graph1(); + let root = 1; + let target = 8; + + let expected_path = vec![1, 2, 3, 4, 5, 6, 7, 8]; + + assert_eq!( + breadth_first_search(&graph, root.into(), target.into()), + Some(expected_path) + ); + } + + /* Example graph #2: + * + * (1) --- (2) (3) --- (4) + * / | / / + * / | / / + * / | / / + * (5) (6) --- (7) (8) + */ + fn graph2() -> Graph { + let nodes = vec![1, 2, 3, 4, 5, 6, 7, 8]; + let undirected_edges = vec![ + (1, 2), + (2, 1), + (2, 5), + (5, 2), + (2, 6), + (6, 2), + (3, 4), + (4, 3), + (3, 6), + (6, 3), + (4, 7), + (7, 4), + (6, 7), + (7, 6), + ]; + + Graph::new( + nodes.into_iter().map(|v| v.into()).collect(), + undirected_edges.into_iter().map(|e| e.into()).collect(), + ) + } + + #[test] + fn breadth_first_search_graph2_when_no_path_to_node_returns_none() { + let graph = graph2(); + let root = 8; + let target = 4; + + assert_eq!( + breadth_first_search(&graph, root.into(), target.into()), + None + ); + } + + #[test] + fn breadth_first_search_graph2_should_find_path_from_4_to_1() { + let graph = graph2(); + let root = 4; + let target = 1; + + let expected_path = vec![4, 3, 7, 6, 2, 1]; + + assert_eq!( + breadth_first_search(&graph, root.into(), target.into()), + Some(expected_path) + ); + } +} diff --git a/src/graph/centroid_decomposition.rs b/src/graph/centroid_decomposition.rs new file mode 100644 index 00000000000..c8a9754fdd0 --- /dev/null +++ b/src/graph/centroid_decomposition.rs @@ -0,0 +1,159 @@ +type Adj = [Vec]; + +const IN_DECOMPOSITION: u64 = 1 << 63; + +/// Centroid Decomposition for a tree. +/// +/// Given a tree, it can be recursively decomposed into centroids. Then the +/// parent of a centroid `c` is the previous centroid that splitted its connected +/// component into two or more components. It can be shown that in such +/// decomposition, for each path `p` with starting and ending vertices `u`, `v`, +/// the lowest common ancestor of `u` and `v` in centroid tree is a vertex of `p`. +/// +/// The input tree should have its vertices numbered from 1 to n, and +/// `graph_enumeration.rs` may help to convert other representations. +pub struct CentroidDecomposition { + /// The root of the centroid tree, should _not_ be set by the user + pub root: usize, + /// The result. `decomposition[v]` is the parent of `v` in centroid tree. + /// `decomposition[root]` is 0 + pub decomposition: Vec, + /// Used internally to save the big_child of a vertex, and whether it has + /// been added to the centroid tree. + vert_state: Vec, + /// Used internally to save the subtree size of a vertex + vert_size: Vec, +} + +impl CentroidDecomposition { + pub fn new(mut num_vertices: usize) -> Self { + num_vertices += 1; + CentroidDecomposition { + root: 0, + decomposition: vec![0; num_vertices], + vert_state: vec![0; num_vertices], + vert_size: vec![0; num_vertices], + } + } + #[inline] + fn put_in_decomposition(&mut self, v: usize, parent: usize) { + self.decomposition[v] = parent; + self.vert_state[v] |= IN_DECOMPOSITION; + } + #[inline] + fn is_in_decomposition(&self, v: usize) -> bool { + (self.vert_state[v] & IN_DECOMPOSITION) != 0 + } + fn dfs_size(&mut self, v: usize, parent: usize, adj: &Adj) -> usize { + self.vert_size[v] = 1; + let mut big_child = 0_usize; + let mut bc_size = 0_usize; // big child size + for &u in adj[v].iter() { + if u == parent || self.is_in_decomposition(u) { + continue; + } + let u_size = self.dfs_size(u, v, adj); + self.vert_size[v] += u_size; + if u_size > bc_size { + big_child = u; + bc_size = u_size; + } + } + self.vert_state[v] = big_child as u64; + self.vert_size[v] + } + fn dfs_centroid(&self, v: usize, size_thr: usize) -> usize { + // recurse until big child's size is <= `size_thr` + match self.vert_state[v] as usize { + u if self.vert_size[u] <= size_thr => v, + u => self.dfs_centroid(u, size_thr), + } + } + fn decompose_subtree( + &mut self, + v: usize, + centroid_parent: usize, + calculate_vert_size: bool, + adj: &Adj, + ) -> usize { + // `calculate_vert_size` determines if it is necessary to recalculate + // `self.vert_size` + if calculate_vert_size { + self.dfs_size(v, centroid_parent, adj); + } + let v_size = self.vert_size[v]; + let centroid = self.dfs_centroid(v, v_size >> 1); + self.put_in_decomposition(centroid, centroid_parent); + for &u in adj[centroid].iter() { + if self.is_in_decomposition(u) { + continue; + } + self.decompose_subtree( + u, + centroid, + self.vert_size[u] > self.vert_size[centroid], + adj, + ); + } + centroid + } + pub fn decompose_tree(&mut self, adj: &Adj) { + self.decompose_subtree(1, 0, true, adj); + } +} + +#[cfg(test)] +mod tests { + use super::CentroidDecomposition; + use crate::{ + graph::{enumerate_graph, prufer_code}, + math::PCG32, + }; + fn calculate_height(v: usize, heights: &mut [usize], parents: &mut [usize]) -> usize { + if heights[v] == 0 { + heights[v] = calculate_height(parents[v], heights, parents) + 1; + } + heights[v] + } + #[test] + fn single_path() { + let len = 16; + let mut adj: Vec> = vec![vec![]; len]; + adj[1].push(2); + adj[15].push(14); + #[allow(clippy::needless_range_loop)] + for i in 2..15 { + adj[i].push(i + 1); + adj[i].push(i - 1); + } + let mut cd = CentroidDecomposition::new(len - 1); + cd.decompose_tree(&adj); + // We should get a complete binary tree + assert_eq!( + cd.decomposition, + vec![0, 2, 4, 2, 8, 6, 4, 6, 0, 10, 12, 10, 8, 14, 12, 14] + ); + } + #[test] + #[ignore] + fn random_tree_height() { + // Do not run this test in debug mode! It takes > 30s to run without + // optimizations! + let n = 1e6 as usize; + let max_height = 1 + 20; + let len = n + 1; + let mut rng = PCG32::new_default(314159); + let mut tree_prufer_code: Vec = vec![0; n - 2]; + tree_prufer_code.fill_with(|| (rng.get_u32() % (n as u32)) + 1); + let vertex_list: Vec = (1..=(n as u32)).collect(); + let adj = enumerate_graph(&prufer_code::prufer_decode(&tree_prufer_code, &vertex_list)); + let mut cd = CentroidDecomposition::new(n); + cd.decompose_tree(&adj); + let mut heights: Vec = vec![0; len]; + heights[0] = 1; + for i in 1..=n { + let h = calculate_height(i, &mut heights, &mut cd.decomposition); + assert!(h <= max_height); + } + } +} diff --git a/src/graph/decremental_connectivity.rs b/src/graph/decremental_connectivity.rs new file mode 100644 index 00000000000..e5245404866 --- /dev/null +++ b/src/graph/decremental_connectivity.rs @@ -0,0 +1,267 @@ +use std::collections::HashSet; + +/// A data-structure that, given a forest, allows dynamic-connectivity queries. +/// Meaning deletion of an edge (u,v) and checking whether two vertecies are still connected. +/// +/// # Complexity +/// The preprocessing phase runs in O(n) time, where n is the number of vertecies in the forest. +/// Deletion runs in O(log n) and checking for connectivity runs in O(1) time. +/// +/// # Sources +/// used Wikipedia as reference: +pub struct DecrementalConnectivity { + adjacent: Vec>, + component: Vec, + count: usize, + visited: Vec, + dfs_id: usize, +} +impl DecrementalConnectivity { + //expects the parent of a root to be itself + pub fn new(adjacent: Vec>) -> Result { + let n = adjacent.len(); + if !is_forest(&adjacent) { + return Err("input graph is not a forest!".to_string()); + } + let mut tmp = DecrementalConnectivity { + adjacent, + component: vec![0; n], + count: 0, + visited: vec![0; n], + dfs_id: 1, + }; + tmp.component = tmp.calc_component(); + Ok(tmp) + } + + pub fn connected(&self, u: usize, v: usize) -> Option { + match (self.component.get(u), self.component.get(v)) { + (Some(a), Some(b)) => Some(a == b), + _ => None, + } + } + + pub fn delete(&mut self, u: usize, v: usize) { + if !self.adjacent[u].contains(&v) || self.component[u] != self.component[v] { + panic!("delete called on the edge ({u}, {v}) which doesn't exist"); + } + + self.adjacent[u].remove(&v); + self.adjacent[v].remove(&u); + + let mut queue: Vec = Vec::new(); + if self.is_smaller(u, v) { + queue.push(u); + self.dfs_id += 1; + self.visited[v] = self.dfs_id; + } else { + queue.push(v); + self.dfs_id += 1; + self.visited[u] = self.dfs_id; + } + while !queue.is_empty() { + let ¤t = queue.last().unwrap(); + self.dfs_step(&mut queue, self.dfs_id); + self.component[current] = self.count; + } + self.count += 1; + } + + fn calc_component(&mut self) -> Vec { + let mut visited: Vec = vec![false; self.adjacent.len()]; + let mut comp: Vec = vec![0; self.adjacent.len()]; + + for i in 0..self.adjacent.len() { + if visited[i] { + continue; + } + let mut queue: Vec = vec![i]; + while let Some(current) = queue.pop() { + if !visited[current] { + for &neighbour in self.adjacent[current].iter() { + queue.push(neighbour); + } + } + visited[current] = true; + comp[current] = self.count; + } + self.count += 1; + } + comp + } + + fn is_smaller(&mut self, u: usize, v: usize) -> bool { + let mut u_queue: Vec = vec![u]; + let u_id = self.dfs_id; + self.visited[v] = u_id; + self.dfs_id += 1; + + let mut v_queue: Vec = vec![v]; + let v_id = self.dfs_id; + self.visited[u] = v_id; + self.dfs_id += 1; + + // parallel depth first search + while !u_queue.is_empty() && !v_queue.is_empty() { + self.dfs_step(&mut u_queue, u_id); + self.dfs_step(&mut v_queue, v_id); + } + u_queue.is_empty() + } + + fn dfs_step(&mut self, queue: &mut Vec, dfs_id: usize) { + let u = queue.pop().unwrap(); + self.visited[u] = dfs_id; + for &v in self.adjacent[u].iter() { + if self.visited[v] == dfs_id { + continue; + } + queue.push(v); + } + } +} + +// checks whether the given graph is a forest +// also checks for all adjacent vertices a,b if adjacent[a].contains(b) && adjacent[b].contains(a) +fn is_forest(adjacent: &Vec>) -> bool { + let mut visited = vec![false; adjacent.len()]; + for node in 0..adjacent.len() { + if visited[node] { + continue; + } + if has_cycle(adjacent, &mut visited, node, node) { + return false; + } + } + true +} + +fn has_cycle( + adjacent: &Vec>, + visited: &mut Vec, + node: usize, + parent: usize, +) -> bool { + visited[node] = true; + for &neighbour in adjacent[node].iter() { + if !adjacent[neighbour].contains(&node) { + panic!("the given graph does not strictly contain bidirectional edges\n {node} -> {neighbour} exists, but the other direction does not"); + } + if !visited[neighbour] { + if has_cycle(adjacent, visited, neighbour, node) { + return true; + } + } else if neighbour != parent { + return true; + } + } + false +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + // test forest (remember the assumptoin that roots are adjacent to themselves) + // _ _ + // \ / \ / + // 0 7 + // / | \ | + // 1 2 3 8 + // / / \ + // 4 5 6 + #[test] + fn construction_test() { + let mut adjacent = vec![ + HashSet::from([0, 1, 2, 3]), + HashSet::from([0, 4]), + HashSet::from([0, 5, 6]), + HashSet::from([0]), + HashSet::from([1]), + HashSet::from([2]), + HashSet::from([2]), + HashSet::from([7, 8]), + HashSet::from([7]), + ]; + let dec_con = super::DecrementalConnectivity::new(adjacent.clone()).unwrap(); + assert_eq!(dec_con.component, vec![0, 0, 0, 0, 0, 0, 0, 1, 1]); + + // add a cycle to the tree + adjacent[2].insert(4); + adjacent[4].insert(2); + assert!(super::DecrementalConnectivity::new(adjacent.clone()).is_err()); + } + #[test] + #[should_panic(expected = "2 -> 4 exists")] + fn non_bidirectional_test() { + let adjacent = vec![ + HashSet::from([0, 1, 2, 3]), + HashSet::from([0, 4]), + HashSet::from([0, 5, 6, 4]), + HashSet::from([0]), + HashSet::from([1]), + HashSet::from([2]), + HashSet::from([2]), + HashSet::from([7, 8]), + HashSet::from([7]), + ]; + + // should panic now since our graph is not bidirectional + super::DecrementalConnectivity::new(adjacent).unwrap(); + } + + #[test] + #[should_panic(expected = "delete called on the edge (2, 4)")] + fn delete_panic_test() { + let adjacent = vec![ + HashSet::from([0, 1, 2, 3]), + HashSet::from([0, 4]), + HashSet::from([0, 5, 6]), + HashSet::from([0]), + HashSet::from([1]), + HashSet::from([2]), + HashSet::from([2]), + HashSet::from([7, 8]), + HashSet::from([7]), + ]; + let mut dec_con = super::DecrementalConnectivity::new(adjacent).unwrap(); + dec_con.delete(2, 4); + } + + #[test] + fn query_test() { + let adjacent = vec![ + HashSet::from([0, 1, 2, 3]), + HashSet::from([0, 4]), + HashSet::from([0, 5, 6]), + HashSet::from([0]), + HashSet::from([1]), + HashSet::from([2]), + HashSet::from([2]), + HashSet::from([7, 8]), + HashSet::from([7]), + ]; + let mut dec_con1 = super::DecrementalConnectivity::new(adjacent.clone()).unwrap(); + assert!(dec_con1.connected(3, 4).unwrap()); + assert!(dec_con1.connected(5, 0).unwrap()); + assert!(!dec_con1.connected(2, 7).unwrap()); + assert!(dec_con1.connected(0, 9).is_none()); + dec_con1.delete(0, 2); + assert!(dec_con1.connected(3, 4).unwrap()); + assert!(!dec_con1.connected(5, 0).unwrap()); + assert!(dec_con1.connected(5, 6).unwrap()); + assert!(dec_con1.connected(8, 7).unwrap()); + dec_con1.delete(7, 8); + assert!(!dec_con1.connected(8, 7).unwrap()); + dec_con1.delete(1, 4); + assert!(!dec_con1.connected(1, 4).unwrap()); + + let mut dec_con2 = super::DecrementalConnectivity::new(adjacent.clone()).unwrap(); + dec_con2.delete(4, 1); + assert!(!dec_con2.connected(1, 4).unwrap()); + + let mut dec_con3 = super::DecrementalConnectivity::new(adjacent).unwrap(); + dec_con3.delete(1, 4); + assert!(!dec_con3.connected(4, 1).unwrap()); + } +} diff --git a/src/graph/depth_first_search.rs b/src/graph/depth_first_search.rs new file mode 100644 index 00000000000..4a8789a888d --- /dev/null +++ b/src/graph/depth_first_search.rs @@ -0,0 +1,192 @@ +use std::collections::HashSet; +use std::collections::VecDeque; + +// Perform a Depth First Search Algorithm to find a element in a graph +// +// Return a Optional with a vector with history of vertex visiteds +// or a None if the element not exists on the graph +pub fn depth_first_search(graph: &Graph, root: Vertex, objective: Vertex) -> Option> { + let mut visited: HashSet = HashSet::new(); + let mut history: Vec = Vec::new(); + let mut queue = VecDeque::new(); + queue.push_back(root); + + // While there is an element in the queue + // get the first element of the vertex queue + while let Some(current_vertex) = queue.pop_front() { + // Added current vertex in the history of visiteds vertex + history.push(current_vertex.value()); + + // Verify if this vertex is the objective + if current_vertex == objective { + // Return the Optional with the history of visiteds vertex + return Some(history); + } + + // For each over the neighbors of current vertex + for neighbor in current_vertex.neighbors(graph).into_iter().rev() { + // Insert in the HashSet of visiteds if this value not exist yet + if visited.insert(neighbor) { + // Add the neighbor on front of queue + queue.push_front(neighbor); + } + } + } + + // If all vertex is visited and the objective is not found + // return a Optional with None value + None +} + +// Data Structures + +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct Vertex(u32); +#[derive(Copy, Clone, PartialEq, Eq, Hash)] +pub struct Edge(u32, u32); +#[derive(Clone)] +pub struct Graph { + #[allow(dead_code)] + vertices: Vec, + edges: Vec, +} + +impl Graph { + pub fn new(vertices: Vec, edges: Vec) -> Self { + Graph { vertices, edges } + } +} + +impl From for Vertex { + fn from(item: u32) -> Self { + Vertex(item) + } +} + +impl Vertex { + pub fn value(&self) -> u32 { + self.0 + } + + pub fn neighbors(&self, graph: &Graph) -> VecDeque { + graph + .edges + .iter() + .filter(|e| e.0 == self.0) + .map(|e| e.1.into()) + .collect() + } +} + +impl From<(u32, u32)> for Edge { + fn from(item: (u32, u32)) -> Self { + Edge(item.0, item.1) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn find_1_fail() { + let vertices = vec![1, 2, 3, 4, 5, 6, 7]; + let edges = vec![(1, 2), (1, 3), (2, 4), (2, 5), (3, 6), (3, 7)]; + + let root = 1; + let objective = 99; + + let graph = Graph::new( + vertices.into_iter().map(|v| v.into()).collect(), + edges.into_iter().map(|e| e.into()).collect(), + ); + + assert_eq!( + depth_first_search(&graph, root.into(), objective.into()), + None + ); + } + + #[test] + fn find_1_sucess() { + let vertices = vec![1, 2, 3, 4, 5, 6, 7]; + let edges = vec![(1, 2), (1, 3), (2, 4), (2, 5), (3, 6), (3, 7)]; + + let root = 1; + let objective = 7; + + let correct_path = vec![1, 2, 4, 5, 3, 6, 7]; + + let graph = Graph::new( + vertices.into_iter().map(|v| v.into()).collect(), + edges.into_iter().map(|e| e.into()).collect(), + ); + + assert_eq!( + depth_first_search(&graph, root.into(), objective.into()), + Some(correct_path) + ); + } + + #[test] + fn find_2_sucess() { + let vertices = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let edges = vec![ + (0, 1), + (1, 3), + (3, 2), + (2, 1), + (3, 4), + (4, 5), + (5, 7), + (7, 6), + (6, 4), + ]; + + let root = 0; + let objective = 6; + + let correct_path = vec![0, 1, 3, 2, 4, 5, 7, 6]; + + let graph = Graph::new( + vertices.into_iter().map(|v| v.into()).collect(), + edges.into_iter().map(|e| e.into()).collect(), + ); + + assert_eq!( + depth_first_search(&graph, root.into(), objective.into()), + Some(correct_path) + ); + } + + #[test] + fn find_3_sucess() { + let vertices = vec![0, 1, 2, 3, 4, 5, 6, 7]; + let edges = vec![ + (0, 1), + (1, 3), + (3, 2), + (2, 1), + (3, 4), + (4, 5), + (5, 7), + (7, 6), + (6, 4), + ]; + + let root = 0; + let objective = 4; + + let correct_path = vec![0, 1, 3, 2, 4]; + + let graph = Graph::new( + vertices.into_iter().map(|v| v.into()).collect(), + edges.into_iter().map(|e| e.into()).collect(), + ); + + assert_eq!( + depth_first_search(&graph, root.into(), objective.into()), + Some(correct_path) + ); + } +} diff --git a/src/graph/depth_first_search_tic_tac_toe.rs b/src/graph/depth_first_search_tic_tac_toe.rs new file mode 100644 index 00000000000..788991c3823 --- /dev/null +++ b/src/graph/depth_first_search_tic_tac_toe.rs @@ -0,0 +1,403 @@ +/* +Tic-Tac-Toe Depth First Search Rust Demo +Copyright 2021 David V. Makray + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +#[allow(unused_imports)] +use std::io; + +//Interactive Tic-Tac-Toe play needs the "rand = "0.8.3" crate. +//#[cfg(not(test))] +//extern crate rand; +//#[cfg(not(test))] +//use rand::Rng; + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +struct Position { + x: u8, + y: u8, +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub enum Players { + Blank, + PlayerX, + PlayerO, +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +struct SinglePlayAction { + position: Position, + side: Players, +} + +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct PlayActions { + positions: Vec, + side: Players, +} + +#[allow(dead_code)] +#[cfg(not(test))] +fn main() { + let mut board = vec![vec![Players::Blank; 3]; 3]; + + while !available_positions(&board).is_empty() + && !win_check(Players::PlayerX, &board) + && !win_check(Players::PlayerO, &board) + { + display_board(&board); + println!("Type in coordinate for X mark to be played. ie. a1 etc."); + let mut input = String::new(); + io::stdin() + .read_line(&mut input) + .expect("Failed to read line"); + + let mut move_position: Option = None; + input.make_ascii_lowercase(); + let bytes = input.trim().trim_start().as_bytes(); + if bytes.len() as u32 == 2 + && (bytes[0] as char).is_alphabetic() + && (bytes[1] as char).is_numeric() + { + let column: u8 = bytes[0] - b'a'; + let row: u8 = bytes[1] - b'1'; + + if column <= 2 && row <= 2 { + move_position = Some(Position { x: column, y: row }); + } + } + + //Take the validated user input coordinate and use it. + if let Some(move_pos) = move_position { + let open_positions = available_positions(&board); + + let mut search = open_positions.iter(); + let result = search.find(|&&x| x == move_pos); + if result.is_none() { + println!("Not a valid empty coordinate."); + continue; + } + board[move_pos.y as usize][move_pos.x as usize] = Players::PlayerX; + + if win_check(Players::PlayerX, &board) { + display_board(&board); + println!("Player X Wins!"); + return; + } + + //Find the best game plays from the current board state + let recusion_result = minimax(Players::PlayerO, &board); + match recusion_result { + Some(x) => { + //Interactive Tic-Tac-Toe play needs the "rand = "0.8.3" crate. + //#[cfg(not(test))] + //let random_selection = rand::rng().gen_range(0..x.positions.len()); + let random_selection = 0; + + let response_pos = x.positions[random_selection]; + board[response_pos.y as usize][response_pos.x as usize] = Players::PlayerO; + if win_check(Players::PlayerO, &board) { + display_board(&board); + println!("Player O Wins!"); + return; + } + } + + None => { + display_board(&board); + println!("Draw game."); + return; + } + } + } + } +} + +#[allow(dead_code)] +fn display_board(board: &[Vec]) { + println!(); + for (y, board_row) in board.iter().enumerate() { + print!("{} ", (y + 1)); + for board_cell in board_row { + match board_cell { + Players::PlayerX => print!("X "), + Players::PlayerO => print!("O "), + Players::Blank => print!("_ "), + } + } + println!(); + } + println!(" a b c"); +} + +fn available_positions(board: &[Vec]) -> Vec { + let mut available: Vec = Vec::new(); + for (y, board_row) in board.iter().enumerate() { + for (x, board_cell) in board_row.iter().enumerate() { + if *board_cell == Players::Blank { + available.push(Position { + x: x as u8, + y: y as u8, + }); + } + } + } + available +} + +fn win_check(player: Players, board: &[Vec]) -> bool { + if player == Players::Blank { + return false; + } + + //Check for a win on the diagonals. + if (board[0][0] == board[1][1]) && (board[1][1] == board[2][2]) && (board[2][2] == player) + || (board[2][0] == board[1][1]) && (board[1][1] == board[0][2]) && (board[0][2] == player) + { + return true; + } + + for i in 0..3 { + //Check for a win on the horizontals. + if (board[i][0] == board[i][1]) && (board[i][1] == board[i][2]) && (board[i][2] == player) { + return true; + } + + //Check for a win on the verticals. + if (board[0][i] == board[1][i]) && (board[1][i] == board[2][i]) && (board[2][i] == player) { + return true; + } + } + + false +} + +//Minimize the actions of the opponent while maximizing the game state of the current player. +pub fn minimax(side: Players, board: &[Vec]) -> Option { + //Check that board is in a valid state. + if win_check(Players::PlayerX, board) || win_check(Players::PlayerO, board) { + return None; + } + + let opposite = match side { + Players::PlayerX => Players::PlayerO, + Players::PlayerO => Players::PlayerX, + Players::Blank => panic!("Minimax can't operate when a player isn't specified."), + }; + + let positions = available_positions(board); + if positions.is_empty() { + return None; + } + + //Play position + let mut best_move: Option = None; + + for pos in positions { + let mut board_next = board.to_owned(); + board_next[pos.y as usize][pos.x as usize] = side; + + //Check for a win condition before recursion to determine if this node is terminal. + if win_check(Players::PlayerX, &board_next) { + append_playaction( + side, + &mut best_move, + SinglePlayAction { + position: pos, + side: Players::PlayerX, + }, + ); + continue; + } + + if win_check(Players::PlayerO, &board_next) { + append_playaction( + side, + &mut best_move, + SinglePlayAction { + position: pos, + side: Players::PlayerO, + }, + ); + continue; + } + + let result = minimax(opposite, &board_next); + let current_score = match result { + Some(x) => x.side, + _ => Players::Blank, + }; + + append_playaction( + side, + &mut best_move, + SinglePlayAction { + position: pos, + side: current_score, + }, + ) + } + best_move +} + +//Promote only better or collate equally scored game plays +fn append_playaction( + current_side: Players, + opt_play_actions: &mut Option, + appendee: SinglePlayAction, +) { + if opt_play_actions.is_none() { + *opt_play_actions = Some(PlayActions { + positions: vec![appendee.position], + side: appendee.side, + }); + return; + } + + let play_actions = opt_play_actions.as_mut().unwrap(); + + //New game action is scored from the current side and the current saved best score against the new game action. + match (current_side, play_actions.side, appendee.side) { + (Players::Blank, _, _) => panic!("Unreachable state."), + + //Winning scores + (Players::PlayerX, Players::PlayerX, Players::PlayerX) + | (Players::PlayerO, Players::PlayerO, Players::PlayerO) => { + play_actions.positions.push(appendee.position); + } + + //Non-winning to Winning scores + (Players::PlayerX, _, Players::PlayerX) => { + play_actions.side = Players::PlayerX; + play_actions.positions.clear(); + play_actions.positions.push(appendee.position); + } + (Players::PlayerO, _, Players::PlayerO) => { + play_actions.side = Players::PlayerO; + play_actions.positions.clear(); + play_actions.positions.push(appendee.position); + } + + //Losing to Neutral scores + (Players::PlayerX, Players::PlayerO, Players::Blank) + | (Players::PlayerO, Players::PlayerX, Players::Blank) => { + play_actions.side = Players::Blank; + play_actions.positions.clear(); + play_actions.positions.push(appendee.position); + } + + //Ignoring lower scored plays + (Players::PlayerX, Players::PlayerX, _) + | (Players::PlayerO, Players::PlayerO, _) + | (Players::PlayerX, Players::Blank, Players::PlayerO) + | (Players::PlayerO, Players::Blank, Players::PlayerX) => {} + + //No change hence append only + (_, _, _) => { + assert!(play_actions.side == appendee.side); + play_actions.positions.push(appendee.position); + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn win_state_check() { + let mut board = vec![vec![Players::Blank; 3]; 3]; + board[0][0] = Players::PlayerX; + board[0][1] = Players::PlayerX; + board[0][2] = Players::PlayerX; + let responses = minimax(Players::PlayerO, &board); + assert_eq!(responses, None); + } + + #[test] + fn win_state_check2() { + let mut board = vec![vec![Players::Blank; 3]; 3]; + board[0][0] = Players::PlayerX; + board[0][1] = Players::PlayerO; + board[1][0] = Players::PlayerX; + board[1][1] = Players::PlayerO; + board[2][1] = Players::PlayerO; + let responses = minimax(Players::PlayerO, &board); + assert_eq!(responses, None); + } + + #[test] + fn block_win_move() { + let mut board = vec![vec![Players::Blank; 3]; 3]; + board[0][0] = Players::PlayerX; + board[0][1] = Players::PlayerX; + board[1][2] = Players::PlayerO; + board[2][2] = Players::PlayerO; + let responses = minimax(Players::PlayerX, &board); + assert_eq!( + responses, + Some(PlayActions { + positions: vec![Position { x: 2, y: 0 }], + side: Players::PlayerX + }) + ); + } + + #[test] + fn block_move() { + let mut board = vec![vec![Players::Blank; 3]; 3]; + board[0][1] = Players::PlayerX; + board[0][2] = Players::PlayerO; + board[2][0] = Players::PlayerO; + let responses = minimax(Players::PlayerX, &board); + assert_eq!( + responses, + Some(PlayActions { + positions: vec![Position { x: 1, y: 1 }], + side: Players::Blank + }) + ); + } + + #[test] + fn expected_loss() { + let mut board = vec![vec![Players::Blank; 3]; 3]; + board[0][0] = Players::PlayerX; + board[0][2] = Players::PlayerO; + board[1][0] = Players::PlayerX; + board[2][0] = Players::PlayerO; + board[2][2] = Players::PlayerO; + let responses = minimax(Players::PlayerX, &board); + assert_eq!( + responses, + Some(PlayActions { + positions: vec![ + Position { x: 1, y: 0 }, + Position { x: 1, y: 1 }, + Position { x: 2, y: 1 }, + Position { x: 1, y: 2 } + ], + side: Players::PlayerO + }) + ); + } +} diff --git a/src/graph/detect_cycle.rs b/src/graph/detect_cycle.rs new file mode 100644 index 00000000000..0243b44eede --- /dev/null +++ b/src/graph/detect_cycle.rs @@ -0,0 +1,294 @@ +use std::collections::{HashMap, HashSet, VecDeque}; + +use crate::data_structures::{graph::Graph, DirectedGraph, UndirectedGraph}; + +pub trait DetectCycle { + fn detect_cycle_dfs(&self) -> bool; + fn detect_cycle_bfs(&self) -> bool; +} + +// Helper function to detect cycle in an undirected graph using DFS graph traversal +fn undirected_graph_detect_cycle_dfs<'a>( + graph: &'a UndirectedGraph, + visited_node: &mut HashSet<&'a String>, + parent: Option<&'a String>, + u: &'a String, +) -> bool { + visited_node.insert(u); + for (v, _) in graph.adjacency_table().get(u).unwrap() { + if matches!(parent, Some(parent) if v == parent) { + continue; + } + if visited_node.contains(v) + || undirected_graph_detect_cycle_dfs(graph, visited_node, Some(u), v) + { + return true; + } + } + false +} + +// Helper function to detect cycle in an undirected graph using BFS graph traversal +fn undirected_graph_detect_cycle_bfs<'a>( + graph: &'a UndirectedGraph, + visited_node: &mut HashSet<&'a String>, + u: &'a String, +) -> bool { + visited_node.insert(u); + + // Initialize the queue for BFS, storing (current node, parent node) tuples + let mut queue = VecDeque::<(&String, Option<&String>)>::new(); + queue.push_back((u, None)); + + while let Some((u, parent)) = queue.pop_front() { + for (v, _) in graph.adjacency_table().get(u).unwrap() { + if matches!(parent, Some(parent) if v == parent) { + continue; + } + if visited_node.contains(v) { + return true; + } + visited_node.insert(v); + queue.push_back((v, Some(u))); + } + } + false +} + +impl DetectCycle for UndirectedGraph { + fn detect_cycle_dfs(&self) -> bool { + let mut visited_node = HashSet::<&String>::new(); + let adj = self.adjacency_table(); + for u in adj.keys() { + if !visited_node.contains(u) + && undirected_graph_detect_cycle_dfs(self, &mut visited_node, None, u) + { + return true; + } + } + false + } + + fn detect_cycle_bfs(&self) -> bool { + let mut visited_node = HashSet::<&String>::new(); + let adj = self.adjacency_table(); + for u in adj.keys() { + if !visited_node.contains(u) + && undirected_graph_detect_cycle_bfs(self, &mut visited_node, u) + { + return true; + } + } + false + } +} + +// Helper function to detect cycle in a directed graph using DFS graph traversal +fn directed_graph_detect_cycle_dfs<'a>( + graph: &'a DirectedGraph, + visited_node: &mut HashSet<&'a String>, + in_stack_visited_node: &mut HashSet<&'a String>, + u: &'a String, +) -> bool { + visited_node.insert(u); + in_stack_visited_node.insert(u); + for (v, _) in graph.adjacency_table().get(u).unwrap() { + if visited_node.contains(v) && in_stack_visited_node.contains(v) { + return true; + } + if !visited_node.contains(v) + && directed_graph_detect_cycle_dfs(graph, visited_node, in_stack_visited_node, v) + { + return true; + } + } + in_stack_visited_node.remove(u); + false +} + +impl DetectCycle for DirectedGraph { + fn detect_cycle_dfs(&self) -> bool { + let mut visited_node = HashSet::<&String>::new(); + let mut in_stack_visited_node = HashSet::<&String>::new(); + let adj = self.adjacency_table(); + for u in adj.keys() { + if !visited_node.contains(u) + && directed_graph_detect_cycle_dfs( + self, + &mut visited_node, + &mut in_stack_visited_node, + u, + ) + { + return true; + } + } + false + } + + // detect cycle in a the graph using Kahn's algorithm + // https://www.geeksforgeeks.org/detect-cycle-in-a-directed-graph-using-bfs/ + fn detect_cycle_bfs(&self) -> bool { + // Set 0 in-degree for each vertex + let mut in_degree: HashMap<&String, usize> = + self.adjacency_table().keys().map(|k| (k, 0)).collect(); + + // Calculate in-degree for each vertex + for u in self.adjacency_table().keys() { + for (v, _) in self.adjacency_table().get(u).unwrap() { + *in_degree.get_mut(v).unwrap() += 1; + } + } + // Initialize queue with vertex having 0 in-degree + let mut queue: VecDeque<&String> = in_degree + .iter() + .filter(|(_, °ree)| degree == 0) + .map(|(&k, _)| k) + .collect(); + + let mut count = 0; + while let Some(u) = queue.pop_front() { + count += 1; + for (v, _) in self.adjacency_table().get(u).unwrap() { + in_degree.entry(v).and_modify(|d| { + *d -= 1; + if *d == 0 { + queue.push_back(v); + } + }); + } + } + + // If count of processed vertices is not equal to the number of vertices, + // the graph has a cycle + count != self.adjacency_table().len() + } +} + +#[cfg(test)] +mod test { + use super::DetectCycle; + use crate::data_structures::{graph::Graph, DirectedGraph, UndirectedGraph}; + fn get_undirected_single_node_with_loop() -> UndirectedGraph { + let mut res = UndirectedGraph::new(); + res.add_edge(("a", "a", 1)); + res + } + fn get_directed_single_node_with_loop() -> DirectedGraph { + let mut res = DirectedGraph::new(); + res.add_edge(("a", "a", 1)); + res + } + fn get_undirected_two_nodes_connected() -> UndirectedGraph { + let mut res = UndirectedGraph::new(); + res.add_edge(("a", "b", 1)); + res + } + fn get_directed_two_nodes_connected() -> DirectedGraph { + let mut res = DirectedGraph::new(); + res.add_edge(("a", "b", 1)); + res.add_edge(("b", "a", 1)); + res + } + fn get_directed_two_nodes() -> DirectedGraph { + let mut res = DirectedGraph::new(); + res.add_edge(("a", "b", 1)); + res + } + fn get_undirected_triangle() -> UndirectedGraph { + let mut res = UndirectedGraph::new(); + res.add_edge(("a", "b", 1)); + res.add_edge(("b", "c", 1)); + res.add_edge(("c", "a", 1)); + res + } + fn get_directed_triangle() -> DirectedGraph { + let mut res = DirectedGraph::new(); + res.add_edge(("a", "b", 1)); + res.add_edge(("b", "c", 1)); + res.add_edge(("c", "a", 1)); + res + } + fn get_undirected_triangle_with_tail() -> UndirectedGraph { + let mut res = get_undirected_triangle(); + res.add_edge(("c", "d", 1)); + res.add_edge(("d", "e", 1)); + res.add_edge(("e", "f", 1)); + res.add_edge(("g", "h", 1)); + res + } + fn get_directed_triangle_with_tail() -> DirectedGraph { + let mut res = get_directed_triangle(); + res.add_edge(("c", "d", 1)); + res.add_edge(("d", "e", 1)); + res.add_edge(("e", "f", 1)); + res.add_edge(("g", "h", 1)); + res + } + fn get_undirected_graph_with_cycle() -> UndirectedGraph { + let mut res = UndirectedGraph::new(); + res.add_edge(("a", "b", 1)); + res.add_edge(("a", "c", 1)); + res.add_edge(("b", "c", 1)); + res.add_edge(("b", "d", 1)); + res.add_edge(("c", "d", 1)); + res + } + fn get_undirected_graph_without_cycle() -> UndirectedGraph { + let mut res = UndirectedGraph::new(); + res.add_edge(("a", "b", 1)); + res.add_edge(("a", "c", 1)); + res.add_edge(("b", "d", 1)); + res.add_edge(("c", "e", 1)); + res + } + fn get_directed_graph_with_cycle() -> DirectedGraph { + let mut res = DirectedGraph::new(); + res.add_edge(("b", "a", 1)); + res.add_edge(("c", "a", 1)); + res.add_edge(("b", "c", 1)); + res.add_edge(("c", "d", 1)); + res.add_edge(("d", "b", 1)); + res + } + fn get_directed_graph_without_cycle() -> DirectedGraph { + let mut res = DirectedGraph::new(); + res.add_edge(("b", "a", 1)); + res.add_edge(("c", "a", 1)); + res.add_edge(("b", "c", 1)); + res.add_edge(("c", "d", 1)); + res.add_edge(("b", "d", 1)); + res + } + macro_rules! test_detect_cycle { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (graph, has_cycle) = $test_case; + println!("detect_cycle_dfs: {}", graph.detect_cycle_dfs()); + println!("detect_cycle_bfs: {}", graph.detect_cycle_bfs()); + assert_eq!(graph.detect_cycle_dfs(), has_cycle); + assert_eq!(graph.detect_cycle_bfs(), has_cycle); + } + )* + }; + } + test_detect_cycle! { + undirected_empty: (UndirectedGraph::new(), false), + directed_empty: (DirectedGraph::new(), false), + undirected_single_node_with_loop: (get_undirected_single_node_with_loop(), true), + directed_single_node_with_loop: (get_directed_single_node_with_loop(), true), + undirected_two_nodes_connected: (get_undirected_two_nodes_connected(), false), + directed_two_nodes_connected: (get_directed_two_nodes_connected(), true), + directed_two_nodes: (get_directed_two_nodes(), false), + undirected_triangle: (get_undirected_triangle(), true), + undirected_triangle_with_tail: (get_undirected_triangle_with_tail(), true), + directed_triangle: (get_directed_triangle(), true), + directed_triangle_with_tail: (get_directed_triangle_with_tail(), true), + undirected_graph_with_cycle: (get_undirected_graph_with_cycle(), true), + undirected_graph_without_cycle: (get_undirected_graph_without_cycle(), false), + directed_graph_with_cycle: (get_directed_graph_with_cycle(), true), + directed_graph_without_cycle: (get_directed_graph_without_cycle(), false), + } +} diff --git a/src/graph/dijkstra.rs b/src/graph/dijkstra.rs new file mode 100644 index 00000000000..8cef293abe7 --- /dev/null +++ b/src/graph/dijkstra.rs @@ -0,0 +1,159 @@ +use std::collections::{BTreeMap, BTreeSet}; +use std::ops::Add; + +type Graph = BTreeMap>; + +// performs Dijsktra's algorithm on the given graph from the given start +// the graph is a positively-weighted directed graph +// +// returns a map that for each reachable vertex associates the distance and the predecessor +// since the start has no predecessor but is reachable, map[start] will be None +// +// Time: O(E * logV). For each vertex, we traverse each edge, resulting in O(E). For each edge, we +// insert a new shortest path for a vertex into the tree, resulting in O(E * logV). +// Space: O(V). The tree holds up to V vertices. +pub fn dijkstra>( + graph: &Graph, + start: V, +) -> BTreeMap> { + let mut ans = BTreeMap::new(); + let mut prio = BTreeSet::new(); + + // start is the special case that doesn't have a predecessor + ans.insert(start, None); + + for (new, weight) in &graph[&start] { + ans.insert(*new, Some((start, *weight))); + prio.insert((*weight, *new)); + } + + while let Some((path_weight, vertex)) = prio.pop_first() { + for (next, weight) in &graph[&vertex] { + let new_weight = path_weight + *weight; + match ans.get(next) { + // if ans[next] is a lower dist than the alternative one, we do nothing + Some(Some((_, dist_next))) if new_weight >= *dist_next => {} + // if ans[next] is None then next is start and so the distance won't be changed, it won't be added again in prio + Some(None) => {} + // the new path is shorter, either new was not in ans or it was farther + _ => { + if let Some(Some((_, prev_weight))) = + ans.insert(*next, Some((vertex, new_weight))) + { + prio.remove(&(prev_weight, *next)); + } + prio.insert((new_weight, *next)); + } + } + } + } + + ans +} + +#[cfg(test)] +mod tests { + use super::{dijkstra, Graph}; + use std::collections::BTreeMap; + + fn add_edge(graph: &mut Graph, v1: V, v2: V, c: E) { + graph.entry(v1).or_default().insert(v2, c); + graph.entry(v2).or_default(); + } + + #[test] + fn single_vertex() { + let mut graph: Graph = BTreeMap::new(); + graph.insert(0, BTreeMap::new()); + + let mut dists = BTreeMap::new(); + dists.insert(0, None); + + assert_eq!(dijkstra(&graph, 0), dists); + } + + #[test] + fn single_edge() { + let mut graph = BTreeMap::new(); + add_edge(&mut graph, 0, 1, 2); + + let mut dists_0 = BTreeMap::new(); + dists_0.insert(0, None); + dists_0.insert(1, Some((0, 2))); + + assert_eq!(dijkstra(&graph, 0), dists_0); + + let mut dists_1 = BTreeMap::new(); + dists_1.insert(1, None); + + assert_eq!(dijkstra(&graph, 1), dists_1); + } + + #[test] + fn tree_1() { + let mut graph = BTreeMap::new(); + let mut dists = BTreeMap::new(); + dists.insert(1, None); + for i in 1..100 { + add_edge(&mut graph, i, i * 2, i * 2); + add_edge(&mut graph, i, i * 2 + 1, i * 2 + 1); + + match dists[&i] { + Some((_, d)) => { + dists.insert(i * 2, Some((i, d + i * 2))); + dists.insert(i * 2 + 1, Some((i, d + i * 2 + 1))); + } + None => { + dists.insert(i * 2, Some((i, i * 2))); + dists.insert(i * 2 + 1, Some((i, i * 2 + 1))); + } + } + } + + assert_eq!(dijkstra(&graph, 1), dists); + } + + #[test] + fn graph_1() { + let mut graph = BTreeMap::new(); + add_edge(&mut graph, 'a', 'c', 12); + add_edge(&mut graph, 'a', 'd', 60); + add_edge(&mut graph, 'b', 'a', 10); + add_edge(&mut graph, 'c', 'b', 20); + add_edge(&mut graph, 'c', 'd', 32); + add_edge(&mut graph, 'e', 'a', 7); + + let mut dists_a = BTreeMap::new(); + dists_a.insert('a', None); + dists_a.insert('c', Some(('a', 12))); + dists_a.insert('d', Some(('c', 44))); + dists_a.insert('b', Some(('c', 32))); + assert_eq!(dijkstra(&graph, 'a'), dists_a); + + let mut dists_b = BTreeMap::new(); + dists_b.insert('b', None); + dists_b.insert('a', Some(('b', 10))); + dists_b.insert('c', Some(('a', 22))); + dists_b.insert('d', Some(('c', 54))); + assert_eq!(dijkstra(&graph, 'b'), dists_b); + + let mut dists_c = BTreeMap::new(); + dists_c.insert('c', None); + dists_c.insert('b', Some(('c', 20))); + dists_c.insert('d', Some(('c', 32))); + dists_c.insert('a', Some(('b', 30))); + assert_eq!(dijkstra(&graph, 'c'), dists_c); + + let mut dists_d = BTreeMap::new(); + dists_d.insert('d', None); + assert_eq!(dijkstra(&graph, 'd'), dists_d); + + let mut dists_e = BTreeMap::new(); + dists_e.insert('e', None); + dists_e.insert('a', Some(('e', 7))); + dists_e.insert('c', Some(('a', 19))); + dists_e.insert('d', Some(('c', 51))); + dists_e.insert('b', Some(('c', 39))); + assert_eq!(dijkstra(&graph, 'e'), dists_e); + } +} diff --git a/src/graph/dinic_maxflow.rs b/src/graph/dinic_maxflow.rs new file mode 100644 index 00000000000..87ff7a7953a --- /dev/null +++ b/src/graph/dinic_maxflow.rs @@ -0,0 +1,213 @@ +use std::collections::VecDeque; +use std::ops::{Add, AddAssign, Neg, Sub, SubAssign}; + +// We assume that graph vertices are numbered from 1 to n. + +/// Adjacency matrix +type Graph = Vec>; + +/// We assume that T::default() gives "zero" flow and T supports negative values +pub struct FlowEdge { + pub sink: usize, + pub capacity: T, + pub flow: T, +} + +pub struct FlowResultEdge { + pub source: usize, + pub sink: usize, + pub flow: T, +} + +impl + SubAssign + Ord + Neg + Default> + FlowEdge +{ + pub fn new(sink: usize, capacity: T) -> Self { + FlowEdge { + sink, + capacity, + flow: T::default(), + } + } +} + +pub struct DinicMaxFlow { + /// BFS Level of each vertex. starts from 1 + level: Vec, + + /// The index of the last visited edge connected to each vertex + pub last_edge: Vec, + + /// Holds wether the solution has already been calculated + network_solved: bool, + + pub source: usize, + pub sink: usize, + + /// Number of edges added to the residual network + pub num_edges: usize, + pub num_vertices: usize, + + pub adj: Graph, + + /// The list of flow edges + pub edges: Vec>, +} + +impl + SubAssign + Neg + Ord + Default> + DinicMaxFlow +{ + pub fn new(source: usize, sink: usize, num_vertices: usize) -> Self { + DinicMaxFlow { + level: vec![0; num_vertices + 1], + last_edge: vec![0; num_vertices + 1], + network_solved: false, + source, + sink, + num_edges: 0, + num_vertices, + adj: vec![vec![]; num_vertices + 1], + edges: vec![], + } + } + #[inline] + pub fn add_edge(&mut self, source: usize, sink: usize, capacity: T) { + self.edges.push(FlowEdge::new(sink, capacity)); + // Add the reverse edge with zero capacity + self.edges.push(FlowEdge::new(source, T::default())); + // We inserted the m'th edge from source to sink + self.adj[source].push(self.num_edges); + self.adj[sink].push(self.num_edges + 1); + self.num_edges += 2; + } + + fn bfs(&mut self) -> bool { + let mut q: VecDeque = VecDeque::new(); + q.push_back(self.source); + + while !q.is_empty() { + let v = q.pop_front().unwrap(); + for &e in self.adj[v].iter() { + if self.edges[e].capacity <= self.edges[e].flow { + continue; + } + let u = self.edges[e].sink; + if self.level[u] != 0 { + continue; + } + self.level[u] = self.level[v] + 1; + q.push_back(u); + } + } + + self.level[self.sink] != 0 + } + + fn dfs(&mut self, v: usize, pushed: T) -> T { + // We have pushed nothing, or we are at the sink + if v == self.sink { + return pushed; + } + for e_pos in self.last_edge[v]..self.adj[v].len() { + let e = self.adj[v][e_pos]; + let u = self.edges[e].sink; + if (self.level[v] + 1) != self.level[u] || self.edges[e].capacity <= self.edges[e].flow + { + continue; + } + let down_flow = self.dfs( + u, + std::cmp::min(pushed, self.edges[e].capacity - self.edges[e].flow), + ); + if down_flow == T::default() { + continue; + } + self.last_edge[v] = e_pos; + self.edges[e].flow += down_flow; + self.edges[e ^ 1].flow -= down_flow; + return down_flow; + } + self.last_edge[v] = self.adj[v].len(); + T::default() + } + + pub fn find_maxflow(&mut self, infinite_flow: T) -> T { + self.network_solved = true; + let mut total_flow: T = T::default(); + loop { + self.level.fill(0); + self.level[self.source] = 1; + // There is no longer a path from source to sink in the residual + // network + if !self.bfs() { + break; + } + self.last_edge.fill(0); + let mut next_flow = self.dfs(self.source, infinite_flow); + while next_flow != T::default() { + total_flow += next_flow; + next_flow = self.dfs(self.source, infinite_flow); + } + } + total_flow + } + + pub fn get_flow_edges(&mut self, infinite_flow: T) -> Vec> { + if !self.network_solved { + self.find_maxflow(infinite_flow); + } + let mut result = Vec::new(); + for v in 1..self.adj.len() { + for &e_ind in self.adj[v].iter() { + let e = &self.edges[e_ind]; + // Make sure that reverse edges from residual network are not + // included + if e.flow > T::default() { + result.push(FlowResultEdge { + source: v, + sink: e.sink, + flow: e.flow, + }); + } + } + } + result + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn small_graph() { + let mut flow: DinicMaxFlow = DinicMaxFlow::new(1, 6, 6); + flow.add_edge(1, 2, 16); + flow.add_edge(1, 4, 13); + flow.add_edge(2, 3, 12); + flow.add_edge(3, 4, 9); + flow.add_edge(3, 6, 20); + flow.add_edge(4, 2, 4); + flow.add_edge(4, 5, 14); + flow.add_edge(5, 3, 7); + flow.add_edge(5, 6, 4); + + let max_flow = flow.find_maxflow(i32::MAX); + assert_eq!(max_flow, 23); + + let mut sm_out = [0; 7]; + let mut sm_in = [0; 7]; + + let flow_edges = flow.get_flow_edges(i32::MAX); + for e in flow_edges { + sm_out[e.source] += e.flow; + sm_in[e.sink] += e.flow; + } + for i in 2..=5 { + assert_eq!(sm_in[i], sm_out[i]); + } + assert_eq!(sm_in[1], 0); + assert_eq!(sm_out[1], max_flow); + assert_eq!(sm_in[6], max_flow); + assert_eq!(sm_out[6], 0); + } +} diff --git a/src/graph/disjoint_set_union.rs b/src/graph/disjoint_set_union.rs new file mode 100644 index 00000000000..d20701c8c00 --- /dev/null +++ b/src/graph/disjoint_set_union.rs @@ -0,0 +1,148 @@ +//! This module implements the Disjoint Set Union (DSU), also known as Union-Find, +//! which is an efficient data structure for keeping track of a set of elements +//! partitioned into disjoint (non-overlapping) subsets. + +/// Represents a node in the Disjoint Set Union (DSU) structure which +/// keep track of the parent-child relationships in the disjoint sets. +pub struct DSUNode { + /// The index of the node's parent, or itself if it's the root. + parent: usize, + /// The size of the set rooted at this node, used for union by size. + size: usize, +} + +/// Disjoint Set Union (Union-Find) data structure, particularly useful for +/// managing dynamic connectivity problems such as determining +/// if two elements are in the same subset or merging two subsets. +pub struct DisjointSetUnion { + /// List of DSU nodes where each element's parent and size are tracked. + nodes: Vec, +} + +impl DisjointSetUnion { + /// Initializes `n + 1` disjoint sets, each element is its own parent. + /// + /// # Parameters + /// + /// - `n`: The number of elements to manage (`0` to `n` inclusive). + /// + /// # Returns + /// + /// A new instance of `DisjointSetUnion` with `n + 1` independent sets. + pub fn new(num_elements: usize) -> DisjointSetUnion { + let mut nodes = Vec::with_capacity(num_elements + 1); + for idx in 0..=num_elements { + nodes.push(DSUNode { + parent: idx, + size: 1, + }); + } + + Self { nodes } + } + + /// Finds the representative (root) of the set containing `element` with path compression. + /// + /// Path compression ensures that future queries are faster by directly linking + /// all nodes in the path to the root. + /// + /// # Parameters + /// + /// - `element`: The element whose set representative is being found. + /// + /// # Returns + /// + /// The root representative of the set containing `element`. + pub fn find_set(&mut self, element: usize) -> usize { + if element != self.nodes[element].parent { + self.nodes[element].parent = self.find_set(self.nodes[element].parent); + } + self.nodes[element].parent + } + + /// Merges the sets containing `first_elem` and `sec_elem` using union by size. + /// + /// The smaller set is always attached to the root of the larger set to ensure balanced trees. + /// + /// # Parameters + /// + /// - `first_elem`: The first element whose set is to be merged. + /// - `sec_elem`: The second element whose set is to be merged. + /// + /// # Returns + /// + /// The root of the merged set, or `usize::MAX` if both elements are already in the same set. + pub fn merge(&mut self, first_elem: usize, sec_elem: usize) -> usize { + let mut first_root = self.find_set(first_elem); + let mut sec_root = self.find_set(sec_elem); + + if first_root == sec_root { + // Already in the same set, no merge required + return usize::MAX; + } + + // Union by size: attach the smaller tree under the larger tree + if self.nodes[first_root].size < self.nodes[sec_root].size { + std::mem::swap(&mut first_root, &mut sec_root); + } + + self.nodes[sec_root].parent = first_root; + self.nodes[first_root].size += self.nodes[sec_root].size; + + first_root + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_disjoint_set_union() { + let mut dsu = DisjointSetUnion::new(10); + + dsu.merge(1, 2); + dsu.merge(2, 3); + dsu.merge(1, 9); + dsu.merge(4, 5); + dsu.merge(7, 8); + dsu.merge(4, 8); + dsu.merge(6, 9); + + assert_eq!(dsu.find_set(1), dsu.find_set(2)); + assert_eq!(dsu.find_set(1), dsu.find_set(3)); + assert_eq!(dsu.find_set(1), dsu.find_set(6)); + assert_eq!(dsu.find_set(1), dsu.find_set(9)); + + assert_eq!(dsu.find_set(4), dsu.find_set(5)); + assert_eq!(dsu.find_set(4), dsu.find_set(7)); + assert_eq!(dsu.find_set(4), dsu.find_set(8)); + + assert_ne!(dsu.find_set(1), dsu.find_set(10)); + assert_ne!(dsu.find_set(4), dsu.find_set(10)); + + dsu.merge(3, 4); + + assert_eq!(dsu.find_set(1), dsu.find_set(2)); + assert_eq!(dsu.find_set(1), dsu.find_set(3)); + assert_eq!(dsu.find_set(1), dsu.find_set(6)); + assert_eq!(dsu.find_set(1), dsu.find_set(9)); + assert_eq!(dsu.find_set(1), dsu.find_set(4)); + assert_eq!(dsu.find_set(1), dsu.find_set(5)); + assert_eq!(dsu.find_set(1), dsu.find_set(7)); + assert_eq!(dsu.find_set(1), dsu.find_set(8)); + + assert_ne!(dsu.find_set(1), dsu.find_set(10)); + + dsu.merge(10, 1); + assert_eq!(dsu.find_set(10), dsu.find_set(1)); + assert_eq!(dsu.find_set(10), dsu.find_set(2)); + assert_eq!(dsu.find_set(10), dsu.find_set(3)); + assert_eq!(dsu.find_set(10), dsu.find_set(4)); + assert_eq!(dsu.find_set(10), dsu.find_set(5)); + assert_eq!(dsu.find_set(10), dsu.find_set(6)); + assert_eq!(dsu.find_set(10), dsu.find_set(7)); + assert_eq!(dsu.find_set(10), dsu.find_set(8)); + assert_eq!(dsu.find_set(10), dsu.find_set(9)); + } +} diff --git a/src/graph/eulerian_path.rs b/src/graph/eulerian_path.rs new file mode 100644 index 00000000000..d37ee053f43 --- /dev/null +++ b/src/graph/eulerian_path.rs @@ -0,0 +1,380 @@ +//! This module provides functionality to find an Eulerian path in a directed graph. +//! An Eulerian path visits every edge exactly once. The algorithm checks if an Eulerian +//! path exists and, if so, constructs and returns the path. + +use std::collections::LinkedList; + +/// Finds an Eulerian path in a directed graph. +/// +/// # Arguments +/// +/// * `node_count` - The number of nodes in the graph. +/// * `edge_list` - A vector of tuples representing directed edges, where each tuple is of the form `(start, end)`. +/// +/// # Returns +/// +/// An `Option>` containing the Eulerian path if it exists; otherwise, `None`. +pub fn find_eulerian_path(node_count: usize, edge_list: Vec<(usize, usize)>) -> Option> { + let mut adjacency_list = vec![Vec::new(); node_count]; + for (start, end) in edge_list { + adjacency_list[start].push(end); + } + + let mut eulerian_solver = EulerianPathSolver::new(adjacency_list); + eulerian_solver.find_path() +} + +/// Struct to represent the solver for finding an Eulerian path in a directed graph. +pub struct EulerianPathSolver { + node_count: usize, + edge_count: usize, + in_degrees: Vec, + out_degrees: Vec, + eulerian_path: LinkedList, + adjacency_list: Vec>, +} + +impl EulerianPathSolver { + /// Creates a new instance of `EulerianPathSolver`. + /// + /// # Arguments + /// + /// * `adjacency_list` - The graph represented as an adjacency list. + /// + /// # Returns + /// + /// A new instance of `EulerianPathSolver`. + pub fn new(adjacency_list: Vec>) -> Self { + Self { + node_count: adjacency_list.len(), + edge_count: 0, + in_degrees: vec![0; adjacency_list.len()], + out_degrees: vec![0; adjacency_list.len()], + eulerian_path: LinkedList::new(), + adjacency_list, + } + } + + /// Find the Eulerian path if it exists. + /// + /// # Returns + /// + /// An `Option>` containing the Eulerian path if found; otherwise, `None`. + /// + /// If multiple Eulerian paths exist, the one found will be returned, but it may not be unique. + fn find_path(&mut self) -> Option> { + self.initialize_degrees(); + + if !self.has_eulerian_path() { + return None; + } + + let start_node = self.get_start_node(); + self.depth_first_search(start_node); + + if self.eulerian_path.len() != self.edge_count + 1 { + return None; + } + + let mut path = Vec::with_capacity(self.edge_count + 1); + while let Some(node) = self.eulerian_path.pop_front() { + path.push(node); + } + + Some(path) + } + + /// Initializes in-degrees and out-degrees for each node and counts total edges. + fn initialize_degrees(&mut self) { + for (start_node, neighbors) in self.adjacency_list.iter().enumerate() { + for &end_node in neighbors { + self.in_degrees[end_node] += 1; + self.out_degrees[start_node] += 1; + self.edge_count += 1; + } + } + } + + /// Checks if an Eulerian path exists in the graph. + /// + /// # Returns + /// + /// `true` if an Eulerian path exists; otherwise, `false`. + fn has_eulerian_path(&self) -> bool { + if self.edge_count == 0 { + return false; + } + + let (mut start_nodes, mut end_nodes) = (0, 0); + for i in 0..self.node_count { + let (in_degree, out_degree) = + (self.in_degrees[i] as isize, self.out_degrees[i] as isize); + match out_degree - in_degree { + 1 => start_nodes += 1, + -1 => end_nodes += 1, + degree_diff if degree_diff.abs() > 1 => return false, + _ => (), + } + } + + (start_nodes == 0 && end_nodes == 0) || (start_nodes == 1 && end_nodes == 1) + } + + /// Finds the starting node for the Eulerian path. + /// + /// # Returns + /// + /// The index of the starting node. + fn get_start_node(&self) -> usize { + for i in 0..self.node_count { + if self.out_degrees[i] > self.in_degrees[i] { + return i; + } + } + (0..self.node_count) + .find(|&i| self.out_degrees[i] > 0) + .unwrap_or(0) + } + + /// Performs depth-first search to construct the Eulerian path. + /// + /// # Arguments + /// + /// * `curr_node` - The current node being visited in the DFS traversal. + fn depth_first_search(&mut self, curr_node: usize) { + while self.out_degrees[curr_node] > 0 { + let next_node = self.adjacency_list[curr_node][self.out_degrees[curr_node] - 1]; + self.out_degrees[curr_node] -= 1; + self.depth_first_search(next_node); + } + self.eulerian_path.push_front(curr_node); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (n, edges, expected) = $test_case; + assert_eq!(find_eulerian_path(n, edges), expected); + } + )* + } + } + + test_cases! { + test_eulerian_cycle: ( + 7, + vec![ + (1, 2), + (1, 3), + (2, 2), + (2, 4), + (2, 4), + (3, 1), + (3, 2), + (3, 5), + (4, 3), + (4, 6), + (5, 6), + (6, 3) + ], + Some(vec![1, 3, 5, 6, 3, 2, 4, 3, 1, 2, 2, 4, 6]) + ), + test_simple_path: ( + 5, + vec![ + (0, 1), + (1, 2), + (1, 4), + (1, 3), + (2, 1), + (4, 1) + ], + Some(vec![0, 1, 4, 1, 2, 1, 3]) + ), + test_disconnected_graph: ( + 4, + vec![ + (0, 1), + (2, 3) + ], + None::> + ), + test_single_cycle: ( + 4, + vec![ + (0, 1), + (1, 2), + (2, 3), + (3, 0) + ], + Some(vec![0, 1, 2, 3, 0]) + ), + test_empty_graph: ( + 3, + vec![], + None::> + ), + test_unbalanced_path: ( + 3, + vec![ + (0, 1), + (1, 2), + (2, 0), + (0, 2) + ], + Some(vec![0, 2, 0, 1, 2]) + ), + test_no_eulerian_path: ( + 3, + vec![ + (0, 1), + (0, 2) + ], + None::> + ), + test_complex_eulerian_path: ( + 6, + vec![ + (0, 1), + (1, 2), + (2, 3), + (3, 4), + (4, 0), + (0, 5), + (5, 0), + (2, 0) + ], + Some(vec![2, 0, 5, 0, 1, 2, 3, 4, 0]) + ), + test_single_node_self_loop: ( + 1, + vec![(0, 0)], + Some(vec![0, 0]) + ), + test_complete_graph: ( + 3, + vec![ + (0, 1), + (0, 2), + (1, 0), + (1, 2), + (2, 0), + (2, 1) + ], + Some(vec![0, 2, 1, 2, 0, 1, 0]) + ), + test_multiple_disconnected_components: ( + 6, + vec![ + (0, 1), + (2, 3), + (4, 5) + ], + None::> + ), + test_unbalanced_graph_with_path: ( + 4, + vec![ + (0, 1), + (1, 2), + (2, 3), + (3, 1) + ], + Some(vec![0, 1, 2, 3, 1]) + ), + test_node_with_no_edges: ( + 4, + vec![ + (0, 1), + (1, 2) + ], + Some(vec![0, 1, 2]) + ), + test_multiple_edges_between_same_nodes: ( + 3, + vec![ + (0, 1), + (1, 2), + (1, 2), + (2, 0) + ], + Some(vec![1, 2, 0, 1, 2]) + ), + test_larger_graph_with_eulerian_path: ( + 10, + vec![ + (0, 1), + (1, 2), + (2, 3), + (3, 4), + (4, 5), + (5, 6), + (6, 7), + (7, 8), + (8, 9), + (9, 0), + (1, 6), + (6, 3), + (3, 8) + ], + Some(vec![1, 6, 3, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8]) + ), + test_no_edges_multiple_nodes: ( + 5, + vec![], + None::> + ), + test_multiple_start_and_end_nodes: ( + 4, + vec![ + (0, 1), + (1, 2), + (2, 0), + (0, 2), + (1, 3) + ], + None::> + ), + test_single_edge: ( + 2, + vec![(0, 1)], + Some(vec![0, 1]) + ), + test_multiple_eulerian_paths: ( + 4, + vec![ + (0, 1), + (1, 2), + (2, 0), + (0, 3), + (3, 0) + ], + Some(vec![0, 3, 0, 1, 2, 0]) + ), + test_dag_path: ( + 4, + vec![ + (0, 1), + (1, 2), + (2, 3) + ], + Some(vec![0, 1, 2, 3]) + ), + test_parallel_edges_case: ( + 2, + vec![ + (0, 1), + (0, 1), + (1, 0) + ], + Some(vec![0, 1, 0, 1]) + ), + } +} diff --git a/src/graph/floyd_warshall.rs b/src/graph/floyd_warshall.rs new file mode 100644 index 00000000000..0a78b992e5d --- /dev/null +++ b/src/graph/floyd_warshall.rs @@ -0,0 +1,191 @@ +use num_traits::Zero; +use std::collections::BTreeMap; +use std::ops::Add; + +type Graph = BTreeMap>; + +/// Performs the Floyd-Warshall algorithm on the input graph.\ +/// The graph is a weighted, directed graph with no negative cycles. +/// +/// Returns a map storing the distance from each node to all the others.\ +/// i.e. For each vertex `u`, `map[u][v] == Some(distance)` means +/// distance is the sum of the weights of the edges on the shortest path +/// from `u` to `v`. +/// +/// For a key `v`, if `map[v].len() == 0`, then `v` cannot reach any other vertex, but is in the graph +/// (island node, or sink in the case of a directed graph) +pub fn floyd_warshall + num_traits::Zero>( + graph: &Graph, +) -> BTreeMap> { + let mut map: BTreeMap> = BTreeMap::new(); + for (u, edges) in graph.iter() { + if !map.contains_key(u) { + map.insert(*u, BTreeMap::new()); + } + map.entry(*u).or_default().insert(*u, Zero::zero()); + for (v, weight) in edges.iter() { + if !map.contains_key(v) { + map.insert(*v, BTreeMap::new()); + } + map.entry(*v).or_default().insert(*v, Zero::zero()); + map.entry(*u).and_modify(|mp| { + mp.insert(*v, *weight); + }); + } + } + let keys = map.keys().copied().collect::>(); + for &k in &keys { + for &i in &keys { + if !map[&i].contains_key(&k) { + continue; + } + for &j in &keys { + if i == j { + continue; + } + if !map[&k].contains_key(&j) { + continue; + } + let entry_i_j = map[&i].get(&j); + let entry_i_k = map[&i][&k]; + let entry_k_j = map[&k][&j]; + match entry_i_j { + Some(&e) => { + if e > entry_i_k + entry_k_j { + map.entry(i).or_default().insert(j, entry_i_k + entry_k_j); + } + } + None => { + map.entry(i).or_default().insert(j, entry_i_k + entry_k_j); + } + }; + } + } + } + map +} + +#[cfg(test)] +mod tests { + use super::{floyd_warshall, Graph}; + use std::collections::BTreeMap; + + fn add_edge(graph: &mut Graph, v1: V, v2: V, c: E) { + graph.entry(v1).or_default().insert(v2, c); + } + + fn bi_add_edge(graph: &mut Graph, v1: V, v2: V, c: E) { + add_edge(graph, v1, v2, c); + add_edge(graph, v2, v1, c); + } + + #[test] + fn single_vertex() { + let mut graph: Graph = BTreeMap::new(); + graph.insert(0, BTreeMap::new()); + + let mut dists = BTreeMap::new(); + dists.insert(0, BTreeMap::new()); + dists.get_mut(&0).unwrap().insert(0, 0); + assert_eq!(floyd_warshall(&graph), dists); + } + + #[test] + fn single_edge() { + let mut graph = BTreeMap::new(); + bi_add_edge(&mut graph, 0, 1, 2); + bi_add_edge(&mut graph, 1, 2, 3); + + let mut dists_0 = BTreeMap::new(); + dists_0.insert(0, BTreeMap::new()); + dists_0.insert(1, BTreeMap::new()); + dists_0.insert(2, BTreeMap::new()); + dists_0.get_mut(&0).unwrap().insert(0, 0); + dists_0.get_mut(&1).unwrap().insert(1, 0); + dists_0.get_mut(&2).unwrap().insert(2, 0); + dists_0.get_mut(&1).unwrap().insert(0, 2); + dists_0.get_mut(&0).unwrap().insert(1, 2); + dists_0.get_mut(&1).unwrap().insert(2, 3); + dists_0.get_mut(&2).unwrap().insert(1, 3); + dists_0.get_mut(&2).unwrap().insert(0, 5); + dists_0.get_mut(&0).unwrap().insert(2, 5); + + assert_eq!(floyd_warshall(&graph), dists_0); + } + + #[test] + fn graph_1() { + let mut graph = BTreeMap::new(); + add_edge(&mut graph, 'a', 'c', 12); + add_edge(&mut graph, 'a', 'd', 60); + add_edge(&mut graph, 'b', 'a', 10); + add_edge(&mut graph, 'c', 'b', 20); + add_edge(&mut graph, 'c', 'd', 32); + add_edge(&mut graph, 'e', 'a', 7); + + let mut dists_a = BTreeMap::new(); + dists_a.insert('d', BTreeMap::new()); + + dists_a.entry('a').or_insert(BTreeMap::new()).insert('a', 0); + dists_a.entry('b').or_insert(BTreeMap::new()).insert('b', 0); + dists_a.entry('c').or_insert(BTreeMap::new()).insert('c', 0); + dists_a.entry('d').or_insert(BTreeMap::new()).insert('d', 0); + dists_a.entry('e').or_insert(BTreeMap::new()).insert('e', 0); + dists_a + .entry('a') + .or_insert(BTreeMap::new()) + .insert('c', 12); + dists_a + .entry('c') + .or_insert(BTreeMap::new()) + .insert('a', 30); + dists_a + .entry('c') + .or_insert(BTreeMap::new()) + .insert('b', 20); + dists_a + .entry('c') + .or_insert(BTreeMap::new()) + .insert('d', 32); + dists_a.entry('e').or_insert(BTreeMap::new()).insert('a', 7); + dists_a + .entry('b') + .or_insert(BTreeMap::new()) + .insert('a', 10); + dists_a + .entry('a') + .or_insert(BTreeMap::new()) + .insert('d', 44); + dists_a + .entry('a') + .or_insert(BTreeMap::new()) + .insert('b', 32); + dists_a + .entry('a') + .or_insert(BTreeMap::new()) + .insert('b', 32); + dists_a + .entry('b') + .or_insert(BTreeMap::new()) + .insert('c', 22); + + dists_a + .entry('b') + .or_insert(BTreeMap::new()) + .insert('d', 54); + dists_a + .entry('e') + .or_insert(BTreeMap::new()) + .insert('c', 19); + dists_a + .entry('e') + .or_insert(BTreeMap::new()) + .insert('d', 51); + dists_a + .entry('e') + .or_insert(BTreeMap::new()) + .insert('b', 39); + + assert_eq!(floyd_warshall(&graph), dists_a); + } +} diff --git a/src/graph/ford_fulkerson.rs b/src/graph/ford_fulkerson.rs new file mode 100644 index 00000000000..c6a2f310ebe --- /dev/null +++ b/src/graph/ford_fulkerson.rs @@ -0,0 +1,314 @@ +//! The Ford-Fulkerson algorithm is a widely used algorithm to solve the maximum flow problem in a flow network. +//! +//! The maximum flow problem involves determining the maximum amount of flow that can be sent from a source vertex to a sink vertex +//! in a directed weighted graph, subject to capacity constraints on the edges. + +use std::collections::VecDeque; + +/// Enum representing the possible errors that can occur when running the Ford-Fulkerson algorithm. +#[derive(Debug, PartialEq)] +pub enum FordFulkersonError { + EmptyGraph, + ImproperGraph, + SourceOutOfBounds, + SinkOutOfBounds, +} + +/// Performs a Breadth-First Search (BFS) on the residual graph to find an augmenting path +/// from the source vertex `source` to the sink vertex `sink`. +/// +/// # Arguments +/// +/// * `graph` - A reference to the residual graph represented as an adjacency matrix. +/// * `source` - The source vertex. +/// * `sink` - The sink vertex. +/// * `parent` - A mutable reference to the parent array used to store the augmenting path. +/// +/// # Returns +/// +/// Returns `true` if an augmenting path is found from `source` to `sink`, `false` otherwise. +fn bfs(graph: &[Vec], source: usize, sink: usize, parent: &mut [usize]) -> bool { + let mut visited = vec![false; graph.len()]; + visited[source] = true; + parent[source] = usize::MAX; + + let mut queue = VecDeque::new(); + queue.push_back(source); + + while let Some(current_vertex) = queue.pop_front() { + for (previous_vertex, &capacity) in graph[current_vertex].iter().enumerate() { + if !visited[previous_vertex] && capacity > 0 { + visited[previous_vertex] = true; + parent[previous_vertex] = current_vertex; + if previous_vertex == sink { + return true; + } + queue.push_back(previous_vertex); + } + } + } + + false +} + +/// Validates the input parameters for the Ford-Fulkerson algorithm. +/// +/// This function checks if the provided graph, source vertex, and sink vertex +/// meet the requirements for the Ford-Fulkerson algorithm. It ensures the graph +/// is non-empty, square (each row has the same length as the number of rows), and +/// that the source and sink vertices are within the valid range of vertex indices. +/// +/// # Arguments +/// +/// * `graph` - A reference to the flow network represented as an adjacency matrix. +/// * `source` - The source vertex. +/// * `sink` - The sink vertex. +/// +/// # Returns +/// +/// Returns `Ok(())` if the input parameters are valid, otherwise returns an appropriate +/// `FordFulkersonError`. +fn validate_ford_fulkerson_input( + graph: &[Vec], + source: usize, + sink: usize, +) -> Result<(), FordFulkersonError> { + if graph.is_empty() { + return Err(FordFulkersonError::EmptyGraph); + } + + if graph.iter().any(|row| row.len() != graph.len()) { + return Err(FordFulkersonError::ImproperGraph); + } + + if source >= graph.len() { + return Err(FordFulkersonError::SourceOutOfBounds); + } + + if sink >= graph.len() { + return Err(FordFulkersonError::SinkOutOfBounds); + } + + Ok(()) +} + +/// Applies the Ford-Fulkerson algorithm to find the maximum flow in a flow network +/// represented by a weighted directed graph. +/// +/// # Arguments +/// +/// * `graph` - A mutable reference to the flow network represented as an adjacency matrix. +/// * `source` - The source vertex. +/// * `sink` - The sink vertex. +/// +/// # Returns +/// +/// Returns the maximum flow and the residual graph +pub fn ford_fulkerson( + graph: &[Vec], + source: usize, + sink: usize, +) -> Result { + validate_ford_fulkerson_input(graph, source, sink)?; + + let mut residual_graph = graph.to_owned(); + let mut parent = vec![usize::MAX; graph.len()]; + let mut max_flow = 0; + + while bfs(&residual_graph, source, sink, &mut parent) { + let mut path_flow = usize::MAX; + let mut previous_vertex = sink; + + while previous_vertex != source { + let current_vertex = parent[previous_vertex]; + path_flow = path_flow.min(residual_graph[current_vertex][previous_vertex]); + previous_vertex = current_vertex; + } + + previous_vertex = sink; + while previous_vertex != source { + let current_vertex = parent[previous_vertex]; + residual_graph[current_vertex][previous_vertex] -= path_flow; + residual_graph[previous_vertex][current_vertex] += path_flow; + previous_vertex = current_vertex; + } + + max_flow += path_flow; + } + + Ok(max_flow) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_max_flow { + ($($name:ident: $tc:expr,)* ) => { + $( + #[test] + fn $name() { + let (graph, source, sink, expected_result) = $tc; + assert_eq!(ford_fulkerson(&graph, source, sink), expected_result); + } + )* + }; + } + + test_max_flow! { + test_empty_graph: ( + vec![], + 0, + 0, + Err(FordFulkersonError::EmptyGraph), + ), + test_source_out_of_bound: ( + vec![ + vec![0, 8, 0, 0, 3, 0], + vec![0, 0, 9, 0, 0, 0], + vec![0, 0, 0, 0, 7, 2], + vec![0, 0, 0, 0, 0, 5], + vec![0, 0, 7, 4, 0, 0], + vec![0, 0, 0, 0, 0, 0], + ], + 6, + 5, + Err(FordFulkersonError::SourceOutOfBounds), + ), + test_sink_out_of_bound: ( + vec![ + vec![0, 8, 0, 0, 3, 0], + vec![0, 0, 9, 0, 0, 0], + vec![0, 0, 0, 0, 7, 2], + vec![0, 0, 0, 0, 0, 5], + vec![0, 0, 7, 4, 0, 0], + vec![0, 0, 0, 0, 0, 0], + ], + 0, + 6, + Err(FordFulkersonError::SinkOutOfBounds), + ), + test_improper_graph: ( + vec![ + vec![0, 8], + vec![0], + ], + 0, + 1, + Err(FordFulkersonError::ImproperGraph), + ), + test_graph_with_small_flow: ( + vec![ + vec![0, 8, 0, 0, 3, 0], + vec![0, 0, 9, 0, 0, 0], + vec![0, 0, 0, 0, 7, 2], + vec![0, 0, 0, 0, 0, 5], + vec![0, 0, 7, 4, 0, 0], + vec![0, 0, 0, 0, 0, 0], + ], + 0, + 5, + Ok(6), + ), + test_graph_with_medium_flow: ( + vec![ + vec![0, 10, 0, 10, 0, 0], + vec![0, 0, 4, 2, 8, 0], + vec![0, 0, 0, 0, 0, 10], + vec![0, 0, 0, 0, 9, 0], + vec![0, 0, 6, 0, 0, 10], + vec![0, 0, 0, 0, 0, 0], + ], + 0, + 5, + Ok(19), + ), + test_graph_with_large_flow: ( + vec![ + vec![0, 12, 0, 13, 0, 0], + vec![0, 0, 10, 0, 0, 0], + vec![0, 0, 0, 13, 3, 15], + vec![0, 0, 7, 0, 15, 0], + vec![0, 0, 6, 0, 0, 17], + vec![0, 0, 0, 0, 0, 0], + ], + 0, + 5, + Ok(23), + ), + test_complex_graph: ( + vec![ + vec![0, 16, 13, 0, 0, 0], + vec![0, 0, 10, 12, 0, 0], + vec![0, 4, 0, 0, 14, 0], + vec![0, 0, 9, 0, 0, 20], + vec![0, 0, 0, 7, 0, 4], + vec![0, 0, 0, 0, 0, 0], + ], + 0, + 5, + Ok(23), + ), + test_disconnected_graph: ( + vec![ + vec![0, 0, 0, 0], + vec![0, 0, 0, 1], + vec![0, 0, 0, 1], + vec![0, 0, 0, 0], + ], + 0, + 3, + Ok(0), + ), + test_unconnected_sink: ( + vec![ + vec![0, 4, 0, 3, 0, 0], + vec![0, 0, 4, 0, 8, 0], + vec![0, 0, 0, 3, 0, 2], + vec![0, 0, 0, 0, 6, 0], + vec![0, 0, 6, 0, 0, 6], + vec![0, 0, 0, 0, 0, 0], + ], + 0, + 5, + Ok(7), + ), + test_no_edges: ( + vec![ + vec![0, 0, 0], + vec![0, 0, 0], + vec![0, 0, 0], + ], + 0, + 2, + Ok(0), + ), + test_single_vertex: ( + vec![ + vec![0], + ], + 0, + 0, + Ok(0), + ), + test_self_loop: ( + vec![ + vec![10, 0], + vec![0, 0], + ], + 0, + 1, + Ok(0), + ), + test_same_source_sink: ( + vec![ + vec![0, 10, 10], + vec![0, 0, 10], + vec![0, 0, 0], + ], + 0, + 0, + Ok(0), + ), + } +} diff --git a/src/graph/graph_enumeration.rs b/src/graph/graph_enumeration.rs new file mode 100644 index 00000000000..24326c84aa7 --- /dev/null +++ b/src/graph/graph_enumeration.rs @@ -0,0 +1,61 @@ +use std::collections::BTreeMap; + +type Graph = BTreeMap>; + +/* +This function creates a graph with vertices numbered from 1 to n for any input +`Graph`. The result is in the form of Vec to make implementing +other algorithms on the graph easier and help with performance. + +We expect that all vertices, even the isolated ones, to have an entry in `adj` +(possibly an empty vector) +*/ +pub fn enumerate_graph(adj: &Graph) -> Vec> { + let mut result = vec![vec![]; adj.len() + 1]; + let ordering: Vec = adj.keys().cloned().collect(); + for (zero_idx, edges) in adj.values().enumerate() { + let idx = zero_idx + 1; + result[idx] = edges + .iter() + .map(|x| ordering.binary_search(x).unwrap() + 1) + .collect(); + } + result +} + +#[cfg(test)] +mod tests { + use super::*; + fn add_edge(graph: &mut Graph, a: V, b: V) { + graph.entry(a.clone()).or_default().push(b.clone()); + graph.entry(b).or_default().push(a); + } + + #[test] + fn string_vertices() { + let mut graph = Graph::new(); + add_edge(&mut graph, "a", "b"); + add_edge(&mut graph, "b", "c"); + add_edge(&mut graph, "c", "a"); + add_edge(&mut graph, "b", "d"); + let mut result = enumerate_graph(&graph); + let expected = vec![vec![], vec![2, 3], vec![1, 3, 4], vec![1, 2], vec![2]]; + + result.iter_mut().for_each(|v| v.sort_unstable()); + assert_eq!(result, expected); + } + + #[test] + fn integer_vertices() { + let mut graph = Graph::new(); + add_edge(&mut graph, 1001, 1002); + add_edge(&mut graph, 1002, 1003); + add_edge(&mut graph, 1003, 1001); + add_edge(&mut graph, 1004, 1002); + let mut result = enumerate_graph(&graph); + let expected = vec![vec![], vec![2, 3], vec![1, 3, 4], vec![1, 2], vec![2]]; + + result.iter_mut().for_each(|v| v.sort_unstable()); + assert_eq!(result, expected); + } +} diff --git a/src/graph/heavy_light_decomposition.rs b/src/graph/heavy_light_decomposition.rs new file mode 100644 index 00000000000..e96c8152a54 --- /dev/null +++ b/src/graph/heavy_light_decomposition.rs @@ -0,0 +1,192 @@ +/* +Heavy Light Decomposition: +It partitions a tree into disjoint paths such that: +1. Each path is a part of some leaf's path to root +2. The number of paths from any vertex to the root is of O(lg(n)) +Such a decomposition can be used to answer many types of queries about vertices +or edges on a particular path. It is often used with some sort of binary tree +to handle different operations on the paths, for example segment tree or +fenwick tree. + +Many members of this struct are made public, because they can either be +supplied by the developer, or can be useful for other parts of the code. + +The implementation assumes that the tree vertices are numbered from 1 to n +and it is represented using (compressed) adjacency matrix. If this is not true, +maybe `graph_enumeration.rs` can help. +*/ + +type Adj = [Vec]; + +pub struct HeavyLightDecomposition { + // Each vertex is assigned a number from 1 to n. For `v` and `u` such that + // u is parent of v, and both are in path `p`, it is true that: + // position[u] = position[v] - 1 + pub position: Vec, + + // The first (closest to root) vertex of the path containing each vertex + pub head: Vec, + + // The "heaviest" child of each vertex, its subtree is at least as big as + // the other ones. If `v` is a leaf, big_child[v] = 0 + pub big_child: Vec, + + // Used internally to fill `position` Vec + current_position: usize, +} + +impl HeavyLightDecomposition { + pub fn new(mut num_vertices: usize) -> Self { + num_vertices += 1; + HeavyLightDecomposition { + position: vec![0; num_vertices], + head: vec![0; num_vertices], + big_child: vec![0; num_vertices], + current_position: 1, + } + } + fn dfs(&mut self, v: usize, parent: usize, adj: &Adj) -> usize { + let mut big_child = 0usize; + let mut bc_size = 0usize; // big child size + let mut subtree_size = 1usize; // size of this subtree + for &u in adj[v].iter() { + if u == parent { + continue; + } + let u_size = self.dfs(u, v, adj); + subtree_size += u_size; + if u_size > bc_size { + big_child = u; + bc_size = u_size; + } + } + self.big_child[v] = big_child; + subtree_size + } + pub fn decompose(&mut self, root: usize, adj: &Adj) { + self.current_position = 1; + self.dfs(root, 0, adj); + self.decompose_path(root, 0, root, adj); + } + fn decompose_path(&mut self, v: usize, parent: usize, head: usize, adj: &Adj) { + self.head[v] = head; + self.position[v] = self.current_position; + self.current_position += 1; + let bc = self.big_child[v]; + if bc != 0 { + // Continue this path + self.decompose_path(bc, v, head, adj); + } + for &u in adj[v].iter() { + if u == parent || u == bc { + continue; + } + // Start a new path + self.decompose_path(u, v, u, adj); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct LinearCongruenceGenerator { + // modulus as 2 ^ 32 + multiplier: u32, + increment: u32, + state: u32, + } + + impl LinearCongruenceGenerator { + fn new(multiplier: u32, increment: u32, state: u32) -> Self { + Self { + multiplier, + increment, + state, + } + } + fn next(&mut self) -> u32 { + self.state = + (self.multiplier as u64 * self.state as u64 + self.increment as u64) as u32; + self.state + } + } + + fn get_num_paths( + hld: &HeavyLightDecomposition, + mut v: usize, + parent: &[usize], + ) -> (usize, usize) { + // Return height and number of paths + let mut ans = 0usize; + let mut height = 0usize; + let mut prev_head = 0usize; + loop { + height += 1; + let head = hld.head[v]; + if head != prev_head { + ans += 1; + prev_head = head; + } + v = parent[v]; + if v == 0 { + break; + } + } + (ans, height) + } + + #[test] + fn single_path() { + let mut adj = vec![vec![], vec![2], vec![3], vec![4], vec![5], vec![6], vec![]]; + let mut hld = HeavyLightDecomposition::new(6); + hld.decompose(1, &adj); + assert_eq!(hld.head, vec![0, 1, 1, 1, 1, 1, 1]); + assert_eq!(hld.position, vec![0, 1, 2, 3, 4, 5, 6]); + assert_eq!(hld.big_child, vec![0, 2, 3, 4, 5, 6, 0]); + + adj[3].push(2); + adj[2].push(1); + hld.decompose(3, &adj); + assert_eq!(hld.head, vec![0, 2, 2, 3, 3, 3, 3]); + assert_eq!(hld.position, vec![0, 6, 5, 1, 2, 3, 4]); + assert_eq!(hld.big_child, vec![0, 0, 1, 4, 5, 6, 0]); + } + + #[test] + fn random_tree() { + // Let it have 1e4 vertices. It should finish under 100ms even with + // 1e5 vertices + let n = 1e4 as usize; + let threshold = 14; // 2 ^ 14 = 16384 > n + let mut adj: Vec> = vec![vec![]; n + 1]; + let mut parent: Vec = vec![0; n + 1]; + let mut hld = HeavyLightDecomposition::new(n); + let mut lcg = LinearCongruenceGenerator::new(1103515245, 12345, 314); + parent[2] = 1; + adj[1].push(2); + #[allow(clippy::needless_range_loop)] + for i in 3..=n { + // randomly determine the parent of each vertex. + // There will be modulus bias, but it isn't important + let par_max = i - 1; + let par_min = (10 * par_max + 1) / 11; + // Bring par_min closer to par_max to increase expected tree height + let par = (lcg.next() as usize % (par_max - par_min + 1)) + par_min; + adj[par].push(i); + parent[i] = par; + } + // let's get a few leaves + let leaves: Vec = (1..=n) + .rev() + .filter(|&v| adj[v].is_empty()) + .take(100) + .collect(); + hld.decompose(1, &adj); + for l in leaves { + let (p, _h) = get_num_paths(&hld, l, &parent); + assert!(p <= threshold); + } + } +} diff --git a/src/graph/kosaraju.rs b/src/graph/kosaraju.rs new file mode 100644 index 00000000000..842ddd5ffc1 --- /dev/null +++ b/src/graph/kosaraju.rs @@ -0,0 +1,157 @@ +// Kosaraju algorithm, a linear-time algorithm to find the strongly connected components (SCCs) of a directed graph, in Rust. +pub struct Graph { + vertices: usize, + adj_list: Vec>, + transpose_adj_list: Vec>, +} + +impl Graph { + pub fn new(vertices: usize) -> Self { + Graph { + vertices, + adj_list: vec![vec![]; vertices], + transpose_adj_list: vec![vec![]; vertices], + } + } + + pub fn add_edge(&mut self, u: usize, v: usize) { + self.adj_list[u].push(v); + self.transpose_adj_list[v].push(u); + } + + pub fn dfs(&self, node: usize, visited: &mut Vec, stack: &mut Vec) { + visited[node] = true; + for &neighbor in &self.adj_list[node] { + if !visited[neighbor] { + self.dfs(neighbor, visited, stack); + } + } + stack.push(node); + } + + pub fn dfs_scc(&self, node: usize, visited: &mut Vec, scc: &mut Vec) { + visited[node] = true; + scc.push(node); + for &neighbor in &self.transpose_adj_list[node] { + if !visited[neighbor] { + self.dfs_scc(neighbor, visited, scc); + } + } + } +} + +pub fn kosaraju(graph: &Graph) -> Vec> { + let mut visited = vec![false; graph.vertices]; + let mut stack = Vec::new(); + + for i in 0..graph.vertices { + if !visited[i] { + graph.dfs(i, &mut visited, &mut stack); + } + } + + let mut sccs = Vec::new(); + visited = vec![false; graph.vertices]; + + while let Some(node) = stack.pop() { + if !visited[node] { + let mut scc = Vec::new(); + graph.dfs_scc(node, &mut visited, &mut scc); + sccs.push(scc); + } + } + + sccs +} +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_kosaraju_single_sccs() { + let vertices = 5; + let mut graph = Graph::new(vertices); + + graph.add_edge(0, 1); + graph.add_edge(1, 2); + graph.add_edge(2, 3); + graph.add_edge(2, 4); + graph.add_edge(3, 0); + graph.add_edge(4, 2); + + let sccs = kosaraju(&graph); + assert_eq!(sccs.len(), 1); + assert!(sccs.contains(&vec![0, 3, 2, 1, 4])); + } + + #[test] + fn test_kosaraju_multiple_sccs() { + let vertices = 8; + let mut graph = Graph::new(vertices); + + graph.add_edge(1, 0); + graph.add_edge(0, 1); + graph.add_edge(1, 2); + graph.add_edge(2, 0); + graph.add_edge(2, 3); + graph.add_edge(3, 4); + graph.add_edge(4, 5); + graph.add_edge(5, 6); + graph.add_edge(6, 7); + graph.add_edge(4, 7); + graph.add_edge(6, 4); + + let sccs = kosaraju(&graph); + assert_eq!(sccs.len(), 4); + assert!(sccs.contains(&vec![0, 1, 2])); + assert!(sccs.contains(&vec![3])); + assert!(sccs.contains(&vec![4, 6, 5])); + assert!(sccs.contains(&vec![7])); + } + + #[test] + fn test_kosaraju_multiple_sccs1() { + let vertices = 8; + let mut graph = Graph::new(vertices); + graph.add_edge(0, 2); + graph.add_edge(1, 0); + graph.add_edge(2, 3); + graph.add_edge(3, 4); + graph.add_edge(4, 7); + graph.add_edge(5, 2); + graph.add_edge(5, 6); + graph.add_edge(6, 5); + graph.add_edge(7, 6); + + let sccs = kosaraju(&graph); + assert_eq!(sccs.len(), 3); + assert!(sccs.contains(&vec![0])); + assert!(sccs.contains(&vec![1])); + assert!(sccs.contains(&vec![2, 5, 6, 7, 4, 3])); + } + + #[test] + fn test_kosaraju_no_scc() { + let vertices = 4; + let mut graph = Graph::new(vertices); + + graph.add_edge(0, 1); + graph.add_edge(1, 2); + graph.add_edge(2, 3); + + let sccs = kosaraju(&graph); + assert_eq!(sccs.len(), 4); + for (i, _) in sccs.iter().enumerate().take(vertices) { + assert_eq!(sccs[i], vec![i]); + } + } + + #[test] + fn test_kosaraju_empty_graph() { + let vertices = 0; + let graph = Graph::new(vertices); + + let sccs = kosaraju(&graph); + assert_eq!(sccs.len(), 0); + } +} diff --git a/src/graph/lee_breadth_first_search.rs b/src/graph/lee_breadth_first_search.rs new file mode 100644 index 00000000000..e0c25fd7906 --- /dev/null +++ b/src/graph/lee_breadth_first_search.rs @@ -0,0 +1,117 @@ +use std::collections::VecDeque; + +// All four potential movements from a cell are listed here. + +fn validate(matrix: &[Vec], visited: &[Vec], row: isize, col: isize) -> bool { + // Check if it is possible to move to the position (row, col) from the current cell. + let (row, col) = (row as usize, col as usize); + row < matrix.len() && col < matrix[0].len() && matrix[row][col] == 1 && !visited[row][col] +} + +pub fn lee(matrix: Vec>, source: (usize, usize), destination: (usize, usize)) -> isize { + const ROW: [isize; 4] = [-1, 0, 0, 1]; + const COL: [isize; 4] = [0, -1, 1, 0]; + let (i, j) = source; + let (x, y) = destination; + + // Base case: invalid input + if matrix.is_empty() || matrix[i][j] == 0 || matrix[x][y] == 0 { + return -1; + } + + let (m, n) = (matrix.len(), matrix[0].len()); + let mut visited = vec![vec![false; n]; m]; + let mut q = VecDeque::new(); + visited[i][j] = true; + q.push_back((i, j, 0)); + let mut min_dist = isize::MAX; + + // Loop until the queue is empty + while let Some((i, j, dist)) = q.pop_front() { + if i == x && j == y { + // If the destination is found, update `min_dist` and stop + min_dist = dist; + break; + } + + // Check for all four possible movements from the current cell + for k in 0..ROW.len() { + let row = i as isize + ROW[k]; + let col = j as isize + COL[k]; + if validate(&matrix, &visited, row, col) { + // Mark the next cell as visited and enqueue it + let (row, col) = (row as usize, col as usize); + visited[row][col] = true; + q.push_back((row, col, dist + 1)); + } + } + } + + if min_dist != isize::MAX { + min_dist + } else { + -1 + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_lee_exists() { + let mat: Vec> = vec![ + vec![1, 0, 1, 1, 1], + vec![1, 0, 1, 0, 1], + vec![1, 1, 1, 0, 1], + vec![0, 0, 0, 0, 1], + vec![1, 1, 1, 0, 1], + ]; + let source = (0, 0); + let dest = (2, 1); + assert_eq!(lee(mat, source, dest), 3); + } + + #[test] + fn test_lee_does_not_exist() { + let mat: Vec> = vec![ + vec![1, 0, 1, 1, 1], + vec![1, 0, 0, 0, 1], + vec![1, 1, 1, 0, 1], + vec![0, 0, 0, 0, 1], + vec![1, 1, 1, 0, 1], + ]; + let source = (0, 0); + let dest = (3, 4); + assert_eq!(lee(mat, source, dest), -1); + } + + #[test] + fn test_source_equals_destination() { + let mat: Vec> = vec![ + vec![1, 0, 1, 1, 1], + vec![1, 0, 1, 0, 1], + vec![1, 1, 1, 0, 1], + vec![0, 0, 0, 0, 1], + vec![1, 1, 1, 0, 1], + ]; + let source = (2, 1); + let dest = (2, 1); + assert_eq!(lee(mat, source, dest), 0); + } + + #[test] + fn test_lee_exists_2() { + let mat: Vec> = vec![ + vec![1, 1, 1, 1, 1, 0, 0], + vec![1, 1, 1, 1, 1, 1, 0], + vec![1, 0, 1, 0, 1, 1, 1], + vec![1, 1, 1, 1, 1, 0, 1], + vec![0, 0, 0, 1, 0, 0, 0], + vec![1, 0, 1, 1, 1, 0, 0], + vec![0, 0, 0, 0, 1, 0, 0], + ]; + let source = (0, 0); + let dest = (3, 2); + assert_eq!(lee(mat, source, dest), 5); + } +} diff --git a/src/graph/lowest_common_ancestor.rs b/src/graph/lowest_common_ancestor.rs new file mode 100644 index 00000000000..2485f9cb216 --- /dev/null +++ b/src/graph/lowest_common_ancestor.rs @@ -0,0 +1,211 @@ +/* + Note: We will assume that here tree vertices are numbered from 1 to n. +If a tree is not enumerated that way or its vertices are not represented +using numbers, it can trivially be converted using Depth First Search +manually or by using `src/graph/graph_enumeration.rs` + + Here we implement two different algorithms: +- The online one is implemented using Sparse Table and has O(n.lg(n)) +time complexity and memory usage. It answers each query in O(lg(n)). +- The offline algorithm was discovered by Robert Tarjan. At first each +query should be determined and saved. Then, vertices are visited in +Depth First Search order and queries are answered using Disjoint +Set Union algorithm. The time complexity is O(n.alpha(n) + q) and +memory usage is O(n + q), but time complexity can be considered to be O(n + q), +because alpha(n) < 5 for n < 10 ^ 600 + */ + +use super::DisjointSetUnion; +pub struct LowestCommonAncestorOnline { + // Make members public to allow the user to fill them themself. + pub parents_sparse_table: Vec>, + pub height: Vec, +} + +impl LowestCommonAncestorOnline { + // Should be called once as: + // fill_sparse_table(tree_root, 0, 0, adjacency_list) + #[inline] + fn get_parent(&self, v: usize, i: usize) -> usize { + self.parents_sparse_table[v][i] + } + #[inline] + fn num_parents(&self, v: usize) -> usize { + self.parents_sparse_table[v].len() + } + pub fn new(num_vertices: usize) -> Self { + let mut pars = vec![vec![0]; num_vertices + 1]; + pars[0].clear(); + LowestCommonAncestorOnline { + parents_sparse_table: pars, + height: vec![0; num_vertices + 1], + } + } + pub fn fill_sparse_table( + &mut self, + vertex: usize, + parent: usize, + height: usize, + adj: &[Vec], + ) { + self.parents_sparse_table[vertex][0] = parent; + self.height[vertex] = height; + let mut level = 1; + let mut current_parent = parent; + while self.num_parents(current_parent) >= level { + current_parent = self.get_parent(current_parent, level - 1); + level += 1; + self.parents_sparse_table[vertex].push(current_parent); + } + for &child in adj[vertex].iter() { + if child == parent { + // It isn't a child! + continue; + } + self.fill_sparse_table(child, vertex, height + 1, adj); + } + } + + pub fn get_ancestor(&self, mut v: usize, mut u: usize) -> usize { + if self.height[v] < self.height[u] { + std::mem::swap(&mut v, &mut u); + } + // Bring v up to so that it has the same height as u + let height_diff = self.height[v] - self.height[u]; + for i in 0..63 { + let bit = 1 << i; + if bit > height_diff { + break; + } + if height_diff & bit != 0 { + v = self.get_parent(v, i); + } + } + if u == v { + return u; + } + // `self.num_parents` of u and v should be equal + for i in (0..self.num_parents(v)).rev() { + let nv = self.get_parent(v, i); + let nu = self.get_parent(u, i); + if nv != nu { + u = nu; + v = nv; + } + } + self.get_parent(v, 0) + } +} + +#[derive(Clone, Copy)] +pub struct LCAQuery { + other: usize, + query_id: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct QueryAnswer { + query_id: usize, + answer: usize, +} + +pub struct LowestCommonAncestorOffline { + pub queries: Vec>, + dsu: DisjointSetUnion, + /* + The LSB of dsu_parent[v] determines whether it was visited or not. + The rest of the number determines the vertex that represents a + particular set in DSU. + */ + dsu_parent: Vec, +} + +impl LowestCommonAncestorOffline { + pub fn new(num_vertices: usize) -> Self { + LowestCommonAncestorOffline { + queries: vec![vec![]; num_vertices + 1], + dsu: DisjointSetUnion::new(num_vertices), + dsu_parent: vec![0; num_vertices + 1], + } + } + pub fn add_query(&mut self, u: usize, v: usize, query_id: usize) { + // We should add this query to both vertices, and it will be answered + // the second time it is seen in DFS. + self.queries[u].push(LCAQuery { other: v, query_id }); + if u == v { + return; + } + self.queries[v].push(LCAQuery { other: u, query_id }); + } + + fn calculate_answers( + &mut self, + vertex: usize, + parent: usize, + adj: &[Vec], + answers: &mut Vec, + ) { + self.dsu_parent[vertex] = (vertex as u64) << 1; + for &child in adj[vertex].iter() { + if child == parent { + continue; + } + self.calculate_answers(child, vertex, adj, answers); + self.dsu.merge(child, vertex); + let set = self.dsu.find_set(vertex); + self.dsu_parent[set] = ((vertex as u64) << 1) | (self.dsu_parent[set] & 1); + } + self.dsu_parent[vertex] |= 0b1; + for &query in self.queries[vertex].iter() { + if self.dsu_parent[query.other] & 1 != 0 { + // It has been visited + answers.push(QueryAnswer { + query_id: query.query_id, + answer: (self.dsu_parent[self.dsu.find_set(query.other)] >> 1) as usize, + }); + } + } + } + pub fn answer_queries(&mut self, root: usize, adj: &[Vec]) -> Vec { + let mut answers = Vec::new(); + self.calculate_answers(root, 0, adj, &mut answers); + answers + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn small_binary_tree() { + let num_verts = 127; + let mut tree: Vec> = vec![vec![]; num_verts + 1]; + for i in 1..=num_verts >> 1 { + let left_child = i << 1; + let right_child = left_child + 1; + tree[i].push(left_child); + tree[i].push(right_child); + tree[left_child].push(i); + tree[right_child].push(i); + } + let mut online_answers: Vec = Vec::new(); + let mut online = LowestCommonAncestorOnline::new(num_verts); + let mut offline = LowestCommonAncestorOffline::new(num_verts); + let mut query_id = 314; // A random number, doesn't matter + online.fill_sparse_table(1, 0, 0, &tree); + for i in 1..=num_verts { + for j in 1..i { + // Query every possible pair + online_answers.push(QueryAnswer { + query_id, + answer: online.get_ancestor(i, j), + }); + offline.add_query(i, j, query_id); + query_id += 1; + } + } + let mut offline_answers = offline.answer_queries(1, &tree); + offline_answers.sort_unstable_by(|a1, a2| a1.query_id.cmp(&a2.query_id)); + assert_eq!(offline_answers, online_answers); + } +} diff --git a/src/graph/minimum_spanning_tree.rs b/src/graph/minimum_spanning_tree.rs new file mode 100644 index 00000000000..9d36cafb303 --- /dev/null +++ b/src/graph/minimum_spanning_tree.rs @@ -0,0 +1,159 @@ +//! This module implements Kruskal's algorithm to find the Minimum Spanning Tree (MST) +//! of an undirected, weighted graph using a Disjoint Set Union (DSU) for cycle detection. + +use crate::graph::DisjointSetUnion; + +/// Represents an edge in the graph with a source, destination, and associated cost. +#[derive(Debug, PartialEq, Eq)] +pub struct Edge { + /// The starting vertex of the edge. + source: usize, + /// The ending vertex of the edge. + destination: usize, + /// The cost associated with the edge. + cost: usize, +} + +impl Edge { + /// Creates a new edge with the specified source, destination, and cost. + pub fn new(source: usize, destination: usize, cost: usize) -> Self { + Self { + source, + destination, + cost, + } + } +} + +/// Executes Kruskal's algorithm to compute the Minimum Spanning Tree (MST) of a graph. +/// +/// # Parameters +/// +/// - `edges`: A vector of `Edge` instances representing all edges in the graph. +/// - `num_vertices`: The total number of vertices in the graph. +/// +/// # Returns +/// +/// An `Option` containing a tuple with: +/// +/// - The total cost of the MST (usize). +/// - A vector of edges that are included in the MST. +/// +/// Returns `None` if the graph is disconnected. +/// +/// # Complexity +/// +/// The time complexity is O(E log E), where E is the number of edges. +pub fn kruskal(mut edges: Vec, num_vertices: usize) -> Option<(usize, Vec)> { + let mut dsu = DisjointSetUnion::new(num_vertices); + let mut mst_cost: usize = 0; + let mut mst_edges: Vec = Vec::with_capacity(num_vertices - 1); + + // Sort edges by cost in ascending order + edges.sort_unstable_by_key(|edge| edge.cost); + + for edge in edges { + if mst_edges.len() == num_vertices - 1 { + break; + } + + // Attempt to merge the sets containing the edge’s vertices + if dsu.merge(edge.source, edge.destination) != usize::MAX { + mst_cost += edge.cost; + mst_edges.push(edge); + } + } + + // Return MST if it includes exactly num_vertices - 1 edges, otherwise None for disconnected graphs + (mst_edges.len() == num_vertices - 1).then_some((mst_cost, mst_edges)) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (edges, num_vertices, expected_result) = $test_case; + let actual_result = kruskal(edges, num_vertices); + assert_eq!(actual_result, expected_result); + } + )* + }; + } + + test_cases! { + test_seven_vertices_eleven_edges: ( + vec![ + Edge::new(0, 1, 7), + Edge::new(0, 3, 5), + Edge::new(1, 2, 8), + Edge::new(1, 3, 9), + Edge::new(1, 4, 7), + Edge::new(2, 4, 5), + Edge::new(3, 4, 15), + Edge::new(3, 5, 6), + Edge::new(4, 5, 8), + Edge::new(4, 6, 9), + Edge::new(5, 6, 11), + ], + 7, + Some((39, vec![ + Edge::new(0, 3, 5), + Edge::new(2, 4, 5), + Edge::new(3, 5, 6), + Edge::new(0, 1, 7), + Edge::new(1, 4, 7), + Edge::new(4, 6, 9), + ])) + ), + test_ten_vertices_twenty_edges: ( + vec![ + Edge::new(0, 1, 3), + Edge::new(0, 3, 6), + Edge::new(0, 4, 9), + Edge::new(1, 2, 2), + Edge::new(1, 3, 4), + Edge::new(1, 4, 9), + Edge::new(2, 3, 2), + Edge::new(2, 5, 8), + Edge::new(2, 6, 9), + Edge::new(3, 6, 9), + Edge::new(4, 5, 8), + Edge::new(4, 9, 18), + Edge::new(5, 6, 7), + Edge::new(5, 8, 9), + Edge::new(5, 9, 10), + Edge::new(6, 7, 4), + Edge::new(6, 8, 5), + Edge::new(7, 8, 1), + Edge::new(7, 9, 4), + Edge::new(8, 9, 3), + ], + 10, + Some((38, vec![ + Edge::new(7, 8, 1), + Edge::new(1, 2, 2), + Edge::new(2, 3, 2), + Edge::new(0, 1, 3), + Edge::new(8, 9, 3), + Edge::new(6, 7, 4), + Edge::new(5, 6, 7), + Edge::new(2, 5, 8), + Edge::new(4, 5, 8), + ])) + ), + test_disconnected_graph: ( + vec![ + Edge::new(0, 1, 4), + Edge::new(0, 2, 6), + Edge::new(3, 4, 2), + ], + 5, + None + ), + } +} diff --git a/src/graph/mod.rs b/src/graph/mod.rs new file mode 100644 index 00000000000..d4b0b0d00cb --- /dev/null +++ b/src/graph/mod.rs @@ -0,0 +1,55 @@ +mod astar; +mod bellman_ford; +mod bipartite_matching; +mod breadth_first_search; +mod centroid_decomposition; +mod decremental_connectivity; +mod depth_first_search; +mod depth_first_search_tic_tac_toe; +mod detect_cycle; +mod dijkstra; +mod dinic_maxflow; +mod disjoint_set_union; +mod eulerian_path; +mod floyd_warshall; +mod ford_fulkerson; +mod graph_enumeration; +mod heavy_light_decomposition; +mod kosaraju; +mod lee_breadth_first_search; +mod lowest_common_ancestor; +mod minimum_spanning_tree; +mod prim; +mod prufer_code; +mod strongly_connected_components; +mod tarjans_ssc; +mod topological_sort; +mod two_satisfiability; + +pub use self::astar::astar; +pub use self::bellman_ford::bellman_ford; +pub use self::bipartite_matching::BipartiteMatching; +pub use self::breadth_first_search::breadth_first_search; +pub use self::centroid_decomposition::CentroidDecomposition; +pub use self::decremental_connectivity::DecrementalConnectivity; +pub use self::depth_first_search::depth_first_search; +pub use self::depth_first_search_tic_tac_toe::minimax; +pub use self::detect_cycle::DetectCycle; +pub use self::dijkstra::dijkstra; +pub use self::dinic_maxflow::DinicMaxFlow; +pub use self::disjoint_set_union::DisjointSetUnion; +pub use self::eulerian_path::find_eulerian_path; +pub use self::floyd_warshall::floyd_warshall; +pub use self::ford_fulkerson::ford_fulkerson; +pub use self::graph_enumeration::enumerate_graph; +pub use self::heavy_light_decomposition::HeavyLightDecomposition; +pub use self::kosaraju::kosaraju; +pub use self::lee_breadth_first_search::lee; +pub use self::lowest_common_ancestor::{LowestCommonAncestorOffline, LowestCommonAncestorOnline}; +pub use self::minimum_spanning_tree::kruskal; +pub use self::prim::{prim, prim_with_start}; +pub use self::prufer_code::{prufer_decode, prufer_encode}; +pub use self::strongly_connected_components::StronglyConnectedComponents; +pub use self::tarjans_ssc::tarjan_scc; +pub use self::topological_sort::topological_sort; +pub use self::two_satisfiability::solve_two_satisfiability; diff --git a/src/graph/prim.rs b/src/graph/prim.rs new file mode 100644 index 00000000000..2fac8883572 --- /dev/null +++ b/src/graph/prim.rs @@ -0,0 +1,199 @@ +use std::cmp::Reverse; +use std::collections::{BTreeMap, BinaryHeap}; +use std::ops::Add; + +type Graph = BTreeMap>; + +fn add_edge(graph: &mut Graph, v1: V, v2: V, c: E) { + graph.entry(v1).or_default().insert(v2, c); + graph.entry(v2).or_default().insert(v1, c); +} + +// selects a start and run the algorithm from it +pub fn prim( + graph: &Graph, +) -> Graph { + match graph.keys().next() { + Some(v) => prim_with_start(graph, *v), + None => BTreeMap::new(), + } +} + +// only works for a connected graph +// if the given graph is not connected it will return the MST of the connected subgraph +pub fn prim_with_start( + graph: &Graph, + start: V, +) -> Graph { + // will contain the MST + let mut mst: Graph = Graph::new(); + // a priority queue based on a binary heap, used to get the cheapest edge + // the elements are an edge: the cost, destination and source + let mut prio = BinaryHeap::new(); + + mst.insert(start, BTreeMap::new()); + + for (v, c) in &graph[&start] { + // the heap is a max heap, we have to use Reverse when adding to simulate a min heap + prio.push(Reverse((*c, v, start))); + } + + while let Some(Reverse((dist, t, prev))) = prio.pop() { + // the destination of the edge has already been seen + if mst.contains_key(t) { + continue; + } + + // the destination is a new vertex + add_edge(&mut mst, prev, *t, dist); + + for (v, c) in &graph[t] { + if !mst.contains_key(v) { + prio.push(Reverse((*c, v, *t))); + } + } + } + + mst +} + +#[cfg(test)] +mod tests { + use super::{add_edge, prim, Graph}; + use std::collections::BTreeMap; + + #[test] + fn empty() { + assert_eq!(prim::(&BTreeMap::new()), BTreeMap::new()); + } + + #[test] + fn single_vertex() { + let mut graph: Graph = BTreeMap::new(); + graph.insert(42, BTreeMap::new()); + + assert_eq!(prim(&graph), graph); + } + + #[test] + fn single_edge() { + let mut graph = BTreeMap::new(); + + add_edge(&mut graph, 42, 666, 12); + + assert_eq!(prim(&graph), graph); + } + + #[test] + fn tree_1() { + let mut graph = BTreeMap::new(); + + add_edge(&mut graph, 0, 1, 10); + add_edge(&mut graph, 0, 2, 11); + add_edge(&mut graph, 2, 3, 12); + add_edge(&mut graph, 2, 4, 13); + add_edge(&mut graph, 1, 5, 14); + add_edge(&mut graph, 1, 6, 15); + add_edge(&mut graph, 3, 7, 16); + + assert_eq!(prim(&graph), graph); + } + + #[test] + fn tree_2() { + let mut graph = BTreeMap::new(); + + add_edge(&mut graph, 1, 2, 11); + add_edge(&mut graph, 2, 3, 12); + add_edge(&mut graph, 2, 4, 13); + add_edge(&mut graph, 4, 5, 14); + add_edge(&mut graph, 4, 6, 15); + add_edge(&mut graph, 6, 7, 16); + + assert_eq!(prim(&graph), graph); + } + + #[test] + fn tree_3() { + let mut graph = BTreeMap::new(); + + for i in 1..100 { + add_edge(&mut graph, i, 2 * i, i); + add_edge(&mut graph, i, 2 * i + 1, -i); + } + + assert_eq!(prim(&graph), graph); + } + + #[test] + fn graph_1() { + let mut graph = BTreeMap::new(); + add_edge(&mut graph, 'a', 'b', 6); + add_edge(&mut graph, 'a', 'c', 7); + add_edge(&mut graph, 'a', 'e', 2); + add_edge(&mut graph, 'a', 'f', 3); + add_edge(&mut graph, 'b', 'c', 5); + add_edge(&mut graph, 'c', 'e', 5); + add_edge(&mut graph, 'd', 'e', 4); + add_edge(&mut graph, 'd', 'f', 1); + add_edge(&mut graph, 'e', 'f', 2); + + let mut ans = BTreeMap::new(); + add_edge(&mut ans, 'd', 'f', 1); + add_edge(&mut ans, 'e', 'f', 2); + add_edge(&mut ans, 'a', 'e', 2); + add_edge(&mut ans, 'b', 'c', 5); + add_edge(&mut ans, 'c', 'e', 5); + + assert_eq!(prim(&graph), ans); + } + + #[test] + fn graph_2() { + let mut graph = BTreeMap::new(); + add_edge(&mut graph, 1, 2, 6); + add_edge(&mut graph, 1, 3, 1); + add_edge(&mut graph, 1, 4, 5); + add_edge(&mut graph, 2, 3, 5); + add_edge(&mut graph, 2, 5, 3); + add_edge(&mut graph, 3, 4, 5); + add_edge(&mut graph, 3, 5, 6); + add_edge(&mut graph, 3, 6, 4); + add_edge(&mut graph, 4, 6, 2); + add_edge(&mut graph, 5, 6, 6); + + let mut ans = BTreeMap::new(); + add_edge(&mut ans, 1, 3, 1); + add_edge(&mut ans, 4, 6, 2); + add_edge(&mut ans, 2, 5, 3); + add_edge(&mut ans, 2, 3, 5); + add_edge(&mut ans, 3, 6, 4); + + assert_eq!(prim(&graph), ans); + } + + #[test] + fn graph_3() { + let mut graph = BTreeMap::new(); + add_edge(&mut graph, "v1", "v2", 1); + add_edge(&mut graph, "v1", "v3", 3); + add_edge(&mut graph, "v1", "v5", 6); + add_edge(&mut graph, "v2", "v3", 2); + add_edge(&mut graph, "v2", "v4", 3); + add_edge(&mut graph, "v2", "v5", 5); + add_edge(&mut graph, "v3", "v4", 5); + add_edge(&mut graph, "v3", "v6", 2); + add_edge(&mut graph, "v4", "v5", 2); + add_edge(&mut graph, "v4", "v6", 4); + add_edge(&mut graph, "v5", "v6", 1); + + let mut ans = BTreeMap::new(); + add_edge(&mut ans, "v1", "v2", 1); + add_edge(&mut ans, "v5", "v6", 1); + add_edge(&mut ans, "v2", "v3", 2); + add_edge(&mut ans, "v3", "v6", 2); + add_edge(&mut ans, "v4", "v5", 2); + + assert_eq!(prim(&graph), ans); + } +} diff --git a/src/graph/prufer_code.rs b/src/graph/prufer_code.rs new file mode 100644 index 00000000000..0c965b8cb50 --- /dev/null +++ b/src/graph/prufer_code.rs @@ -0,0 +1,127 @@ +use std::collections::{BTreeMap, BTreeSet, BinaryHeap}; + +type Graph = BTreeMap>; + +pub fn prufer_encode(tree: &Graph) -> Vec { + if tree.len() <= 2 { + return vec![]; + } + let mut result: Vec = Vec::with_capacity(tree.len() - 2); + let mut queue = BinaryHeap::new(); + let mut in_tree = BTreeSet::new(); + let mut degree = BTreeMap::new(); + for (vertex, adj) in tree { + in_tree.insert(*vertex); + degree.insert(*vertex, adj.len()); + if adj.len() == 1 { + queue.push(*vertex); + } + } + for _ in 2..tree.len() { + let v = queue.pop().unwrap(); + in_tree.remove(&v); + let u = tree[&v].iter().find(|u| in_tree.contains(u)).unwrap(); + result.push(*u); + *degree.get_mut(u).unwrap() -= 1; + if degree[u] == 1 { + queue.push(*u); + } + } + result +} + +#[inline] +fn add_directed_edge(tree: &mut Graph, a: V, b: V) { + tree.entry(a).or_default().push(b); +} + +#[inline] +fn add_edge(tree: &mut Graph, a: V, b: V) { + add_directed_edge(tree, a, b); + add_directed_edge(tree, b, a); +} + +pub fn prufer_decode(code: &[V], vertex_list: &[V]) -> Graph { + // For many cases, this function won't fail even if given unsuitable code + // array. As such, returning really unlikely errors doesn't make much sense. + let mut result = BTreeMap::new(); + let mut list_count: BTreeMap = BTreeMap::new(); + for vertex in code { + *list_count.entry(*vertex).or_insert(0) += 1; + } + let mut queue = BinaryHeap::from( + vertex_list + .iter() + .filter(|v| !list_count.contains_key(v)) + .cloned() + .collect::>(), + ); + for vertex in code { + let child = queue.pop().unwrap(); + add_edge(&mut result, child, *vertex); + let cnt = list_count.get_mut(vertex).unwrap(); + *cnt -= 1; + if *cnt == 0 { + queue.push(*vertex); + } + } + let u = queue.pop().unwrap(); + let v = queue.pop().unwrap(); + add_edge(&mut result, u, v); + result +} + +#[cfg(test)] +mod tests { + use super::{add_edge, prufer_decode, prufer_encode, Graph}; + + fn equal_graphs(g1: &mut Graph, g2: &mut Graph) -> bool { + for adj in g1.values_mut() { + adj.sort(); + } + for adj in g2.values_mut() { + adj.sort(); + } + g1 == g2 + } + + #[test] + fn small_trees() { + let mut g: Graph = Graph::new(); + // Binary tree with 7 vertices + let edges = vec![(1, 2), (1, 3), (2, 4), (2, 5), (3, 6), (3, 7)]; + for (u, v) in edges { + add_edge(&mut g, u, v); + } + let code = prufer_encode(&g); + let vertices = g.keys().cloned().collect::>(); + let mut decoded = prufer_decode(&code, &vertices); + assert_eq!(code, vec![3, 3, 2, 2, 1]); + assert!(equal_graphs(&mut g, &mut decoded)); + + g.clear(); + // A path of length 10 + for v in 2..=9 { + g.insert(v, vec![v - 1, v + 1]); + } + g.insert(1, vec![2]); + g.insert(10, vec![9]); + let code = prufer_encode(&g); + let vertices = g.keys().cloned().collect::>(); + let mut decoded = prufer_decode(&code, &vertices); + assert_eq!(code, vec![9, 8, 7, 6, 5, 4, 3, 2]); + assert!(equal_graphs(&mut g, &mut decoded)); + + g.clear(); + // 7-5-3-1-2-4-6 + let edges = vec![(1, 2), (2, 4), (4, 6), (1, 3), (3, 5), (5, 7)]; + for (u, v) in edges { + add_edge(&mut g, u, v); + } + let code = prufer_encode(&g); + let vertices = g.keys().cloned().collect::>(); + let mut decoded = prufer_decode(&code, &vertices); + assert_eq!(code, vec![5, 4, 3, 2, 1]); + assert!(equal_graphs(&mut g, &mut decoded)); + } +} diff --git a/src/graph/strongly_connected_components.rs b/src/graph/strongly_connected_components.rs new file mode 100644 index 00000000000..b821b51d4aa --- /dev/null +++ b/src/graph/strongly_connected_components.rs @@ -0,0 +1,164 @@ +/* +Tarjan's algorithm to find Strongly Connected Components (SCCs): +It runs in O(n + m) (so it is optimal) and as a by-product, it returns the +components in some (reverse) topologically sorted order. + +We assume that graph is represented using (compressed) adjacency matrix +and its vertices are numbered from 1 to n. If this is not the case, one +can use `src/graph/graph_enumeration.rs` to convert their graph. +*/ + +pub struct StronglyConnectedComponents { + // The number of the SCC the vertex is in, starting from 1 + pub component: Vec, + + // The discover time of the vertex with minimum discover time reachable + // from this vertex. The MSB of the numbers are used to save whether the + // vertex has been visited (but the MSBs are cleared after + // the algorithm is done) + pub state: Vec, + + // The total number of SCCs + pub num_components: usize, + + // The stack of vertices that DFS has seen (used internally) + stack: Vec, + // Used internally during DFS to know the current discover time + current_time: usize, +} + +// Some functions to help with DRY and code readability +const NOT_DONE: u64 = 1 << 63; + +#[inline] +fn set_done(vertex_state: &mut u64) { + *vertex_state ^= NOT_DONE; +} + +#[inline] +fn is_in_stack(vertex_state: u64) -> bool { + vertex_state != 0 && (vertex_state & NOT_DONE) != 0 +} + +#[inline] +fn is_unvisited(vertex_state: u64) -> bool { + vertex_state == NOT_DONE +} + +#[inline] +fn get_discover_time(vertex_state: u64) -> u64 { + vertex_state ^ NOT_DONE +} + +impl StronglyConnectedComponents { + pub fn new(mut num_vertices: usize) -> Self { + num_vertices += 1; // Vertices are numbered from 1, not 0 + StronglyConnectedComponents { + component: vec![0; num_vertices], + state: vec![NOT_DONE; num_vertices], + num_components: 0, + stack: vec![], + current_time: 1, + } + } + fn dfs(&mut self, v: usize, adj: &[Vec]) -> u64 { + let mut min_disc = self.current_time as u64; + // self.state[v] = NOT_DONE + min_disc + self.state[v] ^= min_disc; + self.current_time += 1; + self.stack.push(v); + + for &u in adj[v].iter() { + if is_unvisited(self.state[u]) { + min_disc = std::cmp::min(self.dfs(u, adj), min_disc); + } else if is_in_stack(self.state[u]) { + min_disc = std::cmp::min(get_discover_time(self.state[u]), min_disc); + } + } + + // No vertex with a lower discovery time is reachable from this one + // So it should be "the head" of a new SCC. + if min_disc == get_discover_time(self.state[v]) { + self.num_components += 1; + loop { + let u = self.stack.pop().unwrap(); + self.component[u] = self.num_components; + set_done(&mut self.state[u]); + if u == v { + break; + } + } + } + + min_disc + } + pub fn find_components(&mut self, adj: &[Vec]) { + self.state[0] = 0; + for v in 1..adj.len() { + if is_unvisited(self.state[v]) { + self.dfs(v, adj); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn acyclic() { + let mut sccs = StronglyConnectedComponents::new(5); + let adj = vec![vec![], vec![2, 4], vec![3, 4], vec![5], vec![5], vec![]]; + sccs.find_components(&adj); + assert_eq!(sccs.component, vec![0, 5, 4, 2, 3, 1]); + assert_eq!(sccs.state, vec![0, 1, 2, 3, 5, 4]); + assert_eq!(sccs.num_components, 5); + } + + #[test] + fn cycle() { + let mut sccs = StronglyConnectedComponents::new(4); + let adj = vec![vec![], vec![2], vec![3], vec![4], vec![1]]; + sccs.find_components(&adj); + assert_eq!(sccs.component, vec![0, 1, 1, 1, 1]); + assert_eq!(sccs.state, vec![0, 1, 2, 3, 4]); + assert_eq!(sccs.num_components, 1); + } + + #[test] + fn dumbbell() { + let mut sccs = StronglyConnectedComponents::new(6); + let adj = vec![ + vec![], + vec![2], + vec![3, 4], + vec![1], + vec![5], + vec![6], + vec![4], + ]; + sccs.find_components(&adj); + assert_eq!(sccs.component, vec![0, 2, 2, 2, 1, 1, 1]); + assert_eq!(sccs.state, vec![0, 1, 2, 3, 4, 5, 6]); + assert_eq!(sccs.num_components, 2); + } + + #[test] + fn connected_dumbbell() { + let mut sccs = StronglyConnectedComponents::new(6); + let adj = vec![ + vec![], + vec![2], + vec![3, 4], + vec![1], + vec![5, 1], + vec![6], + vec![4], + ]; + sccs.find_components(&adj); + assert_eq!(sccs.component, vec![0, 1, 1, 1, 1, 1, 1]); + assert_eq!(sccs.state, vec![0, 1, 2, 3, 4, 5, 6]); + assert_eq!(sccs.num_components, 1); + } +} diff --git a/src/graph/tarjans_ssc.rs b/src/graph/tarjans_ssc.rs new file mode 100644 index 00000000000..1f8e258614a --- /dev/null +++ b/src/graph/tarjans_ssc.rs @@ -0,0 +1,173 @@ +pub struct Graph { + n: usize, + adj_list: Vec>, +} + +impl Graph { + pub fn new(n: usize) -> Self { + Self { + n, + adj_list: vec![vec![]; n], + } + } + + pub fn add_edge(&mut self, u: usize, v: usize) { + self.adj_list[u].push(v); + } +} +pub fn tarjan_scc(graph: &Graph) -> Vec> { + struct TarjanState { + index: i32, + stack: Vec, + on_stack: Vec, + index_of: Vec, + lowlink_of: Vec, + components: Vec>, + } + + let mut state = TarjanState { + index: 0, + stack: Vec::new(), + on_stack: vec![false; graph.n], + index_of: vec![-1; graph.n], + lowlink_of: vec![-1; graph.n], + components: Vec::new(), + }; + + fn strong_connect(v: usize, graph: &Graph, state: &mut TarjanState) { + state.index_of[v] = state.index; + state.lowlink_of[v] = state.index; + state.index += 1; + state.stack.push(v); + state.on_stack[v] = true; + + for &w in &graph.adj_list[v] { + if state.index_of[w] == -1 { + strong_connect(w, graph, state); + state.lowlink_of[v] = state.lowlink_of[v].min(state.lowlink_of[w]); + } else if state.on_stack[w] { + state.lowlink_of[v] = state.lowlink_of[v].min(state.index_of[w]); + } + } + + if state.lowlink_of[v] == state.index_of[v] { + let mut component: Vec = Vec::new(); + while let Some(w) = state.stack.pop() { + state.on_stack[w] = false; + component.push(w); + if w == v { + break; + } + } + state.components.push(component); + } + } + + for v in 0..graph.n { + if state.index_of[v] == -1 { + strong_connect(v, graph, &mut state); + } + } + + state.components +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tarjan_scc() { + // Test 1: A graph with multiple strongly connected components + let n_vertices = 11; + let edges = vec![ + (0, 1), + (0, 3), + (1, 2), + (1, 4), + (2, 0), + (2, 6), + (3, 2), + (4, 5), + (4, 6), + (5, 6), + (5, 7), + (5, 8), + (5, 9), + (6, 4), + (7, 9), + (8, 9), + (9, 8), + ]; + let mut graph = Graph::new(n_vertices); + + for &(u, v) in &edges { + graph.add_edge(u, v); + } + + let components = tarjan_scc(&graph); + assert_eq!( + components, + vec![ + vec![8, 9], + vec![7], + vec![5, 4, 6], + vec![3, 2, 1, 0], + vec![10], + ] + ); + + // Test 2: A graph with no edges + let n_vertices = 5; + let edges: Vec<(usize, usize)> = vec![]; + let mut graph = Graph::new(n_vertices); + + for &(u, v) in &edges { + graph.add_edge(u, v); + } + + let components = tarjan_scc(&graph); + + // Each node is its own SCC + assert_eq!( + components, + vec![vec![0], vec![1], vec![2], vec![3], vec![4]] + ); + + // Test 3: A graph with single strongly connected component + let n_vertices = 5; + let edges = vec![(0, 1), (1, 2), (2, 3), (2, 4), (3, 0), (4, 2)]; + let mut graph = Graph::new(n_vertices); + + for &(u, v) in &edges { + graph.add_edge(u, v); + } + + let components = tarjan_scc(&graph); + assert_eq!(components, vec![vec![4, 3, 2, 1, 0]]); + + // Test 4: A graph with multiple strongly connected component + let n_vertices = 7; + let edges = vec![ + (0, 1), + (1, 2), + (2, 0), + (1, 3), + (1, 4), + (1, 6), + (3, 5), + (4, 5), + ]; + let mut graph = Graph::new(n_vertices); + + for &(u, v) in &edges { + graph.add_edge(u, v); + } + + let components = tarjan_scc(&graph); + assert_eq!( + components, + vec![vec![5], vec![3], vec![4], vec![6], vec![2, 1, 0],] + ); + } +} diff --git a/src/graph/topological_sort.rs b/src/graph/topological_sort.rs new file mode 100644 index 00000000000..887758287ea --- /dev/null +++ b/src/graph/topological_sort.rs @@ -0,0 +1,123 @@ +use std::collections::HashMap; +use std::collections::VecDeque; +use std::hash::Hash; + +#[derive(Debug, Eq, PartialEq)] +pub enum TopoligicalSortError { + CycleDetected, +} + +type TopologicalSortResult = Result, TopoligicalSortError>; + +/// Given a directed graph, modeled as a list of edges from source to destination +/// Uses Kahn's algorithm to either: +/// return the topological sort of the graph +/// or detect if there's any cycle +pub fn topological_sort( + edges: &Vec<(Node, Node)>, +) -> TopologicalSortResult { + // Preparation: + // Build a map of edges, organised from source to destinations + // Also, count the number of incoming edges by node + let mut edges_by_source: HashMap> = HashMap::default(); + let mut incoming_edges_count: HashMap = HashMap::default(); + for (source, destination) in edges { + incoming_edges_count.entry(*source).or_insert(0); // if we haven't seen this node yet, mark it as having 0 incoming nodes + edges_by_source // add destination to the list of outgoing edges from source + .entry(*source) + .or_default() + .push(*destination); + // then make destination have one more incoming edge + *incoming_edges_count.entry(*destination).or_insert(0) += 1; + } + + // Now Kahn's algorithm: + // Add nodes that have no incoming edges to a queue + let mut no_incoming_edges_q = VecDeque::default(); + for (node, count) in &incoming_edges_count { + if *count == 0 { + no_incoming_edges_q.push_back(*node); + } + } + // For each node in this "O-incoming-edge-queue" + let mut sorted = Vec::default(); + while let Some(no_incoming_edges) = no_incoming_edges_q.pop_back() { + sorted.push(no_incoming_edges); // since the node has no dependency, it can be safely pushed to the sorted result + incoming_edges_count.remove(&no_incoming_edges); + // For each node having this one as dependency + for neighbour in edges_by_source.get(&no_incoming_edges).unwrap_or(&vec![]) { + if let Some(count) = incoming_edges_count.get_mut(neighbour) { + *count -= 1; // decrement the count of incoming edges for the dependent node + if *count == 0 { + // `node` was the last node `neighbour` was dependent on + incoming_edges_count.remove(neighbour); // let's remove it from the map, so that we can know if we covered the whole graph + no_incoming_edges_q.push_front(*neighbour); // it has no incoming edges anymore => push it to the queue + } + } + } + } + if incoming_edges_count.is_empty() { + // we have visited every node + Ok(sorted) + } else { + // some nodes haven't been visited, meaning there's a cycle in the graph + Err(TopoligicalSortError::CycleDetected) + } +} + +#[cfg(test)] +mod tests { + use super::topological_sort; + use crate::graph::topological_sort::TopoligicalSortError; + + fn is_valid_sort(sorted: &[Node], graph: &[(Node, Node)]) -> bool { + for (source, dest) in graph { + let source_pos = sorted.iter().position(|node| node == source); + let dest_pos = sorted.iter().position(|node| node == dest); + match (source_pos, dest_pos) { + (Some(src), Some(dst)) if src < dst => {} + _ => { + return false; + } + }; + } + true + } + + #[test] + fn it_works() { + let graph = vec![(1, 2), (1, 3), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)]; + let sort = topological_sort(&graph); + assert!(sort.is_ok()); + let sort = sort.unwrap(); + assert!(is_valid_sort(&sort, &graph)); + assert_eq!(sort, vec![1, 2, 3, 4, 5, 6, 7]); + } + + #[test] + fn test_wikipedia_example() { + let graph = vec![ + (5, 11), + (7, 11), + (7, 8), + (3, 8), + (3, 10), + (11, 2), + (11, 9), + (11, 10), + (8, 9), + ]; + let sort = topological_sort(&graph); + assert!(sort.is_ok()); + let sort = sort.unwrap(); + assert!(is_valid_sort(&sort, &graph)); + } + + #[test] + fn test_cyclic_graph() { + let graph = vec![(1, 2), (2, 3), (3, 4), (4, 5), (4, 2)]; + let sort = topological_sort(&graph); + assert!(sort.is_err()); + assert_eq!(sort.err().unwrap(), TopoligicalSortError::CycleDetected); + } +} diff --git a/src/graph/two_satisfiability.rs b/src/graph/two_satisfiability.rs new file mode 100644 index 00000000000..a3e727f9323 --- /dev/null +++ b/src/graph/two_satisfiability.rs @@ -0,0 +1,116 @@ +use super::strongly_connected_components::StronglyConnectedComponents as SCCs; + +pub type Condition = (i64, i64); +type Graph = Vec>; + +#[inline] +fn variable(var: i64) -> usize { + if var < 0 { + (((-var) << 1) + 1) as usize + } else { + (var << 1) as usize + } +} + +/// Returns an assignment that satisfies all the constraints, or a variable that makes such an assignment impossible.\ +/// Variables should be numbered from 1 to `n`, and a negative number `-m` corresponds to the negated variable `m`.\ +/// For more information about this problem, please visit: +pub fn solve_two_satisfiability( + expression: &[Condition], + num_variables: usize, +) -> Result, i64> { + let num_verts = (num_variables + 1) << 1; + let mut result = Vec::new(); + let mut sccs = SCCs::new(num_verts); + let mut adj = Graph::new(); + adj.resize(num_verts, vec![]); + for cond in expression.iter() { + let v1 = variable(cond.0); + let v2 = variable(cond.1); + adj[v1 ^ 1].push(v2); + adj[v2 ^ 1].push(v1); + } + sccs.find_components(&adj); + result.resize(num_variables + 1, false); + for var in (2..num_verts).step_by(2) { + if sccs.component[var] == sccs.component[var ^ 1] { + return Err((var >> 1) as i64); + } + // if a variable isn't + if sccs.component[var] < sccs.component[var ^ 1] { + result[var >> 1] = true; + } + } + Ok(result) +} + +#[cfg(test)] +mod tests { + use std::thread; + + use super::*; + + fn check_answer(expression: &[Condition], answers: &[bool]) -> bool { + let mut ok = true; + for &(c1, c2) in expression { + let mut cv = false; + if c1 < 0 { + cv |= !answers[-c1 as usize]; + } else { + cv |= answers[c1 as usize]; + } + if c2 < 0 { + cv |= !answers[-c2 as usize]; + } else { + cv |= answers[c2 as usize]; + } + ok &= cv; + } + ok + } + #[test] + fn basic_test() { + let conds = vec![(1, 1), (2, 2)]; + let res = solve_two_satisfiability(&conds, 2); + assert!(res.is_ok()); + assert!(check_answer(&conds, &res.unwrap())); + + let conds = vec![(1, 2), (-2, -2)]; + let res = solve_two_satisfiability(&conds, 2); + assert!(res.is_ok()); + assert!(check_answer(&conds, &res.unwrap())); + + let conds = vec![]; + let res = solve_two_satisfiability(&conds, 2); + assert!(res.is_ok()); + assert!(check_answer(&conds, &res.unwrap())); + + let conds = vec![(-1, -1), (-2, -2), (1, 2)]; + let res = solve_two_satisfiability(&conds, 2); + assert!(res.is_err()); + } + + #[test] + #[ignore] + fn big_test() { + // We should spawn a new thread and set its stack size to something + // big (256MB in this case), because doing DFS (for finding SCCs) is + // a stack-intensive operation. 256MB should be enough for 3e5 + // variables though. + let builder = thread::Builder::new().stack_size(256 * 1024 * 1024); + let handler = builder + .spawn(|| { + let num_conds = 3e5 as i64; + let mut conds = vec![]; + for i in 1..num_conds { + conds.push((i, -(i + 1))); + } + conds.push((num_conds, num_conds)); + let res = solve_two_satisfiability(&conds, num_conds as usize); + assert!(res.is_ok()); + assert!(check_answer(&conds, &res.unwrap())); + }) + .unwrap(); + handler.join().unwrap(); + } +} diff --git a/src/greedy/mod.rs b/src/greedy/mod.rs new file mode 100644 index 00000000000..e718c149f42 --- /dev/null +++ b/src/greedy/mod.rs @@ -0,0 +1,3 @@ +mod stable_matching; + +pub use self::stable_matching::stable_matching; diff --git a/src/greedy/stable_matching.rs b/src/greedy/stable_matching.rs new file mode 100644 index 00000000000..9b8f603d3d0 --- /dev/null +++ b/src/greedy/stable_matching.rs @@ -0,0 +1,276 @@ +use std::collections::{HashMap, VecDeque}; + +fn initialize_men( + men_preferences: &HashMap>, +) -> (VecDeque, HashMap) { + let mut free_men = VecDeque::new(); + let mut next_proposal = HashMap::new(); + + for man in men_preferences.keys() { + free_men.push_back(man.clone()); + next_proposal.insert(man.clone(), 0); + } + + (free_men, next_proposal) +} + +fn initialize_women( + women_preferences: &HashMap>, +) -> HashMap> { + let mut current_partner = HashMap::new(); + for woman in women_preferences.keys() { + current_partner.insert(woman.clone(), None); + } + current_partner +} + +fn precompute_woman_ranks( + women_preferences: &HashMap>, +) -> HashMap> { + let mut woman_ranks = HashMap::new(); + for (woman, preferences) in women_preferences { + let mut rank_map = HashMap::new(); + for (rank, man) in preferences.iter().enumerate() { + rank_map.insert(man.clone(), rank); + } + woman_ranks.insert(woman.clone(), rank_map); + } + woman_ranks +} + +fn process_proposal( + man: &str, + free_men: &mut VecDeque, + current_partner: &mut HashMap>, + man_engaged: &mut HashMap>, + next_proposal: &mut HashMap, + men_preferences: &HashMap>, + woman_ranks: &HashMap>, +) { + let man_pref_list = &men_preferences[man]; + let next_woman_idx = next_proposal[man]; + let woman = &man_pref_list[next_woman_idx]; + + // Update man's next proposal index + next_proposal.insert(man.to_string(), next_woman_idx + 1); + + if let Some(current_man) = current_partner[woman].clone() { + // Woman is currently engaged, check if she prefers the new man + if woman_prefers_new_man(woman, man, ¤t_man, woman_ranks) { + engage_man( + man, + woman, + free_men, + current_partner, + man_engaged, + Some(current_man), + ); + } else { + // Woman rejects the proposal, so the man remains free + free_men.push_back(man.to_string()); + } + } else { + // Woman is not engaged, so engage her with this man + engage_man(man, woman, free_men, current_partner, man_engaged, None); + } +} + +fn woman_prefers_new_man( + woman: &str, + man1: &str, + man2: &str, + woman_ranks: &HashMap>, +) -> bool { + let ranks = &woman_ranks[woman]; + ranks[man1] < ranks[man2] +} + +fn engage_man( + man: &str, + woman: &str, + free_men: &mut VecDeque, + current_partner: &mut HashMap>, + man_engaged: &mut HashMap>, + current_man: Option, +) { + man_engaged.insert(man.to_string(), Some(woman.to_string())); + current_partner.insert(woman.to_string(), Some(man.to_string())); + + if let Some(current_man) = current_man { + // The current man is now free + free_men.push_back(current_man); + } +} + +fn finalize_matches(man_engaged: HashMap>) -> HashMap { + let mut stable_matches = HashMap::new(); + for (man, woman_option) in man_engaged { + if let Some(woman) = woman_option { + stable_matches.insert(man, woman); + } + } + stable_matches +} + +pub fn stable_matching( + men_preferences: &HashMap>, + women_preferences: &HashMap>, +) -> HashMap { + let (mut free_men, mut next_proposal) = initialize_men(men_preferences); + let mut current_partner = initialize_women(women_preferences); + let mut man_engaged = HashMap::new(); + + let woman_ranks = precompute_woman_ranks(women_preferences); + + while let Some(man) = free_men.pop_front() { + process_proposal( + &man, + &mut free_men, + &mut current_partner, + &mut man_engaged, + &mut next_proposal, + men_preferences, + &woman_ranks, + ); + } + + finalize_matches(man_engaged) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + #[test] + fn test_stable_matching_scenario_1() { + let men_preferences = HashMap::from([ + ( + "A".to_string(), + vec!["X".to_string(), "Y".to_string(), "Z".to_string()], + ), + ( + "B".to_string(), + vec!["Y".to_string(), "X".to_string(), "Z".to_string()], + ), + ( + "C".to_string(), + vec!["X".to_string(), "Y".to_string(), "Z".to_string()], + ), + ]); + + let women_preferences = HashMap::from([ + ( + "X".to_string(), + vec!["B".to_string(), "A".to_string(), "C".to_string()], + ), + ( + "Y".to_string(), + vec!["A".to_string(), "B".to_string(), "C".to_string()], + ), + ( + "Z".to_string(), + vec!["A".to_string(), "B".to_string(), "C".to_string()], + ), + ]); + + let matches = stable_matching(&men_preferences, &women_preferences); + + let expected_matches1 = HashMap::from([ + ("A".to_string(), "Y".to_string()), + ("B".to_string(), "X".to_string()), + ("C".to_string(), "Z".to_string()), + ]); + + let expected_matches2 = HashMap::from([ + ("A".to_string(), "X".to_string()), + ("B".to_string(), "Y".to_string()), + ("C".to_string(), "Z".to_string()), + ]); + + assert!(matches == expected_matches1 || matches == expected_matches2); + } + + #[test] + fn test_stable_matching_empty() { + let men_preferences = HashMap::new(); + let women_preferences = HashMap::new(); + + let matches = stable_matching(&men_preferences, &women_preferences); + assert!(matches.is_empty()); + } + + #[test] + fn test_stable_matching_duplicate_preferences() { + let men_preferences = HashMap::from([ + ("A".to_string(), vec!["X".to_string(), "X".to_string()]), // Man with duplicate preferences + ("B".to_string(), vec!["Y".to_string()]), + ]); + + let women_preferences = HashMap::from([ + ("X".to_string(), vec!["A".to_string(), "B".to_string()]), + ("Y".to_string(), vec!["B".to_string()]), + ]); + + let matches = stable_matching(&men_preferences, &women_preferences); + let expected_matches = HashMap::from([ + ("A".to_string(), "X".to_string()), + ("B".to_string(), "Y".to_string()), + ]); + + assert_eq!(matches, expected_matches); + } + + #[test] + fn test_stable_matching_single_pair() { + let men_preferences = HashMap::from([("A".to_string(), vec!["X".to_string()])]); + let women_preferences = HashMap::from([("X".to_string(), vec!["A".to_string()])]); + + let matches = stable_matching(&men_preferences, &women_preferences); + let expected_matches = HashMap::from([("A".to_string(), "X".to_string())]); + + assert_eq!(matches, expected_matches); + } + #[test] + fn test_woman_prefers_new_man() { + let men_preferences = HashMap::from([ + ( + "A".to_string(), + vec!["X".to_string(), "Y".to_string(), "Z".to_string()], + ), + ( + "B".to_string(), + vec!["X".to_string(), "Y".to_string(), "Z".to_string()], + ), + ( + "C".to_string(), + vec!["X".to_string(), "Y".to_string(), "Z".to_string()], + ), + ]); + + let women_preferences = HashMap::from([ + ( + "X".to_string(), + vec!["B".to_string(), "A".to_string(), "C".to_string()], + ), + ( + "Y".to_string(), + vec!["A".to_string(), "B".to_string(), "C".to_string()], + ), + ( + "Z".to_string(), + vec!["A".to_string(), "B".to_string(), "C".to_string()], + ), + ]); + + let matches = stable_matching(&men_preferences, &women_preferences); + + let expected_matches = HashMap::from([ + ("A".to_string(), "Y".to_string()), + ("B".to_string(), "X".to_string()), + ("C".to_string(), "Z".to_string()), + ]); + + assert_eq!(matches, expected_matches); + } +} diff --git a/src/lib.rs b/src/lib.rs index 59147adbd11..910bf05de06 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,28 +1,39 @@ +pub mod backtracking; +pub mod big_integer; +pub mod bit_manipulation; pub mod ciphers; +pub mod compression; +pub mod conversions; pub mod data_structures; pub mod dynamic_programming; +pub mod financial; pub mod general; +pub mod geometry; +pub mod graph; +pub mod greedy; +pub mod machine_learning; +pub mod math; +pub mod navigation; +pub mod number_theory; pub mod searching; pub mod sorting; pub mod string; #[cfg(test)] mod tests { - use sorting; + use super::sorting; #[test] fn quick_sort() { //descending let mut ve1 = vec![6, 5, 4, 3, 2, 1]; sorting::quick_sort(&mut ve1); - for i in 0..ve1.len() - 1 { - assert!(ve1[i] <= ve1[i + 1]); - } + + assert!(sorting::is_sorted(&ve1)); //pre-sorted let mut ve2 = vec![1, 2, 3, 4, 5, 6]; sorting::quick_sort(&mut ve2); - for i in 0..ve2.len() - 1 { - assert!(ve2[i] <= ve2[i + 1]); - } + + assert!(sorting::is_sorted(&ve2)); } } diff --git a/src/machine_learning/cholesky.rs b/src/machine_learning/cholesky.rs new file mode 100644 index 00000000000..3afcc040245 --- /dev/null +++ b/src/machine_learning/cholesky.rs @@ -0,0 +1,99 @@ +pub fn cholesky(mat: Vec, n: usize) -> Vec { + if (mat.is_empty()) || (n == 0) { + return vec![]; + } + let mut res = vec![0.0; mat.len()]; + for i in 0..n { + for j in 0..=i { + let mut s = 0.0; + for k in 0..j { + s += res[i * n + k] * res[j * n + k]; + } + let value = if i == j { + let diag_value = mat[i * n + i] - s; + if diag_value.is_nan() { + 0.0 + } else { + diag_value.sqrt() + } + } else { + let off_diag_value = 1.0 / res[j * n + j] * (mat[i * n + j] - s); + if off_diag_value.is_nan() { + 0.0 + } else { + off_diag_value + } + }; + res[i * n + j] = value; + } + } + res +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cholesky() { + // Test case 1 + let mat1 = vec![25.0, 15.0, -5.0, 15.0, 18.0, 0.0, -5.0, 0.0, 11.0]; + let res1 = cholesky(mat1, 3); + + // The expected Cholesky decomposition values + #[allow(clippy::useless_vec)] + let expected1 = vec![5.0, 0.0, 0.0, 3.0, 3.0, 0.0, -1.0, 1.0, 3.0]; + + assert!(res1 + .iter() + .zip(expected1.iter()) + .all(|(a, b)| (a - b).abs() < 1e-6)); + } + + fn transpose_matrix(mat: &[f64], n: usize) -> Vec { + (0..n) + .flat_map(|i| (0..n).map(move |j| mat[j * n + i])) + .collect() + } + + fn matrix_multiply(mat1: &[f64], mat2: &[f64], n: usize) -> Vec { + (0..n) + .flat_map(|i| { + (0..n).map(move |j| { + (0..n).fold(0.0, |acc, k| acc + mat1[i * n + k] * mat2[k * n + j]) + }) + }) + .collect() + } + + #[test] + fn test_matrix_operations() { + // Test case 1: Transposition + let mat1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + let transposed_mat1 = transpose_matrix(&mat1, 3); + let expected_transposed_mat1 = vec![1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0]; + assert_eq!(transposed_mat1, expected_transposed_mat1); + + // Test case 2: Matrix multiplication + let mat2 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + let mat3 = vec![9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]; + let multiplied_mat = matrix_multiply(&mat2, &mat3, 3); + let expected_multiplied_mat = vec![30.0, 24.0, 18.0, 84.0, 69.0, 54.0, 138.0, 114.0, 90.0]; + assert_eq!(multiplied_mat, expected_multiplied_mat); + } + + #[test] + fn empty_matrix() { + let mat = vec![]; + let res = cholesky(mat, 0); + assert_eq!(res, vec![]); + } + + #[test] + fn matrix_with_all_zeros() { + let mat3 = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; + let res3 = cholesky(mat3, 3); + let expected3 = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; + assert_eq!(res3, expected3); + } +} diff --git a/src/machine_learning/k_means.rs b/src/machine_learning/k_means.rs new file mode 100644 index 00000000000..cd892d64424 --- /dev/null +++ b/src/machine_learning/k_means.rs @@ -0,0 +1,90 @@ +use rand::random; + +fn get_distance(p1: &(f64, f64), p2: &(f64, f64)) -> f64 { + let dx: f64 = p1.0 - p2.0; + let dy: f64 = p1.1 - p2.1; + + ((dx * dx) + (dy * dy)).sqrt() +} + +fn find_nearest(data_point: &(f64, f64), centroids: &[(f64, f64)]) -> u32 { + let mut cluster: u32 = 0; + + for (i, c) in centroids.iter().enumerate() { + let d1 = get_distance(data_point, c); + let d2 = get_distance(data_point, ¢roids[cluster as usize]); + + if d1 < d2 { + cluster = i as u32; + } + } + + cluster +} + +pub fn k_means(data_points: Vec<(f64, f64)>, n_clusters: usize, max_iter: i32) -> Option> { + if data_points.len() < n_clusters { + return None; + } + + let mut centroids: Vec<(f64, f64)> = Vec::new(); + let mut labels: Vec = vec![0; data_points.len()]; + + for _ in 0..n_clusters { + let x: f64 = random::(); + let y: f64 = random::(); + + centroids.push((x, y)); + } + + let mut count_iter: i32 = 0; + + while count_iter < max_iter { + let mut new_centroids_position: Vec<(f64, f64)> = vec![(0.0, 0.0); n_clusters]; + let mut new_centroids_num: Vec = vec![0; n_clusters]; + + for (i, d) in data_points.iter().enumerate() { + let nearest_cluster = find_nearest(d, ¢roids); + labels[i] = nearest_cluster; + + new_centroids_position[nearest_cluster as usize].0 += d.0; + new_centroids_position[nearest_cluster as usize].1 += d.1; + new_centroids_num[nearest_cluster as usize] += 1; + } + + for i in 0..centroids.len() { + if new_centroids_num[i] == 0 { + continue; + } + + let new_x: f64 = new_centroids_position[i].0 / new_centroids_num[i] as f64; + let new_y: f64 = new_centroids_position[i].1 / new_centroids_num[i] as f64; + + centroids[i] = (new_x, new_y); + } + + count_iter += 1; + } + + Some(labels) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_k_means() { + let mut data_points: Vec<(f64, f64)> = vec![]; + let n_points: usize = 1000; + + for _ in 0..n_points { + let x: f64 = random::() * 100.0; + let y: f64 = random::() * 100.0; + + data_points.push((x, y)); + } + + println!("{:?}", k_means(data_points, 10, 100).unwrap_or_default()); + } +} diff --git a/src/machine_learning/linear_regression.rs b/src/machine_learning/linear_regression.rs new file mode 100644 index 00000000000..22e0840e58b --- /dev/null +++ b/src/machine_learning/linear_regression.rs @@ -0,0 +1,48 @@ +/// Returns the parameters of the line after performing simple linear regression on the input data. +pub fn linear_regression(data_points: Vec<(f64, f64)>) -> Option<(f64, f64)> { + if data_points.is_empty() { + return None; + } + + let count = data_points.len() as f64; + let mean_x = data_points.iter().fold(0.0, |sum, y| sum + y.0) / count; + let mean_y = data_points.iter().fold(0.0, |sum, y| sum + y.1) / count; + + let mut covariance = 0.0; + let mut std_dev_sqr_x = 0.0; + let mut std_dev_sqr_y = 0.0; + + for data_point in data_points { + covariance += (data_point.0 - mean_x) * (data_point.1 - mean_y); + std_dev_sqr_x += (data_point.0 - mean_x).powi(2); + std_dev_sqr_y += (data_point.1 - mean_y).powi(2); + } + + let std_dev_x = std_dev_sqr_x.sqrt(); + let std_dev_y = std_dev_sqr_y.sqrt(); + let std_dev_prod = std_dev_x * std_dev_y; + + let pcc = covariance / std_dev_prod; //Pearson's correlation constant + let b = pcc * (std_dev_y / std_dev_x); //Slope of the line + let a = mean_y - b * mean_x; //Y-Intercept of the line + + Some((a, b)) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_linear_regression() { + assert_eq!( + linear_regression(vec![(0.0, 0.0), (1.0, 1.0), (2.0, 2.0)]), + Some((2.220446049250313e-16, 0.9999999999999998)) + ); + } + + #[test] + fn test_empty_list_linear_regression() { + assert_eq!(linear_regression(vec![]), None); + } +} diff --git a/src/machine_learning/logistic_regression.rs b/src/machine_learning/logistic_regression.rs new file mode 100644 index 00000000000..645cd960f83 --- /dev/null +++ b/src/machine_learning/logistic_regression.rs @@ -0,0 +1,92 @@ +use super::optimization::gradient_descent; +use std::f64::consts::E; + +/// Returns the weights after performing Logistic regression on the input data points. +pub fn logistic_regression( + data_points: Vec<(Vec, f64)>, + iterations: usize, + learning_rate: f64, +) -> Option> { + if data_points.is_empty() { + return None; + } + + let num_features = data_points[0].0.len() + 1; + let mut params = vec![0.0; num_features]; + + let derivative_fn = |params: &[f64]| derivative(params, &data_points); + + gradient_descent(derivative_fn, &mut params, learning_rate, iterations as i32); + + Some(params) +} + +fn derivative(params: &[f64], data_points: &[(Vec, f64)]) -> Vec { + let num_features = params.len(); + let mut gradients = vec![0.0; num_features]; + + for (features, y_i) in data_points { + let z = params[0] + + params[1..] + .iter() + .zip(features) + .map(|(p, x)| p * x) + .sum::(); + let prediction = 1.0 / (1.0 + E.powf(-z)); + + gradients[0] += prediction - y_i; + for (i, x_i) in features.iter().enumerate() { + gradients[i + 1] += (prediction - y_i) * x_i; + } + } + + gradients +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_logistic_regression_simple() { + let data = vec![ + (vec![0.0], 0.0), + (vec![1.0], 0.0), + (vec![2.0], 0.0), + (vec![3.0], 1.0), + (vec![4.0], 1.0), + (vec![5.0], 1.0), + ]; + + let result = logistic_regression(data, 10000, 0.05); + assert!(result.is_some()); + + let params = result.unwrap(); + assert!((params[0] + 17.65).abs() < 1.0); + assert!((params[1] - 7.13).abs() < 1.0); + } + + #[test] + fn test_logistic_regression_extreme_data() { + let data = vec![ + (vec![-100.0], 0.0), + (vec![-10.0], 0.0), + (vec![0.0], 0.0), + (vec![10.0], 1.0), + (vec![100.0], 1.0), + ]; + + let result = logistic_regression(data, 10000, 0.05); + assert!(result.is_some()); + + let params = result.unwrap(); + assert!((params[0] + 6.20).abs() < 1.0); + assert!((params[1] - 5.5).abs() < 1.0); + } + + #[test] + fn test_logistic_regression_no_data() { + let result = logistic_regression(vec![], 5000, 0.1); + assert_eq!(result, None); + } +} diff --git a/src/machine_learning/loss_function/average_margin_ranking_loss.rs b/src/machine_learning/loss_function/average_margin_ranking_loss.rs new file mode 100644 index 00000000000..505bf2a94a7 --- /dev/null +++ b/src/machine_learning/loss_function/average_margin_ranking_loss.rs @@ -0,0 +1,113 @@ +/// Marginal Ranking +/// +/// The 'average_margin_ranking_loss' function calculates the Margin Ranking loss, which is a +/// loss function used for ranking problems in machine learning. +/// +/// ## Formula +/// +/// For a pair of values `x_first` and `x_second`, `margin`, and `y_true`, +/// the Margin Ranking loss is calculated as: +/// +/// - loss = `max(0, -y_true * (x_first - x_second) + margin)`. +/// +/// It returns the average loss by dividing the `total_loss` by total no. of +/// elements. +/// +/// Pytorch implementation: +/// https://pytorch.org/docs/stable/generated/torch.nn.MarginRankingLoss.html +/// https://gombru.github.io/2019/04/03/ranking_loss/ +/// https://vinija.ai/concepts/loss/#pairwise-ranking-loss +/// + +pub fn average_margin_ranking_loss( + x_first: &[f64], + x_second: &[f64], + margin: f64, + y_true: f64, +) -> Result { + check_input(x_first, x_second, margin, y_true)?; + + let total_loss: f64 = x_first + .iter() + .zip(x_second.iter()) + .map(|(f, s)| (margin - y_true * (f - s)).max(0.0)) + .sum(); + Ok(total_loss / (x_first.len() as f64)) +} + +fn check_input( + x_first: &[f64], + x_second: &[f64], + margin: f64, + y_true: f64, +) -> Result<(), MarginalRankingLossError> { + if x_first.len() != x_second.len() { + return Err(MarginalRankingLossError::InputsHaveDifferentLength); + } + if x_first.is_empty() { + return Err(MarginalRankingLossError::EmptyInputs); + } + if margin < 0.0 { + return Err(MarginalRankingLossError::NegativeMargin); + } + if y_true != 1.0 && y_true != -1.0 { + return Err(MarginalRankingLossError::InvalidValues); + } + + Ok(()) +} + +#[derive(Debug, PartialEq, Eq)] +pub enum MarginalRankingLossError { + InputsHaveDifferentLength, + EmptyInputs, + InvalidValues, + NegativeMargin, +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_with_wrong_inputs { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (vec_a, vec_b, margin, y_true, expected) = $inputs; + assert_eq!(average_margin_ranking_loss(&vec_a, &vec_b, margin, y_true), expected); + assert_eq!(average_margin_ranking_loss(&vec_b, &vec_a, margin, y_true), expected); + } + )* + } + } + + test_with_wrong_inputs! { + invalid_length0: (vec![1.0, 2.0, 3.0], vec![2.0, 3.0], 1.0, 1.0, Err(MarginalRankingLossError::InputsHaveDifferentLength)), + invalid_length1: (vec![1.0, 2.0], vec![2.0, 3.0, 4.0], 1.0, 1.0, Err(MarginalRankingLossError::InputsHaveDifferentLength)), + invalid_length2: (vec![], vec![1.0, 2.0, 3.0], 1.0, 1.0, Err(MarginalRankingLossError::InputsHaveDifferentLength)), + invalid_length3: (vec![1.0, 2.0, 3.0], vec![], 1.0, 1.0, Err(MarginalRankingLossError::InputsHaveDifferentLength)), + invalid_values: (vec![1.0, 2.0, 3.0], vec![2.0, 3.0, 4.0], -1.0, 1.0, Err(MarginalRankingLossError::NegativeMargin)), + invalid_y_true: (vec![1.0, 2.0, 3.0], vec![2.0, 3.0, 4.0], 1.0, 2.0, Err(MarginalRankingLossError::InvalidValues)), + empty_inputs: (vec![], vec![], 1.0, 1.0, Err(MarginalRankingLossError::EmptyInputs)), + } + + macro_rules! test_average_margin_ranking_loss { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (x_first, x_second, margin, y_true, expected) = $inputs; + assert_eq!(average_margin_ranking_loss(&x_first, &x_second, margin, y_true), Ok(expected)); + } + )* + } + } + + test_average_margin_ranking_loss! { + set_0: (vec![1.0, 2.0, 3.0], vec![2.0, 3.0, 4.0], 1.0, -1.0, 0.0), + set_1: (vec![1.0, 2.0, 3.0], vec![2.0, 3.0, 4.0], 1.0, 1.0, 2.0), + set_2: (vec![1.0, 2.0, 3.0], vec![1.0, 2.0, 3.0], 0.0, 1.0, 0.0), + set_3: (vec![4.0, 5.0, 6.0], vec![1.0, 2.0, 3.0], 1.0, -1.0, 4.0), + } +} diff --git a/src/machine_learning/loss_function/hinge_loss.rs b/src/machine_learning/loss_function/hinge_loss.rs new file mode 100644 index 00000000000..c02f1eca646 --- /dev/null +++ b/src/machine_learning/loss_function/hinge_loss.rs @@ -0,0 +1,38 @@ +//! # Hinge Loss +//! +//! The `hng_loss` function calculates the Hinge loss, which is a +//! loss function used for classification problems in machine learning. +//! +//! ## Formula +//! +//! For a pair of actual and predicted values, represented as vectors `y_true` and +//! `y_pred`, the Hinge loss is calculated as: +//! +//! - loss = `max(0, 1 - y_true * y_pred)`. +//! +//! It returns the average loss by dividing the `total_loss` by total no. of +//! elements. +//! +pub fn hng_loss(y_true: &[f64], y_pred: &[f64]) -> f64 { + let mut total_loss: f64 = 0.0; + for (p, a) in y_pred.iter().zip(y_true.iter()) { + let loss = (1.0 - a * p).max(0.0); + total_loss += loss; + } + total_loss / (y_pred.len() as f64) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_hinge_loss() { + let predicted_values: Vec = vec![-1.0, 1.0, 1.0]; + let actual_values: Vec = vec![-1.0, -1.0, 1.0]; + assert_eq!( + hng_loss(&predicted_values, &actual_values), + 0.6666666666666666 + ); + } +} diff --git a/src/machine_learning/loss_function/huber_loss.rs b/src/machine_learning/loss_function/huber_loss.rs new file mode 100644 index 00000000000..f81e8deb85c --- /dev/null +++ b/src/machine_learning/loss_function/huber_loss.rs @@ -0,0 +1,74 @@ +/// Computes the Huber loss between arrays of true and predicted values. +/// +/// # Arguments +/// +/// * `y_true` - An array of true values. +/// * `y_pred` - An array of predicted values. +/// * `delta` - The threshold parameter that controls the linear behavior of the loss function. +/// +/// # Returns +/// +/// The average Huber loss for all pairs of true and predicted values. +pub fn huber_loss(y_true: &[f64], y_pred: &[f64], delta: f64) -> Option { + if y_true.len() != y_pred.len() || y_pred.is_empty() { + return None; + } + + let loss: f64 = y_true + .iter() + .zip(y_pred.iter()) + .map(|(&true_val, &pred_val)| { + let residual = (true_val - pred_val).abs(); + match residual { + r if r <= delta => 0.5 * r.powi(2), + _ => delta * residual - 0.5 * delta.powi(2), + } + }) + .sum(); + + Some(loss / (y_pred.len() as f64)) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! huber_loss_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (y_true, y_pred, delta, expected_loss) = $test_case; + assert_eq!(huber_loss(&y_true, &y_pred, delta), expected_loss); + } + )* + }; + } + + huber_loss_tests! { + test_huber_loss_residual_less_than_delta: ( + vec![10.0, 8.0, 12.0], + vec![9.0, 7.0, 11.0], + 1.0, + Some(0.5) + ), + test_huber_loss_residual_greater_than_delta: ( + vec![3.0, 5.0, 7.0], + vec![2.0, 4.0, 8.0], + 0.5, + Some(0.375) + ), + test_huber_loss_invalid_length: ( + vec![10.0, 8.0, 12.0], + vec![7.0, 6.0], + 1.0, + None + ), + test_huber_loss_empty_prediction: ( + vec![10.0, 8.0, 12.0], + vec![], + 1.0, + None + ), + } +} diff --git a/src/machine_learning/loss_function/kl_divergence_loss.rs b/src/machine_learning/loss_function/kl_divergence_loss.rs new file mode 100644 index 00000000000..f477607b20f --- /dev/null +++ b/src/machine_learning/loss_function/kl_divergence_loss.rs @@ -0,0 +1,37 @@ +//! # KL divergence Loss Function +//! +//! For a pair of actual and predicted probability distributions represented as vectors `actual` and `predicted`, the KL divergence loss is calculated as: +//! +//! `L = -Σ(actual[i] * ln(predicted[i]/actual[i]))` for all `i` in the range of the vectors +//! +//! Where `ln` is the natural logarithm function, and `Σ` denotes the summation over all elements of the vectors. +//! +//! ## KL divergence Loss Function Implementation +//! +//! This implementation takes two references to vectors of f64 values, `actual` and `predicted`, and returns the KL divergence loss between them. +//! +pub fn kld_loss(actual: &[f64], predicted: &[f64]) -> f64 { + // epsilon to handle if any of the elements are zero + let eps = 0.00001f64; + let loss: f64 = actual + .iter() + .zip(predicted.iter()) + .map(|(&a, &p)| ((a + eps) * ((a + eps) / (p + eps)).ln())) + .sum(); + loss +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_kld_loss() { + let test_vector_actual = vec![1.346112, 1.337432, 1.246655]; + let test_vector = vec![1.033836, 1.082015, 1.117323]; + assert_eq!( + kld_loss(&test_vector_actual, &test_vector), + 0.7752789394328498 + ); + } +} diff --git a/src/machine_learning/loss_function/mean_absolute_error_loss.rs b/src/machine_learning/loss_function/mean_absolute_error_loss.rs new file mode 100644 index 00000000000..e82cc317624 --- /dev/null +++ b/src/machine_learning/loss_function/mean_absolute_error_loss.rs @@ -0,0 +1,36 @@ +//! # Mean Absolute Error Loss Function +//! +//! The `mae_loss` function calculates the Mean Absolute Error loss, which is a +//! robust loss function used in machine learning. +//! +//! ## Formula +//! +//! For a pair of actual and predicted values, represented as vectors `actual` +//! and `predicted`, the Mean Absolute loss is calculated as: +//! +//! - loss = `(actual - predicted) / n_elements`. +//! +//! It returns the average loss by dividing the `total_loss` by total no. of +//! elements. +//! +pub fn mae_loss(predicted: &[f64], actual: &[f64]) -> f64 { + let mut total_loss: f64 = 0.0; + for (p, a) in predicted.iter().zip(actual.iter()) { + let diff: f64 = p - a; + let absolute_diff = diff.abs(); + total_loss += absolute_diff; + } + total_loss / (predicted.len() as f64) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mae_loss() { + let predicted_values: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let actual_values: Vec = vec![1.0, 3.0, 3.5, 4.5]; + assert_eq!(mae_loss(&predicted_values, &actual_values), 0.5); + } +} diff --git a/src/machine_learning/loss_function/mean_squared_error_loss.rs b/src/machine_learning/loss_function/mean_squared_error_loss.rs new file mode 100644 index 00000000000..407142a3092 --- /dev/null +++ b/src/machine_learning/loss_function/mean_squared_error_loss.rs @@ -0,0 +1,35 @@ +//! # Mean Square Loss Function +//! +//! The `mse_loss` function calculates the Mean Square Error loss, which is a +//! robust loss function used in machine learning. +//! +//! ## Formula +//! +//! For a pair of actual and predicted values, represented as vectors `actual` +//! and `predicted`, the Mean Square loss is calculated as: +//! +//! - loss = `(actual - predicted)^2 / n_elements`. +//! +//! It returns the average loss by dividing the `total_loss` by total no. of +//! elements. +//! +pub fn mse_loss(predicted: &[f64], actual: &[f64]) -> f64 { + let mut total_loss: f64 = 0.0; + for (p, a) in predicted.iter().zip(actual.iter()) { + let diff: f64 = p - a; + total_loss += diff * diff; + } + total_loss / (predicted.len() as f64) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mse_loss() { + let predicted_values: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let actual_values: Vec = vec![1.0, 3.0, 3.5, 4.5]; + assert_eq!(mse_loss(&predicted_values, &actual_values), 0.375); + } +} diff --git a/src/machine_learning/loss_function/mod.rs b/src/machine_learning/loss_function/mod.rs new file mode 100644 index 00000000000..95686eb8c20 --- /dev/null +++ b/src/machine_learning/loss_function/mod.rs @@ -0,0 +1,15 @@ +mod average_margin_ranking_loss; +mod hinge_loss; +mod huber_loss; +mod kl_divergence_loss; +mod mean_absolute_error_loss; +mod mean_squared_error_loss; +mod negative_log_likelihood; + +pub use self::average_margin_ranking_loss::average_margin_ranking_loss; +pub use self::hinge_loss::hng_loss; +pub use self::huber_loss::huber_loss; +pub use self::kl_divergence_loss::kld_loss; +pub use self::mean_absolute_error_loss::mae_loss; +pub use self::mean_squared_error_loss::mse_loss; +pub use self::negative_log_likelihood::neg_log_likelihood; diff --git a/src/machine_learning/loss_function/negative_log_likelihood.rs b/src/machine_learning/loss_function/negative_log_likelihood.rs new file mode 100644 index 00000000000..4fa633091cf --- /dev/null +++ b/src/machine_learning/loss_function/negative_log_likelihood.rs @@ -0,0 +1,100 @@ +// Negative Log Likelihood Loss Function +// +// The `neg_log_likelihood` function calculates the Negative Log Likelyhood loss, +// which is a loss function used for classification problems in machine learning. +// +// ## Formula +// +// For a pair of actual and predicted values, represented as vectors `y_true` and +// `y_pred`, the Negative Log Likelihood loss is calculated as: +// +// - loss = `-y_true * log(y_pred) - (1 - y_true) * log(1 - y_pred)`. +// +// It returns the average loss by dividing the `total_loss` by total no. of +// elements. +// +// https://towardsdatascience.com/cross-entropy-negative-log-likelihood-and-all-that-jazz-47a95bd2e81 +// http://neuralnetworksanddeeplearning.com/chap3.html +// Derivation of the formula: +// https://medium.com/@bhardwajprakarsh/negative-log-likelihood-loss-why-do-we-use-it-for-binary-classification-7625f9e3c944 + +pub fn neg_log_likelihood( + y_true: &[f64], + y_pred: &[f64], +) -> Result { + // Checks if the inputs are empty + if y_true.len() != y_pred.len() { + return Err(NegativeLogLikelihoodLossError::InputsHaveDifferentLength); + } + // Checks if the length of the actual and predicted values are equal + if y_pred.is_empty() { + return Err(NegativeLogLikelihoodLossError::EmptyInputs); + } + // Checks values are between 0 and 1 + if !are_all_values_in_range(y_true) || !are_all_values_in_range(y_pred) { + return Err(NegativeLogLikelihoodLossError::InvalidValues); + } + + let mut total_loss: f64 = 0.0; + for (p, a) in y_pred.iter().zip(y_true.iter()) { + let loss: f64 = -a * p.ln() - (1.0 - a) * (1.0 - p).ln(); + total_loss += loss; + } + Ok(total_loss / (y_pred.len() as f64)) +} + +#[derive(Debug, PartialEq, Eq)] +pub enum NegativeLogLikelihoodLossError { + InputsHaveDifferentLength, + EmptyInputs, + InvalidValues, +} + +fn are_all_values_in_range(values: &[f64]) -> bool { + values.iter().all(|&x| (0.0..=1.0).contains(&x)) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_with_wrong_inputs { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (values_a, values_b, expected_error) = $inputs; + assert_eq!(neg_log_likelihood(&values_a, &values_b), expected_error); + assert_eq!(neg_log_likelihood(&values_b, &values_a), expected_error); + } + )* + } + } + + test_with_wrong_inputs! { + different_length: (vec![0.9, 0.0, 0.8], vec![0.9, 0.1], Err(NegativeLogLikelihoodLossError::InputsHaveDifferentLength)), + different_length_one_empty: (vec![], vec![0.9, 0.1], Err(NegativeLogLikelihoodLossError::InputsHaveDifferentLength)), + value_greater_than_1: (vec![1.1, 0.0, 0.8], vec![0.1, 0.2, 0.3], Err(NegativeLogLikelihoodLossError::InvalidValues)), + value_greater_smaller_than_0: (vec![0.9, 0.0, -0.1], vec![0.1, 0.2, 0.3], Err(NegativeLogLikelihoodLossError::InvalidValues)), + empty_input: (vec![], vec![], Err(NegativeLogLikelihoodLossError::EmptyInputs)), + } + + macro_rules! test_neg_log_likelihood { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (actual_values, predicted_values, expected) = $inputs; + assert_eq!(neg_log_likelihood(&actual_values, &predicted_values).unwrap(), expected); + } + )* + } + } + + test_neg_log_likelihood! { + set_0: (vec![1.0, 0.0, 1.0], vec![0.9, 0.1, 0.8], 0.14462152754328741), + set_1: (vec![1.0, 0.0, 1.0], vec![0.1, 0.2, 0.3], 1.2432338162113972), + set_2: (vec![0.0, 1.0, 0.0], vec![0.1, 0.2, 0.3], 0.6904911240102196), + set_3: (vec![1.0, 0.0, 1.0, 0.0], vec![0.9, 0.1, 0.8, 0.2], 0.164252033486018), + } +} diff --git a/src/machine_learning/mod.rs b/src/machine_learning/mod.rs new file mode 100644 index 00000000000..534326d2121 --- /dev/null +++ b/src/machine_learning/mod.rs @@ -0,0 +1,20 @@ +mod cholesky; +mod k_means; +mod linear_regression; +mod logistic_regression; +mod loss_function; +mod optimization; + +pub use self::cholesky::cholesky; +pub use self::k_means::k_means; +pub use self::linear_regression::linear_regression; +pub use self::logistic_regression::logistic_regression; +pub use self::loss_function::average_margin_ranking_loss; +pub use self::loss_function::hng_loss; +pub use self::loss_function::huber_loss; +pub use self::loss_function::kld_loss; +pub use self::loss_function::mae_loss; +pub use self::loss_function::mse_loss; +pub use self::loss_function::neg_log_likelihood; +pub use self::optimization::gradient_descent; +pub use self::optimization::Adam; diff --git a/src/machine_learning/optimization/adam.rs b/src/machine_learning/optimization/adam.rs new file mode 100644 index 00000000000..6fbebc6d39d --- /dev/null +++ b/src/machine_learning/optimization/adam.rs @@ -0,0 +1,288 @@ +//! # Adam (Adaptive Moment Estimation) optimizer +//! +//! The `Adam (Adaptive Moment Estimation)` optimizer is an adaptive learning rate algorithm used +//! in gradient descent and machine learning, such as for training neural networks to solve deep +//! learning problems. Boasting memory-efficient fast convergence rates, it sets and iteratively +//! updates learning rates individually for each model parameter based on the gradient history. +//! +//! ## Algorithm: +//! +//! Given: +//! - α is the learning rate +//! - (β_1, β_2) are the exponential decay rates for moment estimates +//! - ϵ is any small value to prevent division by zero +//! - g_t are the gradients at time step t +//! - m_t are the biased first moment estimates of the gradient at time step t +//! - v_t are the biased second raw moment estimates of the gradient at time step t +//! - θ_t are the model parameters at time step t +//! - t is the time step +//! +//! Required: +//! θ_0 +//! +//! Initialize: +//! m_0 <- 0 +//! v_0 <- 0 +//! t <- 0 +//! +//! while θ_t not converged do +//! m_t = β_1 * m_{t−1} + (1 − β_1) * g_t +//! v_t = β_2 * v_{t−1} + (1 − β_2) * g_t^2 +//! m_hat_t = m_t / 1 - β_1^t +//! v_hat_t = v_t / 1 - β_2^t +//! θ_t = θ_{t-1} − α * m_hat_t / (sqrt(v_hat_t) + ϵ) +//! +//! ## Resources: +//! - Adam: A Method for Stochastic Optimization (by Diederik P. Kingma and Jimmy Ba): +//! - [https://arxiv.org/abs/1412.6980] +//! - PyTorch Adam optimizer: +//! - [https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam] +//! +pub struct Adam { + learning_rate: f64, // alpha: initial step size for iterative optimization + betas: (f64, f64), // betas: exponential decay rates for moment estimates + epsilon: f64, // epsilon: prevent division by zero + m: Vec, // m: biased first moment estimate of the gradient vector + v: Vec, // v: biased second raw moment estimate of the gradient vector + t: usize, // t: time step +} + +impl Adam { + pub fn new( + learning_rate: Option, + betas: Option<(f64, f64)>, + epsilon: Option, + params_len: usize, + ) -> Self { + Adam { + learning_rate: learning_rate.unwrap_or(1e-3), // typical good default lr + betas: betas.unwrap_or((0.9, 0.999)), // typical good default decay rates + epsilon: epsilon.unwrap_or(1e-8), // typical good default epsilon + m: vec![0.0; params_len], // first moment vector elements all initialized to zero + v: vec![0.0; params_len], // second moment vector elements all initialized to zero + t: 0, // time step initialized to zero + } + } + + pub fn step(&mut self, gradients: &[f64]) -> Vec { + let mut model_params = vec![0.0; gradients.len()]; + self.t += 1; + + for i in 0..gradients.len() { + // update biased first moment estimate and second raw moment estimate + self.m[i] = self.betas.0 * self.m[i] + (1.0 - self.betas.0) * gradients[i]; + self.v[i] = self.betas.1 * self.v[i] + (1.0 - self.betas.1) * gradients[i].powf(2f64); + + // compute bias-corrected first moment estimate and second raw moment estimate + let m_hat = self.m[i] / (1.0 - self.betas.0.powi(self.t as i32)); + let v_hat = self.v[i] / (1.0 - self.betas.1.powi(self.t as i32)); + + // update model parameters + model_params[i] -= self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon); + } + model_params // return updated model parameters + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_adam_init_default_values() { + let optimizer = Adam::new(None, None, None, 1); + + assert_eq!(optimizer.learning_rate, 0.001); + assert_eq!(optimizer.betas, (0.9, 0.999)); + assert_eq!(optimizer.epsilon, 1e-8); + assert_eq!(optimizer.m, vec![0.0; 1]); + assert_eq!(optimizer.v, vec![0.0; 1]); + assert_eq!(optimizer.t, 0); + } + + #[test] + fn test_adam_init_custom_lr_value() { + let optimizer = Adam::new(Some(0.9), None, None, 2); + + assert_eq!(optimizer.learning_rate, 0.9); + assert_eq!(optimizer.betas, (0.9, 0.999)); + assert_eq!(optimizer.epsilon, 1e-8); + assert_eq!(optimizer.m, vec![0.0; 2]); + assert_eq!(optimizer.v, vec![0.0; 2]); + assert_eq!(optimizer.t, 0); + } + + #[test] + fn test_adam_init_custom_betas_value() { + let optimizer = Adam::new(None, Some((0.8, 0.899)), None, 3); + + assert_eq!(optimizer.learning_rate, 0.001); + assert_eq!(optimizer.betas, (0.8, 0.899)); + assert_eq!(optimizer.epsilon, 1e-8); + assert_eq!(optimizer.m, vec![0.0; 3]); + assert_eq!(optimizer.v, vec![0.0; 3]); + assert_eq!(optimizer.t, 0); + } + + #[test] + fn test_adam_init_custom_epsilon_value() { + let optimizer = Adam::new(None, None, Some(1e-10), 4); + + assert_eq!(optimizer.learning_rate, 0.001); + assert_eq!(optimizer.betas, (0.9, 0.999)); + assert_eq!(optimizer.epsilon, 1e-10); + assert_eq!(optimizer.m, vec![0.0; 4]); + assert_eq!(optimizer.v, vec![0.0; 4]); + assert_eq!(optimizer.t, 0); + } + + #[test] + fn test_adam_init_all_custom_values() { + let optimizer = Adam::new(Some(1.0), Some((0.001, 0.099)), Some(1e-1), 5); + + assert_eq!(optimizer.learning_rate, 1.0); + assert_eq!(optimizer.betas, (0.001, 0.099)); + assert_eq!(optimizer.epsilon, 1e-1); + assert_eq!(optimizer.m, vec![0.0; 5]); + assert_eq!(optimizer.v, vec![0.0; 5]); + assert_eq!(optimizer.t, 0); + } + + #[test] + fn test_adam_step_default_params() { + let gradients = vec![-1.0, 2.0, -3.0, 4.0, -5.0, 6.0, -7.0, 8.0]; + + let mut optimizer = Adam::new(None, None, None, 8); + let updated_params = optimizer.step(&gradients); + + assert_eq!( + updated_params, + vec![ + 0.0009999999900000003, + -0.000999999995, + 0.0009999999966666666, + -0.0009999999975, + 0.000999999998, + -0.0009999999983333334, + 0.0009999999985714286, + -0.00099999999875 + ] + ); + } + + #[test] + fn test_adam_step_custom_params() { + let gradients = vec![9.0, -8.0, 7.0, -6.0, 5.0, -4.0, 3.0, -2.0, 1.0]; + + let mut optimizer = Adam::new(Some(0.005), Some((0.5, 0.599)), Some(1e-5), 9); + let updated_params = optimizer.step(&gradients); + + assert_eq!( + updated_params, + vec![ + -0.004999994444450618, + 0.004999993750007813, + -0.004999992857153062, + 0.004999991666680556, + -0.004999990000020001, + 0.004999987500031251, + -0.004999983333388888, + 0.004999975000124999, + -0.0049999500004999945 + ] + ); + } + + #[test] + fn test_adam_step_empty_gradients_array() { + let gradients = vec![]; + + let mut optimizer = Adam::new(None, None, None, 0); + let updated_params = optimizer.step(&gradients); + + assert_eq!(updated_params, vec![]); + } + + #[ignore] + #[test] + fn test_adam_step_iteratively_until_convergence_with_default_params() { + const CONVERGENCE_THRESHOLD: f64 = 1e-5; + let gradients = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + + let mut optimizer = Adam::new(None, None, None, 6); + + let mut model_params = vec![0.0; 6]; + let mut updated_params = optimizer.step(&gradients); + + while (updated_params + .iter() + .zip(model_params.iter()) + .map(|(x, y)| x - y) + .collect::>()) + .iter() + .map(|&x| x.powi(2)) + .sum::() + .sqrt() + > CONVERGENCE_THRESHOLD + { + model_params = updated_params; + updated_params = optimizer.step(&gradients); + } + + assert!(updated_params < vec![CONVERGENCE_THRESHOLD; 6]); + assert_ne!(updated_params, model_params); + assert_eq!( + updated_params, + vec![ + -0.0009999999899999931, + -0.0009999999949999929, + -0.0009999999966666597, + -0.0009999999974999929, + -0.0009999999979999927, + -0.0009999999983333263 + ] + ); + } + + #[ignore] + #[test] + fn test_adam_step_iteratively_until_convergence_with_custom_params() { + const CONVERGENCE_THRESHOLD: f64 = 1e-7; + let gradients = vec![7.0, -8.0, 9.0, -10.0, 11.0, -12.0, 13.0]; + + let mut optimizer = Adam::new(Some(0.005), Some((0.8, 0.899)), Some(1e-5), 7); + + let mut model_params = vec![0.0; 7]; + let mut updated_params = optimizer.step(&gradients); + + while (updated_params + .iter() + .zip(model_params.iter()) + .map(|(x, y)| x - y) + .collect::>()) + .iter() + .map(|&x| x.powi(2)) + .sum::() + .sqrt() + > CONVERGENCE_THRESHOLD + { + model_params = updated_params; + updated_params = optimizer.step(&gradients); + } + + assert!(updated_params < vec![CONVERGENCE_THRESHOLD; 7]); + assert_ne!(updated_params, model_params); + assert_eq!( + updated_params, + vec![ + -0.004999992857153061, + 0.004999993750007814, + -0.0049999944444506185, + 0.004999995000005001, + -0.004999995454549587, + 0.004999995833336807, + -0.004999996153849113 + ] + ); + } +} diff --git a/src/machine_learning/optimization/gradient_descent.rs b/src/machine_learning/optimization/gradient_descent.rs new file mode 100644 index 00000000000..fd322a23ff3 --- /dev/null +++ b/src/machine_learning/optimization/gradient_descent.rs @@ -0,0 +1,86 @@ +/// Gradient Descent Optimization +/// +/// Gradient descent is an iterative optimization algorithm used to find the minimum of a function. +/// It works by updating the parameters (in this case, elements of the vector `x`) in the direction of +/// the steepest decrease in the function's value. This is achieved by subtracting the gradient of +/// the function at the current point from the current point. The learning rate controls the step size. +/// +/// The equation for a single parameter (univariate) is: +/// x_{k+1} = x_k - learning_rate * derivative_of_function(x_k) +/// +/// For multivariate functions, it extends to each parameter: +/// x_{k+1} = x_k - learning_rate * gradient_of_function(x_k) +/// +/// # Arguments +/// +/// * `derivative_fn` - The function that calculates the gradient of the objective function at a given point. +/// * `x` - The initial parameter vector to be optimized. +/// * `learning_rate` - Step size for each iteration. +/// * `num_iterations` - The number of iterations to run the optimization. +/// +/// # Returns +/// +/// A reference to the optimized parameter vector `x`. + +pub fn gradient_descent( + derivative_fn: impl Fn(&[f64]) -> Vec, + x: &mut Vec, + learning_rate: f64, + num_iterations: i32, +) -> &mut Vec { + for _ in 0..num_iterations { + let gradient = derivative_fn(x); + for (x_k, grad) in x.iter_mut().zip(gradient.iter()) { + *x_k -= learning_rate * grad; + } + } + + x +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_gradient_descent_optimized() { + fn derivative_of_square(params: &[f64]) -> Vec { + params.iter().map(|x| 2. * x).collect() + } + + let mut x: Vec = vec![5.0, 6.0]; + let learning_rate: f64 = 0.03; + let num_iterations: i32 = 1000; + + let minimized_vector = + gradient_descent(derivative_of_square, &mut x, learning_rate, num_iterations); + + let test_vector = [0.0, 0.0]; + + let tolerance = 1e-6; + for (minimized_value, test_value) in minimized_vector.iter().zip(test_vector.iter()) { + assert!((minimized_value - test_value).abs() < tolerance); + } + } + + #[test] + fn test_gradient_descent_unoptimized() { + fn derivative_of_square(params: &[f64]) -> Vec { + params.iter().map(|x| 2. * x).collect() + } + + let mut x: Vec = vec![5.0, 6.0]; + let learning_rate: f64 = 0.03; + let num_iterations: i32 = 10; + + let minimized_vector = + gradient_descent(derivative_of_square, &mut x, learning_rate, num_iterations); + + let test_vector = [0.0, 0.0]; + + let tolerance = 1e-6; + for (minimized_value, test_value) in minimized_vector.iter().zip(test_vector.iter()) { + assert!((minimized_value - test_value).abs() >= tolerance); + } + } +} diff --git a/src/machine_learning/optimization/mod.rs b/src/machine_learning/optimization/mod.rs new file mode 100644 index 00000000000..7a962993beb --- /dev/null +++ b/src/machine_learning/optimization/mod.rs @@ -0,0 +1,5 @@ +mod adam; +mod gradient_descent; + +pub use self::adam::Adam; +pub use self::gradient_descent::gradient_descent; diff --git a/src/math/abs.rs b/src/math/abs.rs new file mode 100644 index 00000000000..a10a4b30361 --- /dev/null +++ b/src/math/abs.rs @@ -0,0 +1,38 @@ +/// This function returns the absolute value of a number.\ +/// The absolute value of a number is the non-negative value of the number, regardless of its sign.\ +/// +/// Wikipedia: +pub fn abs(num: T) -> T +where + T: std::ops::Neg + PartialOrd + Copy + num_traits::Zero, +{ + if num < T::zero() { + return -num; + } + num +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_negative_number_i32() { + assert_eq!(69, abs(-69)); + } + + #[test] + fn test_negative_number_f64() { + assert_eq!(69.69, abs(-69.69)); + } + + #[test] + fn zero() { + assert_eq!(0.0, abs(0.0)); + } + + #[test] + fn positive_number() { + assert_eq!(69.69, abs(69.69)); + } +} diff --git a/src/math/aliquot_sum.rs b/src/math/aliquot_sum.rs new file mode 100644 index 00000000000..28bf5981a5e --- /dev/null +++ b/src/math/aliquot_sum.rs @@ -0,0 +1,56 @@ +/// Aliquot sum of a number is defined as the sum of the proper divisors of a number.\ +/// i.e. all the divisors of a number apart from the number itself. +/// +/// ## Example: +/// The aliquot sum of 6 is (1 + 2 + 3) = 6, and that of 15 is (1 + 3 + 5) = 9 +/// +/// Wikipedia article on Aliquot Sum: + +pub fn aliquot_sum(number: u64) -> u64 { + if number == 0 { + panic!("Input has to be positive.") + } + + (1..=number / 2).filter(|&d| number % d == 0).sum() +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_aliquot_sum { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (number, expected) = $tc; + assert_eq!(aliquot_sum(number), expected); + } + )* + } + } + + test_aliquot_sum! { + test_with_1: (1, 0), + test_with_2: (2, 1), + test_with_3: (3, 1), + test_with_4: (4, 1+2), + test_with_5: (5, 1), + test_with_6: (6, 6), + test_with_7: (7, 1), + test_with_8: (8, 1+2+4), + test_with_9: (9, 1+3), + test_with_10: (10, 1+2+5), + test_with_15: (15, 9), + test_with_343: (343, 57), + test_with_344: (344, 316), + test_with_500: (500, 592), + test_with_501: (501, 171), + } + + #[test] + #[should_panic] + fn panics_if_input_is_zero() { + aliquot_sum(0); + } +} diff --git a/src/math/amicable_numbers.rs b/src/math/amicable_numbers.rs new file mode 100644 index 00000000000..35ff4d7fcfe --- /dev/null +++ b/src/math/amicable_numbers.rs @@ -0,0 +1,68 @@ +// Operations based around amicable numbers +// Suports u32 but should be interchangable with other types +// Wikipedia reference: https://en.wikipedia.org/wiki/Amicable_numbers + +// Returns vec of amicable pairs below N +// N must be positive +pub fn amicable_pairs_under_n(n: u32) -> Option> { + let mut factor_sums = vec![0; n as usize]; + + // Make a list of the sum of the factors of each number below N + for i in 1..n { + for j in (i * 2..n).step_by(i as usize) { + factor_sums[j as usize] += i; + } + } + + // Default value of (0, 0) if no pairs are found + let mut out = vec![(0, 0)]; + // Check if numbers are amicable then append + for (i, x) in factor_sums.iter().enumerate() { + if (*x < n) && (factor_sums[*x as usize] == i as u32) && (*x > i as u32) { + out.push((i as u32, *x)); + } + } + + // Check if anything was added to the vec, if so remove the (0, 0) and return + if out.len() == 1 { + None + } else { + out.remove(0); + Some(out) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + pub fn test_amicable_numbers_below_n() { + // First 10 amicable numbers, sorted (low, high) + let expected_result = vec![ + (220, 284), + (1184, 1210), + (2620, 2924), + (5020, 5564), + (6232, 6368), + (10744, 10856), + (12285, 14595), + (17296, 18416), + (63020, 76084), + (66928, 66992), + ]; + + // Generate pairs under 100,000 + let mut result = amicable_pairs_under_n(100_000).unwrap(); + + // There should be 13 pairs under 100,000 + assert_eq!(result.len(), 13); + + // Check the first 10 against known values + result = result[..10].to_vec(); + assert_eq!(result, expected_result); + + // N that does not have any amicable pairs below it, the result should be None + assert_eq!(amicable_pairs_under_n(100), None); + } +} diff --git a/src/math/area_of_polygon.rs b/src/math/area_of_polygon.rs new file mode 100644 index 00000000000..49d006c56b4 --- /dev/null +++ b/src/math/area_of_polygon.rs @@ -0,0 +1,94 @@ +/** + * @file + * @brief Calculate the area of a polygon defined by a vector of points. + * + * @details + * This program provides a function to calculate the area of a polygon defined by a vector of points. + * The area is calculated using the formula: A = |Σ((xi - xi-1) * (yi + yi-1))| / 2 + * where (xi, yi) are the coordinates of the points in the vector. + * + * @param fig A vector of points defining the polygon. + * @return The area of the polygon. + * + * @author [Gyandeep](https://github.com/Gyan172004) + * @see [Wikipedia - Polygon](https://en.wikipedia.org/wiki/Polygon) + */ + +pub struct Point { + x: f64, + y: f64, +} + +/** + * Calculate the area of a polygon defined by a vector of points. + * @param fig A vector of points defining the polygon. + * @return The area of the polygon. + */ + +pub fn area_of_polygon(fig: &[Point]) -> f64 { + let mut res = 0.0; + + for i in 0..fig.len() { + let p = if i > 0 { + &fig[i - 1] + } else { + &fig[fig.len() - 1] + }; + let q = &fig[i]; + + res += (p.x - q.x) * (p.y + q.y); + } + + f64::abs(res) / 2.0 +} + +#[cfg(test)] +mod tests { + use super::*; + + /** + * Test case for calculating the area of a triangle. + */ + #[test] + fn test_area_triangle() { + let points = vec![ + Point { x: 0.0, y: 0.0 }, + Point { x: 1.0, y: 0.0 }, + Point { x: 0.0, y: 1.0 }, + ]; + + assert_eq!(area_of_polygon(&points), 0.5); + } + + /** + * Test case for calculating the area of a square. + */ + #[test] + fn test_area_square() { + let points = vec![ + Point { x: 0.0, y: 0.0 }, + Point { x: 1.0, y: 0.0 }, + Point { x: 1.0, y: 1.0 }, + Point { x: 0.0, y: 1.0 }, + ]; + + assert_eq!(area_of_polygon(&points), 1.0); + } + + /** + * Test case for calculating the area of a hexagon. + */ + #[test] + fn test_area_hexagon() { + let points = vec![ + Point { x: 0.0, y: 0.0 }, + Point { x: 1.0, y: 0.0 }, + Point { x: 1.5, y: 0.866 }, + Point { x: 1.0, y: 1.732 }, + Point { x: 0.0, y: 1.732 }, + Point { x: -0.5, y: 0.866 }, + ]; + + assert_eq!(area_of_polygon(&points), 2.598); + } +} diff --git a/src/math/area_under_curve.rs b/src/math/area_under_curve.rs new file mode 100644 index 00000000000..d4d7133ec38 --- /dev/null +++ b/src/math/area_under_curve.rs @@ -0,0 +1,53 @@ +pub fn area_under_curve(start: f64, end: f64, func: fn(f64) -> f64, step_count: usize) -> f64 { + assert!(step_count > 0); + + let (start, end) = if start > end { + (end, start) + } else { + (start, end) + }; //swap if bounds reversed + + let step_length: f64 = (end - start) / step_count as f64; + let mut area = 0f64; + let mut fx1 = func(start); + let mut fx2: f64; + + for eval_point in (1..=step_count).map(|x| (x as f64 * step_length) + start) { + fx2 = func(eval_point); + area += (fx2 + fx1).abs() * step_length * 0.5; + fx1 = fx2; + } + + area +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_linear_func() { + assert_eq!(area_under_curve(1f64, 2f64, |x| x, 10), 1.5000000000000002); + } + + #[test] + fn test_quadratic_func() { + assert_eq!( + area_under_curve(1f64, 2f64, |x| x * x, 1000), + 2.333333500000005 + ); + } + + #[test] + fn test_zero_length() { + assert_eq!(area_under_curve(0f64, 0f64, |x| x * x, 1000), 0.0); + } + + #[test] + fn test_reverse() { + assert_eq!( + area_under_curve(1f64, 2f64, |x| x, 10), + area_under_curve(2f64, 1f64, |x| x, 10) + ); + } +} diff --git a/src/math/armstrong_number.rs b/src/math/armstrong_number.rs new file mode 100644 index 00000000000..be71fc162a6 --- /dev/null +++ b/src/math/armstrong_number.rs @@ -0,0 +1,44 @@ +pub fn is_armstrong_number(number: u32) -> bool { + let mut digits: Vec = Vec::new(); + let mut num: u32 = number; + + loop { + digits.push(num % 10); + num /= 10; + if num == 0 { + break; + } + } + + let sum_nth_power_of_digits: u32 = digits + .iter() + .map(|digit| digit.pow(digits.len() as u32)) + .sum(); + sum_nth_power_of_digits == number +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn one_digit_armstrong_number() { + assert!(is_armstrong_number(1)) + } + #[test] + fn two_digit_numbers_are_not_armstrong_numbers() { + assert!(!is_armstrong_number(15)) + } + #[test] + fn three_digit_armstrong_number() { + assert!(is_armstrong_number(153)) + } + #[test] + fn three_digit_non_armstrong_number() { + assert!(!is_armstrong_number(105)) + } + #[test] + fn big_armstrong_number() { + assert!(is_armstrong_number(912985153)) + } +} diff --git a/src/math/average.rs b/src/math/average.rs new file mode 100644 index 00000000000..dfa38f3a92f --- /dev/null +++ b/src/math/average.rs @@ -0,0 +1,122 @@ +#[doc = "# Average +Mean, Median, and Mode, in mathematics, the three principal ways of designating the average value of a list of numbers. +The arithmetic mean is found by adding the numbers and dividing the sum by the number of numbers in the list. +This is what is most often meant by an average. The median is the middle value in a list ordered from smallest to largest. +The mode is the most frequently occurring value on the list. + +Reference: https://www.britannica.com/science/mean-median-and-mode + +This program approximates the mean, median and mode of a finite sequence. +Note: Floats sequences are not allowed for `mode` function. +"] +use std::collections::HashMap; +use std::collections::HashSet; + +use num_traits::Num; + +fn sum(sequence: Vec) -> T { + sequence.iter().fold(T::zero(), |acc, x| acc + *x) +} + +/// # Argument +/// +/// * `sequence` - A vector of numbers. +/// Returns mean of `sequence`. +pub fn mean(sequence: Vec) -> Option { + let len = sequence.len(); + if len == 0 { + return None; + } + Some(sum(sequence) / (T::from_usize(len).unwrap())) +} + +fn mean_of_two(a: T, b: T) -> T { + (a + b) / (T::one() + T::one()) +} + +/// # Argument +/// +/// * `sequence` - A vector of numbers. +/// Returns median of `sequence`. + +pub fn median(mut sequence: Vec) -> Option { + if sequence.is_empty() { + return None; + } + sequence.sort_by(|a, b| a.partial_cmp(b).unwrap()); + if sequence.len() % 2 == 1 { + let k = (sequence.len() + 1) / 2; + Some(sequence[k - 1]) + } else { + let j = (sequence.len()) / 2; + Some(mean_of_two(sequence[j - 1], sequence[j])) + } +} + +fn histogram(sequence: Vec) -> HashMap { + sequence.into_iter().fold(HashMap::new(), |mut res, val| { + *res.entry(val).or_insert(0) += 1; + res + }) +} + +/// # Argument +/// +/// * `sequence` - The input vector. +/// Returns mode of `sequence`. +pub fn mode(sequence: Vec) -> Option> { + if sequence.is_empty() { + return None; + } + let hist = histogram(sequence); + let max_count = *hist.values().max().unwrap(); + Some( + hist.into_iter() + .filter(|(_, count)| *count == max_count) + .map(|(value, _)| value) + .collect(), + ) +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn median_test() { + assert_eq!(median(vec![4, 53, 2, 1, 9, 0, 2, 3, 6]).unwrap(), 3); + assert_eq!(median(vec![-9, -8, 0, 1, 2, 2, 3, 4, 6, 9, 53]).unwrap(), 2); + assert_eq!(median(vec![2, 3]).unwrap(), 2); + assert_eq!(median(vec![3.0, 2.0]).unwrap(), 2.5); + assert_eq!(median(vec![1.0, 700.0, 5.0]).unwrap(), 5.0); + assert!(median(Vec::::new()).is_none()); + assert!(median(Vec::::new()).is_none()); + } + #[test] + fn mode_test() { + assert_eq!( + mode(vec![4, 53, 2, 1, 9, 0, 2, 3, 6]).unwrap(), + HashSet::from([2]) + ); + assert_eq!( + mode(vec![-9, -8, 0, 1, 2, 2, 3, -1, -1, 9, -1, -9]).unwrap(), + HashSet::from([-1]) + ); + assert_eq!(mode(vec!["a", "b", "a"]).unwrap(), HashSet::from(["a"])); + assert_eq!(mode(vec![1, 2, 2, 1]).unwrap(), HashSet::from([1, 2])); + assert_eq!(mode(vec![1, 2, 2, 1, 3]).unwrap(), HashSet::from([1, 2])); + assert_eq!(mode(vec![1]).unwrap(), HashSet::from([1])); + assert!(mode(Vec::::new()).is_none()); + } + #[test] + fn mean_test() { + assert_eq!(mean(vec![2023.1112]).unwrap(), 2023.1112); + assert_eq!(mean(vec![0.0, 1.0, 2.0, 3.0, 4.0]).unwrap(), 2.0); + assert_eq!( + mean(vec![-7.0, 4.0, 53.0, 2.0, 1.0, -9.0, 0.0, 2.0, 3.0, -6.0]).unwrap(), + 4.3 + ); + assert_eq!(mean(vec![1, 2]).unwrap(), 1); + assert!(mean(Vec::::new()).is_none()); + assert!(mean(Vec::::new()).is_none()); + } +} diff --git a/src/math/baby_step_giant_step.rs b/src/math/baby_step_giant_step.rs new file mode 100644 index 00000000000..1c4d1cc74c7 --- /dev/null +++ b/src/math/baby_step_giant_step.rs @@ -0,0 +1,83 @@ +use crate::math::greatest_common_divisor; +/// Baby-step Giant-step algorithm +/// +/// Solving discrete logarithm problem: +/// a^x = b (mod n) , with respect to gcd(a, n) == 1 +/// with O(sqrt(n)) time complexity. +/// +/// Wikipedia reference: https://en.wikipedia.org/wiki/Baby-step_giant-step +/// When a is the primitive root modulo n, the answer is unique. +/// Otherwise it will return the smallest positive solution +use std::collections::HashMap; + +pub fn baby_step_giant_step(a: usize, b: usize, n: usize) -> Option { + if greatest_common_divisor::greatest_common_divisor_stein(a as u64, n as u64) != 1 { + return None; + } + + let mut h_map = HashMap::new(); + let m = (n as f64).sqrt().ceil() as usize; + // baby step + let mut step = 1; + for i in 0..m { + h_map.insert((step * b) % n, i); + step = (step * a) % n; + } + // Now step = a^m (mod n), giant step + let giant_step = step; + for i in (m..=n).step_by(m) { + if let Some(v) = h_map.get(&step) { + return Some(i - v); + } + step = (step * giant_step) % n; + } + None +} + +#[cfg(test)] +mod tests { + use super::baby_step_giant_step; + + #[test] + fn small_numbers() { + assert_eq!(baby_step_giant_step(5, 3, 11), Some(2)); + assert_eq!(baby_step_giant_step(3, 83, 100), Some(9)); + assert_eq!(baby_step_giant_step(9, 1, 61), Some(5)); + assert_eq!(baby_step_giant_step(5, 1, 67), Some(22)); + assert_eq!(baby_step_giant_step(7, 1, 45), Some(12)); + } + + #[test] + fn primitive_root_tests() { + assert_eq!( + baby_step_giant_step(3, 311401496, 998244353), + Some(178105253) + ); + assert_eq!( + baby_step_giant_step(5, 324637211, 1000000007), + Some(976653449) + ); + } + + #[test] + fn random_numbers() { + assert_eq!(baby_step_giant_step(174857, 48604, 150991), Some(177)); + assert_eq!(baby_step_giant_step(912103, 53821, 75401), Some(2644)); + assert_eq!(baby_step_giant_step(448447, 365819, 671851), Some(23242)); + assert_eq!( + baby_step_giant_step(220757103, 92430653, 434948279), + Some(862704) + ); + assert_eq!( + baby_step_giant_step(176908456, 23538399, 142357679), + Some(14215560) + ); + } + + #[test] + fn no_solution() { + assert!(baby_step_giant_step(7, 6, 45).is_none()); + assert!(baby_step_giant_step(23, 15, 85).is_none()); + assert!(baby_step_giant_step(2, 1, 84).is_none()); + } +} diff --git a/src/math/bell_numbers.rs b/src/math/bell_numbers.rs new file mode 100644 index 00000000000..9a66c83087e --- /dev/null +++ b/src/math/bell_numbers.rs @@ -0,0 +1,147 @@ +use num_bigint::BigUint; +use num_traits::{One, Zero}; +use std::sync::RwLock; + +/// Returns the number of ways you can select r items given n options +fn n_choose_r(n: u32, r: u32) -> BigUint { + if r == n || r == 0 { + return One::one(); + } + + if r > n { + return Zero::zero(); + } + + // Any combination will only need to be computed once, thus giving no need to + // memoize this function + + let product: BigUint = (0..r).fold(BigUint::one(), |acc, x| { + (acc * BigUint::from(n - x)) / BigUint::from(x + 1) + }); + + product +} + +/// A memoization table for storing previous results +struct MemTable { + buffer: Vec, +} + +impl MemTable { + const fn new() -> Self { + MemTable { buffer: Vec::new() } + } + + fn get(&self, n: usize) -> Option { + if n == 0 || n == 1 { + Some(BigUint::one()) + } else if let Some(entry) = self.buffer.get(n) { + if *entry == BigUint::zero() { + None + } else { + Some(entry.clone()) + } + } else { + None + } + } + + fn set(&mut self, n: usize, b: BigUint) { + self.buffer[n] = b; + } + + #[inline] + fn capacity(&self) -> usize { + self.buffer.capacity() + } + + #[inline] + fn resize(&mut self, new_size: usize) { + if new_size > self.buffer.len() { + self.buffer.resize(new_size, Zero::zero()); + } + } +} + +// Implemented with RwLock so it is accessible across threads +static LOOKUP_TABLE_LOCK: RwLock = RwLock::new(MemTable::new()); + +pub fn bell_number(n: u32) -> BigUint { + let needs_resize; + + // Check if number is already in lookup table + { + let lookup_table = LOOKUP_TABLE_LOCK.read().unwrap(); + + if let Some(entry) = lookup_table.get(n as usize) { + return entry; + } + + needs_resize = (n + 1) as usize > lookup_table.capacity(); + } + + // Resize table before recursion so that if more values need to be added during recursion the table isn't + // reallocated every single time + if needs_resize { + let mut lookup_table = LOOKUP_TABLE_LOCK.write().unwrap(); + + lookup_table.resize((n + 1) as usize); + } + + let new_bell_number: BigUint = (0..n).map(|x| bell_number(x) * n_choose_r(n - 1, x)).sum(); + + // Add new number to lookup table + { + let mut lookup_table = LOOKUP_TABLE_LOCK.write().unwrap(); + + lookup_table.set(n as usize, new_bell_number.clone()); + } + + new_bell_number +} + +#[cfg(test)] +pub mod tests { + use super::*; + use std::str::FromStr; + + #[test] + fn test_choose_zero() { + for i in 1..100 { + assert_eq!(n_choose_r(i, 0), One::one()); + } + } + + #[test] + fn test_combination() { + let five_choose_1 = BigUint::from(5u32); + assert_eq!(n_choose_r(5, 1), five_choose_1); + assert_eq!(n_choose_r(5, 4), five_choose_1); + + let ten_choose_3 = BigUint::from(120u32); + assert_eq!(n_choose_r(10, 3), ten_choose_3); + assert_eq!(n_choose_r(10, 7), ten_choose_3); + + let fourty_two_choose_thirty = BigUint::from_str("11058116888").unwrap(); + assert_eq!(n_choose_r(42, 30), fourty_two_choose_thirty); + assert_eq!(n_choose_r(42, 12), fourty_two_choose_thirty); + } + + #[test] + fn test_bell_numbers() { + let bell_one = BigUint::from(1u32); + assert_eq!(bell_number(1), bell_one); + + let bell_three = BigUint::from(5u32); + assert_eq!(bell_number(3), bell_three); + + let bell_eight = BigUint::from(4140u32); + assert_eq!(bell_number(8), bell_eight); + + let bell_six = BigUint::from(203u32); + assert_eq!(bell_number(6), bell_six); + + let bell_twenty_six = BigUint::from_str("49631246523618756274").unwrap(); + assert_eq!(bell_number(26), bell_twenty_six); + } +} diff --git a/src/math/binary_exponentiation.rs b/src/math/binary_exponentiation.rs new file mode 100644 index 00000000000..d7661adbea8 --- /dev/null +++ b/src/math/binary_exponentiation.rs @@ -0,0 +1,50 @@ +// Binary exponentiation is an algorithm to compute a power in O(logN) where N is the power. +// +// For example, to naively compute n^100, we multiply n 99 times for a O(N) algorithm. +// +// With binary exponentiation we can reduce the number of muliplications by only finding the binary +// exponents. n^100 = n^64 * n^32 * n^4. We can compute n^64 by ((((n^2)^2)^2)...), which is +// logN multiplications. +// +// We know which binary exponents to add by looking at the set bits in the power. For 100, we know +// the bits for 64, 32, and 4 are set. + +// Computes n^p +pub fn binary_exponentiation(mut n: u64, mut p: u32) -> u64 { + let mut result_pow: u64 = 1; + while p > 0 { + if p & 1 == 1 { + result_pow *= n; + } + p >>= 1; + n *= n; + } + result_pow +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basic() { + // Need to be careful about large exponents. It is easy to hit overflows. + assert_eq!(binary_exponentiation(2, 3), 8); + assert_eq!(binary_exponentiation(4, 12), 16777216); + assert_eq!(binary_exponentiation(6, 12), 2176782336); + assert_eq!(binary_exponentiation(10, 4), 10000); + assert_eq!(binary_exponentiation(20, 3), 8000); + assert_eq!(binary_exponentiation(3, 21), 10460353203); + } + + #[test] + fn up_to_ten() { + // Compute all powers from up to ten, using the standard library as the source of truth. + for i in 0..10 { + for j in 0..10 { + println!("{i}, {j}"); + assert_eq!(binary_exponentiation(i, j), u64::pow(i, j)) + } + } + } +} diff --git a/src/math/binomial_coefficient.rs b/src/math/binomial_coefficient.rs new file mode 100644 index 00000000000..e9140f85cac --- /dev/null +++ b/src/math/binomial_coefficient.rs @@ -0,0 +1,75 @@ +extern crate num_bigint; +extern crate num_traits; + +use num_bigint::BigInt; +use num_traits::FromPrimitive; + +/// Calculate binomial coefficient (n choose k). +/// +/// This function computes the binomial coefficient C(n, k) using BigInt +/// for arbitrary precision arithmetic. +/// +/// Formula: +/// C(n, k) = n! / (k! * (n - k)!) +/// +/// Reference: +/// [Binomial Coefficient - Wikipedia](https://en.wikipedia.org/wiki/Binomial_coefficient) +/// +/// # Arguments +/// +/// * `n` - The total number of items. +/// * `k` - The number of items to choose from `n`. +/// +/// # Returns +/// +/// Returns the binomial coefficient C(n, k) as a BigInt. +pub fn binom(n: u64, k: u64) -> BigInt { + let mut res = BigInt::from_u64(1).unwrap(); + for i in 0..k { + res = (res * BigInt::from_u64(n - i).unwrap()) / BigInt::from_u64(i + 1).unwrap(); + } + res +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_binom_5_2() { + assert_eq!(binom(5, 2), BigInt::from(10)); + } + + #[test] + fn test_binom_10_5() { + assert_eq!(binom(10, 5), BigInt::from(252)); + } + + #[test] + fn test_binom_0_0() { + assert_eq!(binom(0, 0), BigInt::from(1)); + } + + #[test] + fn test_binom_large_n_small_k() { + assert_eq!(binom(1000, 2), BigInt::from(499500)); + } + + #[test] + fn test_binom_random_1() { + // Random test case 1 + assert_eq!(binom(7, 4), BigInt::from(35)); + } + + #[test] + fn test_binom_random_2() { + // Random test case 2 + assert_eq!(binom(12, 3), BigInt::from(220)); + } + + #[test] + fn test_binom_random_3() { + // Random test case 3 + assert_eq!(binom(20, 10), BigInt::from(184_756)); + } +} diff --git a/src/math/catalan_numbers.rs b/src/math/catalan_numbers.rs new file mode 100644 index 00000000000..4aceec3289e --- /dev/null +++ b/src/math/catalan_numbers.rs @@ -0,0 +1,59 @@ +// Introduction to Catalan Numbers: +// Catalan numbers are a sequence of natural numbers with many applications in combinatorial mathematics. +// They are named after the Belgian mathematician Eugène Charles Catalan, who contributed to their study. +// Catalan numbers appear in various combinatorial problems, including counting correct bracket sequences, +// full binary trees, triangulations of polygons, and more. + +// For more information, refer to the Wikipedia page on Catalan numbers: +// https://en.wikipedia.org/wiki/Catalan_number + +// Author: [Gyandeep] (https://github.com/Gyan172004) + +const MOD: i64 = 1000000007; // Define your MOD value here +const MAX: usize = 1005; // Define your MAX value here + +pub fn init_catalan() -> Vec { + let mut catalan = vec![0; MAX]; + catalan[0] = 1; + catalan[1] = 1; + + for i in 2..MAX { + catalan[i] = 0; + for j in 0..i { + catalan[i] += (catalan[j] * catalan[i - j - 1]) % MOD; + if catalan[i] >= MOD { + catalan[i] -= MOD; + } + } + } + + catalan +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_catalan() { + let catalan = init_catalan(); + + // Test case 1: Catalan number for n = 0 + assert_eq!(catalan[0], 1); + + // Test case 2: Catalan number for n = 1 + assert_eq!(catalan[1], 1); + + // Test case 3: Catalan number for n = 5 + assert_eq!(catalan[5], 42); + + // Test case 4: Catalan number for n = 10 + assert_eq!(catalan[10], 16796); + + // Test case 5: Catalan number for n = 15 + assert_eq!(catalan[15], 9694845); + + // Print a success message if all tests pass + println!("All tests passed!"); + } +} diff --git a/src/math/ceil.rs b/src/math/ceil.rs new file mode 100644 index 00000000000..c399f5dcd4d --- /dev/null +++ b/src/math/ceil.rs @@ -0,0 +1,58 @@ +// In mathematics and computer science, the ceiling function maps x to the least integer greater than or equal to x +// Source: https://en.wikipedia.org/wiki/Floor_and_ceiling_functions + +pub fn ceil(x: f64) -> f64 { + let x_rounded_towards_zero = x as i32 as f64; + if x < 0. || x_rounded_towards_zero == x { + x_rounded_towards_zero + } else { + x_rounded_towards_zero + 1_f64 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn positive_decimal() { + let num = 1.10; + assert_eq!(ceil(num), num.ceil()); + } + + #[test] + fn positive_decimal_with_small_number() { + let num = 3.01; + assert_eq!(ceil(num), num.ceil()); + } + + #[test] + fn positive_integer() { + let num = 1.00; + assert_eq!(ceil(num), num.ceil()); + } + + #[test] + fn negative_decimal() { + let num = -1.10; + assert_eq!(ceil(num), num.ceil()); + } + + #[test] + fn negative_decimal_with_small_number() { + let num = -1.01; + assert_eq!(ceil(num), num.ceil()); + } + + #[test] + fn negative_integer() { + let num = -1.00; + assert_eq!(ceil(num), num.ceil()); + } + + #[test] + fn zero() { + let num = 0.00; + assert_eq!(ceil(num), num.ceil()); + } +} diff --git a/src/math/chinese_remainder_theorem.rs b/src/math/chinese_remainder_theorem.rs new file mode 100644 index 00000000000..23bca371803 --- /dev/null +++ b/src/math/chinese_remainder_theorem.rs @@ -0,0 +1,35 @@ +use super::extended_euclidean_algorithm; + +fn mod_inv(x: i32, n: i32) -> Option { + let (g, x, _) = extended_euclidean_algorithm(x, n); + if g == 1 { + Some((x % n + n) % n) + } else { + None + } +} + +pub fn chinese_remainder_theorem(residues: &[i32], modulli: &[i32]) -> Option { + let prod = modulli.iter().product::(); + + let mut sum = 0; + + for (&residue, &modulus) in residues.iter().zip(modulli) { + let p = prod / modulus; + sum += residue * mod_inv(p, modulus)? * p + } + Some(sum % prod) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basic() { + assert_eq!(chinese_remainder_theorem(&[3, 5, 7], &[2, 3, 1]), Some(5)); + assert_eq!(chinese_remainder_theorem(&[1, 4, 6], &[3, 5, 7]), Some(34)); + assert_eq!(chinese_remainder_theorem(&[1, 4, 6], &[1, 2, 0]), None); + assert_eq!(chinese_remainder_theorem(&[2, 5, 7], &[6, 9, 15]), None); + } +} diff --git a/src/math/collatz_sequence.rs b/src/math/collatz_sequence.rs new file mode 100644 index 00000000000..32400cc5baa --- /dev/null +++ b/src/math/collatz_sequence.rs @@ -0,0 +1,32 @@ +// collatz conjecture : https://en.wikipedia.org/wiki/Collatz_conjecture +pub fn sequence(mut n: usize) -> Option> { + if n == 0 { + return None; + } + let mut list: Vec = vec![]; + while n != 1 { + list.push(n); + if n % 2 == 0 { + n /= 2; + } else { + n = 3 * n + 1; + } + } + list.push(n); + Some(list) +} + +#[cfg(test)] +mod tests { + use super::sequence; + + #[test] + fn validity_check() { + assert_eq!(sequence(10).unwrap(), [10, 5, 16, 8, 4, 2, 1]); + assert_eq!( + sequence(15).unwrap(), + [15, 46, 23, 70, 35, 106, 53, 160, 80, 40, 20, 10, 5, 16, 8, 4, 2, 1] + ); + assert_eq!(sequence(0).unwrap_or_else(|| vec![0]), [0]); + } +} diff --git a/src/math/combinations.rs b/src/math/combinations.rs new file mode 100644 index 00000000000..8117a66f547 --- /dev/null +++ b/src/math/combinations.rs @@ -0,0 +1,47 @@ +// Function to calculate combinations of k elements from a set of n elements +pub fn combinations(n: i64, k: i64) -> i64 { + // Check if either n or k is negative, and panic if so + if n < 0 || k < 0 { + panic!("Please insert positive values"); + } + + let mut res: i64 = 1; + for i in 0..k { + // Calculate the product of (n - i) and update the result + res *= n - i; + // Divide by (i + 1) to calculate the combination + res /= i + 1; + } + + res +} + +#[cfg(test)] +mod tests { + use super::*; + + // Test case for combinations(10, 5) + #[test] + fn test_combinations_10_choose_5() { + assert_eq!(combinations(10, 5), 252); + } + + // Test case for combinations(6, 3) + #[test] + fn test_combinations_6_choose_3() { + assert_eq!(combinations(6, 3), 20); + } + + // Test case for combinations(20, 5) + #[test] + fn test_combinations_20_choose_5() { + assert_eq!(combinations(20, 5), 15504); + } + + // Test case for invalid input (negative values) + #[test] + #[should_panic(expected = "Please insert positive values")] + fn test_combinations_invalid_input() { + combinations(-5, 10); + } +} diff --git a/src/math/cross_entropy_loss.rs b/src/math/cross_entropy_loss.rs new file mode 100644 index 00000000000..1dbf48022eb --- /dev/null +++ b/src/math/cross_entropy_loss.rs @@ -0,0 +1,40 @@ +//! # Cross-Entropy Loss Function +//! +//! The `cross_entropy_loss` function calculates the cross-entropy loss between the actual and predicted probability distributions. +//! +//! Cross-entropy loss is commonly used in machine learning and deep learning to measure the dissimilarity between two probability distributions. It is often used in classification problems. +//! +//! ## Formula +//! +//! For a pair of actual and predicted probability distributions represented as vectors `actual` and `predicted`, the cross-entropy loss is calculated as: +//! +//! `L = -Σ(actual[i] * ln(predicted[i]))` for all `i` in the range of the vectors +//! +//! Where `ln` is the natural logarithm function, and `Σ` denotes the summation over all elements of the vectors. +//! +//! ## Cross-Entropy Loss Function Implementation +//! +//! This implementation takes two references to vectors of f64 values, `actual` and `predicted`, and returns the cross-entropy loss between them. +//! +pub fn cross_entropy_loss(actual: &[f64], predicted: &[f64]) -> f64 { + let mut loss: Vec = Vec::new(); + for (a, p) in actual.iter().zip(predicted.iter()) { + loss.push(-a * p.ln()); + } + loss.iter().sum() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cross_entropy_loss() { + let test_vector_actual = vec![0., 1., 0., 0., 0., 0.]; + let test_vector = vec![0.1, 0.7, 0.1, 0.05, 0.05, 0.1]; + assert_eq!( + cross_entropy_loss(&test_vector_actual, &test_vector), + 0.35667494393873245 + ); + } +} diff --git a/src/math/decimal_to_fraction.rs b/src/math/decimal_to_fraction.rs new file mode 100644 index 00000000000..5963562b59c --- /dev/null +++ b/src/math/decimal_to_fraction.rs @@ -0,0 +1,67 @@ +pub fn decimal_to_fraction(decimal: f64) -> (i64, i64) { + // Calculate the fractional part of the decimal number + let fractional_part = decimal - decimal.floor(); + + // If the fractional part is zero, the number is already an integer + if fractional_part == 0.0 { + (decimal as i64, 1) + } else { + // Calculate the number of decimal places in the fractional part + let number_of_frac_digits = decimal.to_string().split('.').nth(1).unwrap_or("").len(); + + // Calculate the numerator and denominator using integer multiplication + let numerator = (decimal * 10f64.powi(number_of_frac_digits as i32)) as i64; + let denominator = 10i64.pow(number_of_frac_digits as u32); + + // Find the greatest common divisor (GCD) using Euclid's algorithm + let mut divisor = denominator; + let mut dividend = numerator; + while divisor != 0 { + let r = dividend % divisor; + dividend = divisor; + divisor = r; + } + + // Reduce the fraction by dividing both numerator and denominator by the GCD + let gcd = dividend.abs(); + let numerator = numerator / gcd; + let denominator = denominator / gcd; + + (numerator, denominator) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_decimal_to_fraction_1() { + assert_eq!(decimal_to_fraction(2.0), (2, 1)); + } + + #[test] + fn test_decimal_to_fraction_2() { + assert_eq!(decimal_to_fraction(89.45), (1789, 20)); + } + + #[test] + fn test_decimal_to_fraction_3() { + assert_eq!(decimal_to_fraction(67.), (67, 1)); + } + + #[test] + fn test_decimal_to_fraction_4() { + assert_eq!(decimal_to_fraction(45.2), (226, 5)); + } + + #[test] + fn test_decimal_to_fraction_5() { + assert_eq!(decimal_to_fraction(1.5), (3, 2)); + } + + #[test] + fn test_decimal_to_fraction_6() { + assert_eq!(decimal_to_fraction(6.25), (25, 4)); + } +} diff --git a/src/math/doomsday.rs b/src/math/doomsday.rs new file mode 100644 index 00000000000..3d43f2666bd --- /dev/null +++ b/src/math/doomsday.rs @@ -0,0 +1,36 @@ +const T: [i32; 12] = [0, 3, 2, 5, 0, 3, 5, 1, 4, 6, 2, 4]; + +pub fn doomsday(y: i32, m: i32, d: i32) -> i32 { + let y = if m < 3 { y - 1 } else { y }; + (y + y / 4 - y / 100 + y / 400 + T[(m - 1) as usize] + d) % 7 +} + +pub fn get_week_day(y: i32, m: i32, d: i32) -> String { + let day = doomsday(y, m, d); + let day_str = match day { + 0 => "Sunday", + 1 => "Monday", + 2 => "Tuesday", + 3 => "Wednesday", + 4 => "Thursday", + 5 => "Friday", + 6 => "Saturday", + _ => "Unknown", + }; + + day_str.to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn doomsday_test() { + assert_eq!(get_week_day(1990, 3, 21), "Wednesday"); + assert_eq!(get_week_day(2000, 8, 24), "Thursday"); + assert_eq!(get_week_day(2000, 10, 13), "Friday"); + assert_eq!(get_week_day(2001, 4, 18), "Wednesday"); + assert_eq!(get_week_day(2002, 3, 19), "Tuesday"); + } +} diff --git a/src/math/elliptic_curve.rs b/src/math/elliptic_curve.rs new file mode 100644 index 00000000000..641b759dc8b --- /dev/null +++ b/src/math/elliptic_curve.rs @@ -0,0 +1,405 @@ +use std::collections::HashSet; +use std::fmt; +use std::hash::{Hash, Hasher}; +use std::ops::{Add, Neg, Sub}; + +use crate::math::field::{Field, PrimeField}; +use crate::math::quadratic_residue::legendre_symbol; + +/// Elliptic curve defined by `y^2 = x^3 + Ax + B` over a prime field `F` of +/// characteristic != 2, 3 +/// +/// The coefficients of the elliptic curve are the constant parameters `A` and `B`. +/// +/// Points form an abelian group with the neutral element [`EllipticCurve::infinity`]. The points +/// are represented via affine coordinates ([`EllipticCurve::new`]) except for the points +/// at infinity ([`EllipticCurve::infinity`]). +/// +/// # Example +/// +/// ``` +/// use the_algorithms_rust::math::{EllipticCurve, PrimeField}; +/// type E = EllipticCurve, 1, 0>; +/// let P = E::new(0, 0).expect("not on curve E"); +/// assert_eq!(P + P, E::infinity()); +/// ``` +#[derive(Clone, Copy)] +pub struct EllipticCurve { + infinity: bool, + x: F, + y: F, +} + +impl EllipticCurve { + /// Point at infinity also the neutral element of the group + pub fn infinity() -> Self { + Self::check_invariants(); + Self { + infinity: true, + x: F::ZERO, + y: F::ZERO, + } + } + + /// Affine point + /// + /// + /// Return `None` if the coordinates are not on the curve + pub fn new(x: impl Into, y: impl Into) -> Option { + Self::check_invariants(); + let x = x.into(); + let y = y.into(); + if Self::contains(x, y) { + Some(Self { + infinity: false, + x, + y, + }) + } else { + None + } + } + + /// Return `true` if this is the point at infinity + pub fn is_infinity(&self) -> bool { + self.infinity + } + + /// The affine x-coordinate of the point + pub fn x(&self) -> &F { + &self.x + } + + /// The affine y-coordinate of the point + pub fn y(&self) -> &F { + &self.y + } + + /// The discrimant of the elliptic curve + pub const fn discriminant() -> i64 { + // Note: we can't return an element of F here, because it is not + // possible to declare a trait function as const (cf. + // ) + (-16 * (4 * A * A * A + 27 * B * B)) % (F::CHARACTERISTIC as i64) + } + + fn contains(x: F, y: F) -> bool { + y * y == x * x * x + x.integer_mul(A) + F::ONE.integer_mul(B) + } + + const fn check_invariants() { + assert!(F::CHARACTERISTIC != 2); + assert!(F::CHARACTERISTIC != 3); + assert!(Self::discriminant() != 0); + } +} + +/// Elliptic curve methods over a prime field +impl EllipticCurve, A, B> { + /// Naive calculation of points via enumeration + // TODO: Implement via generators + pub fn points() -> impl Iterator { + std::iter::once(Self::infinity()).chain( + PrimeField::elements() + .flat_map(|x| PrimeField::elements().filter_map(move |y| Self::new(x, y))), + ) + } + + /// Number of points on the elliptic curve over `F`, that is, `#E(F)` + pub fn cardinality() -> usize { + // TODO: implement counting for big P + Self::cardinality_counted_legendre() + } + + /// Number of points on the elliptic curve over `F`, that is, `#E(F)` + /// + /// We simply count the number of points for each x coordinate and sum them up. + /// For that, we first precompute the table of all squares in `F`. + /// + /// Time complexity: O(P)
+ /// Space complexity: O(P) + /// + /// Only fast for small fields. + pub fn cardinality_counted_table() -> usize { + let squares: HashSet<_> = PrimeField::

::elements().map(|x| x * x).collect(); + 1 + PrimeField::elements() + .map(|x| { + let y_square = x * x * x + x.integer_mul(A) + PrimeField::from_integer(B); + if y_square == PrimeField::ZERO { + 1 + } else if squares.contains(&y_square) { + 2 + } else { + 0 + } + }) + .sum::() + } + + /// Number of points on the elliptic curve over `F`, that is, `#E(F)` + /// + /// We count the number of points for each x coordinate by using the [Legendre symbol] _(X | + /// P)_: + /// + /// _1 + (x^3 + Ax + B | P),_ + /// + /// The total number of points is then: + /// + /// _#E(F) = 1 + P + Σ_x (x^3 + Ax + B | P)_ for _x_ in _F_. + /// + /// Time complexity: O(P)
+ /// Space complexity: O(1) + /// + /// Only fast for small fields. + /// + /// [Legendre symbol]: https://en.wikipedia.org/wiki/Legendre_symbol + pub fn cardinality_counted_legendre() -> usize { + let cardinality: i64 = 1 + + P as i64 + + PrimeField::

::elements() + .map(|x| { + let y_square = x * x * x + x.integer_mul(A) + PrimeField::from_integer(B); + let y_square_int = y_square.to_integer(); + legendre_symbol(y_square_int, P) + }) + .sum::(); + cardinality + .try_into() + .expect("invalid legendre cardinality") + } +} + +/// Group law +impl Add for EllipticCurve { + type Output = Self; + + fn add(self, p: Self) -> Self::Output { + if self.infinity { + p + } else if p.infinity { + self + } else if self.x == p.x && self.y == -p.y { + // mirrored + Self::infinity() + } else { + let slope = if self.x != p.x { + (self.y - p.y) / (self.x - p.x) + } else { + ((self.x * self.x).integer_mul(3) + F::from_integer(A)) / self.y.integer_mul(2) + }; + let x = slope * slope - self.x - p.x; + let y = -self.y + slope * (self.x - x); + Self::new(x, y).expect("elliptic curve group law failed") + } + } +} + +/// Inverse +impl Neg for EllipticCurve { + type Output = Self; + + fn neg(self) -> Self::Output { + if self.infinity { + self + } else { + Self::new(self.x, -self.y).expect("elliptic curves are x-axis symmetric") + } + } +} + +/// Difference +impl Sub for EllipticCurve { + type Output = Self; + + fn sub(self, p: Self) -> Self::Output { + self + (-p) + } +} + +/// Debug representation via projective coordinates +impl fmt::Debug for EllipticCurve { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.infinity { + f.write_str("(0:0:1)") + } else { + write!(f, "({:?}:{:?}:1)", self.x, self.y) + } + } +} + +/// Equality of the elliptic curve points (short-circuit at infinity) +impl PartialEq for EllipticCurve { + fn eq(&self, other: &Self) -> bool { + (self.infinity && other.infinity) + || (self.infinity == other.infinity && self.x == other.x && self.y == other.y) + } +} + +impl Eq for EllipticCurve {} + +impl Hash for EllipticCurve { + fn hash(&self, state: &mut H) { + if self.infinity { + state.write_u8(1); + F::ZERO.hash(state); + F::ZERO.hash(state); + } else { + state.write_u8(0); + self.x.hash(state); + self.y.hash(state); + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + use std::time::Instant; + + use super::*; + + #[test] + #[should_panic] + fn test_char_2_panic() { + EllipticCurve::, -1, 1>::infinity(); + } + + #[test] + #[should_panic] + fn test_char_3_panic() { + EllipticCurve::, -1, 1>::infinity(); + } + + #[test] + #[should_panic] + fn test_singular_panic() { + EllipticCurve::, 0, 0>::infinity(); + } + + #[test] + fn e_5_1_0_group_table() { + type F = PrimeField<5>; + type E = EllipticCurve; + + assert_eq!(E::points().count(), 4); + let [a, b, c, d] = [ + E::new(0, 0).unwrap(), + E::infinity(), + E::new(2, 0).unwrap(), + E::new(3, 0).unwrap(), + ]; + + assert_eq!(a + a, b); + assert_eq!(a + b, a); + assert_eq!(a + c, d); + assert_eq!(a + d, c); + assert_eq!(b + a, a); + assert_eq!(b + b, b); + assert_eq!(b + c, c); + assert_eq!(b + d, d); + assert_eq!(c + a, d); + assert_eq!(c + b, c); + assert_eq!(c + c, b); + assert_eq!(c + d, a); + assert_eq!(d + a, c); + assert_eq!(d + b, d); + assert_eq!(d + c, a); + assert_eq!(d + d, b); + } + + #[test] + fn group_law() { + fn test() { + type E = EllipticCurve, 1, 0>; + + let o = E::

::infinity(); + assert_eq!(-o, o); + + let points: Vec<_> = E::points().collect(); + for &p in &points { + assert_eq!(p + (-p), o); // inverse + assert_eq!((-p) + p, o); // inverse + assert_eq!(p - p, o); //inverse + assert_eq!(p + o, p); // neutral + assert_eq!(o + p, p); //neutral + + for &q in &points { + assert_eq!(p + q, q + p); // commutativity + + // associativity + for &s in &points { + assert_eq!((p + q) + s, p + (q + s)); + } + } + } + } + test::<5>(); + test::<7>(); + test::<11>(); + test::<13>(); + test::<17>(); + test::<19>(); + test::<23>(); + } + + #[test] + fn cardinality() { + fn test(expected: usize) { + type E = EllipticCurve, 1, 0>; + assert_eq!(E::

::cardinality(), expected); + assert_eq!(E::

::cardinality_counted_table(), expected); + assert_eq!(E::

::cardinality_counted_legendre(), expected); + } + test::<5>(4); + test::<7>(8); + test::<11>(12); + test::<13>(20); + test::<17>(16); + test::<19>(20); + test::<23>(24); + } + + #[test] + #[ignore = "slow test for measuring time"] + fn cardinality_perf() { + const P: u64 = 1000003; + type E = EllipticCurve, 1, 0>; + const EXPECTED: usize = 1000004; + + let now = Instant::now(); + assert_eq!(E::cardinality_counted_table(), EXPECTED); + println!("cardinality_counted_table : {:?}", now.elapsed()); + let now = Instant::now(); + assert_eq!(E::cardinality_counted_legendre(), EXPECTED); + println!("cardinality_counted_legendre : {:?}", now.elapsed()); + } + + #[test] + #[ignore = "slow test showing that cadinality is not yet feasible to compute for a large prime"] + fn cardinality_large_prime() { + const P: u64 = 2_u64.pow(63) - 25; // largest prime fitting into i64 + type E = EllipticCurve, 1, 0>; + const EXPECTED: usize = 9223372041295506260; + + let now = Instant::now(); + assert_eq!(E::cardinality(), EXPECTED); + println!("cardinality: {:?}", now.elapsed()); + } + + #[test] + fn test_points() { + type F = PrimeField<5>; + type E = EllipticCurve; + + let points: HashSet<_> = E::points().collect(); + let expected: HashSet<_> = [ + E::infinity(), + E::new(0, 0).unwrap(), + E::new(2, 0).unwrap(), + E::new(3, 0).unwrap(), + ] + .into_iter() + .collect(); + assert_eq!(points, expected); + } +} diff --git a/src/math/euclidean_distance.rs b/src/math/euclidean_distance.rs new file mode 100644 index 00000000000..0be7459f042 --- /dev/null +++ b/src/math/euclidean_distance.rs @@ -0,0 +1,41 @@ +// Author : cyrixninja +// Calculate the Euclidean distance between two vectors +// Wikipedia : https://en.wikipedia.org/wiki/Euclidean_distance + +pub fn euclidean_distance(vector_1: &Vector, vector_2: &Vector) -> f64 { + // Calculate the Euclidean distance using the provided vectors. + let squared_sum: f64 = vector_1 + .iter() + .zip(vector_2.iter()) + .map(|(&a, &b)| (a - b).powi(2)) + .sum(); + + squared_sum.sqrt() +} + +type Vector = Vec; + +#[cfg(test)] +mod tests { + use super::*; + + // Define a test function for the euclidean_distance function. + #[test] + fn test_euclidean_distance() { + // First test case: 2D vectors + let vec1_2d = vec![1.0, 2.0]; + let vec2_2d = vec![4.0, 6.0]; + + // Calculate the Euclidean distance + let result_2d = euclidean_distance(&vec1_2d, &vec2_2d); + assert_eq!(result_2d, 5.0); + + // Second test case: 4D vectors + let vec1_4d = vec![1.0, 2.0, 3.0, 4.0]; + let vec2_4d = vec![5.0, 6.0, 7.0, 8.0]; + + // Calculate the Euclidean distance + let result_4d = euclidean_distance(&vec1_4d, &vec2_4d); + assert_eq!(result_4d, 8.0); + } +} diff --git a/src/math/exponential_linear_unit.rs b/src/math/exponential_linear_unit.rs new file mode 100644 index 00000000000..f6143c97881 --- /dev/null +++ b/src/math/exponential_linear_unit.rs @@ -0,0 +1,60 @@ +//! # Exponential Linear Unit (ELU) Function +//! +//! The `exponential_linear_unit` function computes the Exponential Linear Unit (ELU) values of a given vector +//! of f64 numbers with a specified alpha parameter. +//! +//! The ELU activation function is commonly used in neural networks as an alternative to the Leaky ReLU function. +//! It introduces a small negative slope (controlled by the alpha parameter) for the negative input values and has +//! an exponential growth for positive values, which can help mitigate the vanishing gradient problem. +//! +//! ## Formula +//! +//! For a given input vector `x` and an alpha parameter `alpha`, the ELU function computes the output +//! `y` as follows: +//! +//! `y_i = { x_i if x_i >= 0, alpha * (e^x_i - 1) if x_i < 0 }` +//! +//! Where `e` is the mathematical constant (approximately 2.71828). +//! +//! ## Exponential Linear Unit (ELU) Function Implementation +//! +//! This implementation takes a reference to a vector of f64 values and an alpha parameter, and returns a new +//! vector with the ELU transformation applied to each element. The input vector is not altered. +//! + +use std::f64::consts::E; + +pub fn exponential_linear_unit(vector: &Vec, alpha: f64) -> Vec { + let mut _vector = vector.to_owned(); + + for value in &mut _vector { + if value < &mut 0. { + *value *= alpha * (E.powf(*value) - 1.); + } + } + + _vector +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_exponential_linear_unit() { + let test_vector = vec![-10., 2., -3., 4., -5., 10., 0.05]; + let alpha = 0.01; + assert_eq!( + exponential_linear_unit(&test_vector, alpha), + vec![ + 0.09999546000702375, + 2.0, + 0.028506387948964082, + 4.0, + 0.049663102650045726, + 10.0, + 0.05 + ] + ); + } +} diff --git a/src/math/extended_euclidean_algorithm.rs b/src/math/extended_euclidean_algorithm.rs new file mode 100644 index 00000000000..ff9d87ad552 --- /dev/null +++ b/src/math/extended_euclidean_algorithm.rs @@ -0,0 +1,37 @@ +fn update_step(a: &mut i32, old_a: &mut i32, quotient: i32) { + let temp = *a; + *a = *old_a - quotient * temp; + *old_a = temp; +} + +pub fn extended_euclidean_algorithm(a: i32, b: i32) -> (i32, i32, i32) { + let (mut old_r, mut rem) = (a, b); + let (mut old_s, mut coeff_s) = (1, 0); + let (mut old_t, mut coeff_t) = (0, 1); + + while rem != 0 { + let quotient = old_r / rem; + + update_step(&mut rem, &mut old_r, quotient); + update_step(&mut coeff_s, &mut old_s, quotient); + update_step(&mut coeff_t, &mut old_t, quotient); + } + + (old_r, old_s, old_t) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basic() { + assert_eq!(extended_euclidean_algorithm(101, 13), (1, 4, -31)); + assert_eq!(extended_euclidean_algorithm(123, 19), (1, -2, 13)); + assert_eq!(extended_euclidean_algorithm(25, 36), (1, 13, -9)); + assert_eq!(extended_euclidean_algorithm(69, 54), (3, -7, 9)); + assert_eq!(extended_euclidean_algorithm(55, 79), (1, 23, -16)); + assert_eq!(extended_euclidean_algorithm(33, 44), (11, -1, 1)); + assert_eq!(extended_euclidean_algorithm(50, 70), (10, 3, -2)); + } +} diff --git a/src/math/factorial.rs b/src/math/factorial.rs new file mode 100644 index 00000000000..b6fbc831450 --- /dev/null +++ b/src/math/factorial.rs @@ -0,0 +1,68 @@ +use num_bigint::BigUint; +use num_traits::One; +#[allow(unused_imports)] +use std::str::FromStr; + +pub fn factorial(number: u64) -> u64 { + // Base cases: 0! and 1! are both equal to 1 + if number == 0 || number == 1 { + 1 + } else { + // Calculate factorial using the product of the range from 2 to the given number (inclusive) + (2..=number).product() + } +} + +pub fn factorial_recursive(n: u64) -> u64 { + // Base cases: 0! and 1! are both equal to 1 + if n == 0 || n == 1 { + 1 + } else { + // Calculate factorial recursively by multiplying the current number with factorial of (n - 1) + n * factorial_recursive(n - 1) + } +} + +pub fn factorial_bigmath(num: u32) -> BigUint { + let mut result: BigUint = One::one(); + for i in 1..=num { + result *= i; + } + result +} + +// Module for tests +#[cfg(test)] +mod tests { + use super::*; + + // Test cases for the iterative factorial function + #[test] + fn test_factorial() { + assert_eq!(factorial(0), 1); + assert_eq!(factorial(1), 1); + assert_eq!(factorial(6), 720); + assert_eq!(factorial(10), 3628800); + assert_eq!(factorial(20), 2432902008176640000); + } + + // Test cases for the recursive factorial function + #[test] + fn test_factorial_recursive() { + assert_eq!(factorial_recursive(0), 1); + assert_eq!(factorial_recursive(1), 1); + assert_eq!(factorial_recursive(6), 720); + assert_eq!(factorial_recursive(10), 3628800); + assert_eq!(factorial_recursive(20), 2432902008176640000); + } + + #[test] + fn basic_factorial() { + assert_eq!(factorial_bigmath(10), BigUint::from_str("3628800").unwrap()); + assert_eq!( + factorial_bigmath(50), + BigUint::from_str("30414093201713378043612608166064768844377641568960512000000000000") + .unwrap() + ); + } +} diff --git a/src/math/factors.rs b/src/math/factors.rs new file mode 100644 index 00000000000..5131642dffa --- /dev/null +++ b/src/math/factors.rs @@ -0,0 +1,46 @@ +/// Factors are natural numbers which can divide a given natural number to give a remainder of zero +/// Hence 1, 2, 3 and 6 are all factors of 6, as they divide the number 6 completely, +/// leaving no remainder. +/// This function is to list out all the factors of a given number 'n' + +pub fn factors(number: u64) -> Vec { + let mut factors: Vec = Vec::new(); + + for i in 1..=((number as f64).sqrt() as u64) { + if number % i == 0 { + factors.push(i); + if i != number / i { + factors.push(number / i); + } + } + } + + factors.sort(); + factors +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn prime_number() { + assert_eq!(vec![1, 59], factors(59)); + } + + #[test] + fn highly_composite_number() { + assert_eq!( + vec![ + 1, 2, 3, 4, 5, 6, 8, 9, 10, 12, 15, 18, 20, 24, 30, 36, 40, 45, 60, 72, 90, 120, + 180, 360 + ], + factors(360) + ); + } + + #[test] + fn composite_number() { + assert_eq!(vec![1, 3, 23, 69], factors(69)); + } +} diff --git a/src/math/fast_fourier_transform.rs b/src/math/fast_fourier_transform.rs new file mode 100644 index 00000000000..6ed81e7db6a --- /dev/null +++ b/src/math/fast_fourier_transform.rs @@ -0,0 +1,220 @@ +use std::ops::{Add, Mul, MulAssign, Sub}; + +// f64 complex +#[derive(Clone, Copy, Debug)] +pub struct Complex64 { + pub re: f64, + pub im: f64, +} + +impl Complex64 { + #[inline] + pub fn new(re: f64, im: f64) -> Self { + Self { re, im } + } + + #[inline] + pub fn square_norm(&self) -> f64 { + self.re * self.re + self.im * self.im + } + + #[inline] + pub fn norm(&self) -> f64 { + self.square_norm().sqrt() + } + + #[inline] + pub fn inverse(&self) -> Complex64 { + let nrm = self.square_norm(); + Complex64 { + re: self.re / nrm, + im: -self.im / nrm, + } + } +} + +impl Default for Complex64 { + #[inline] + fn default() -> Self { + Self { re: 0.0, im: 0.0 } + } +} + +impl Add for Complex64 { + type Output = Complex64; + + #[inline] + fn add(self, other: Complex64) -> Complex64 { + Complex64 { + re: self.re + other.re, + im: self.im + other.im, + } + } +} + +impl Sub for Complex64 { + type Output = Complex64; + + #[inline] + fn sub(self, other: Complex64) -> Complex64 { + Complex64 { + re: self.re - other.re, + im: self.im - other.im, + } + } +} + +impl Mul for Complex64 { + type Output = Complex64; + + #[inline] + fn mul(self, other: Complex64) -> Complex64 { + Complex64 { + re: self.re * other.re - self.im * other.im, + im: self.re * other.im + self.im * other.re, + } + } +} + +impl MulAssign for Complex64 { + #[inline] + fn mul_assign(&mut self, other: Complex64) { + let tmp = self.re * other.im + self.im * other.re; + self.re = self.re * other.re - self.im * other.im; + self.im = tmp; + } +} + +pub fn fast_fourier_transform_input_permutation(length: usize) -> Vec { + let mut result = Vec::new(); + result.reserve_exact(length); + for i in 0..length { + result.push(i); + } + let mut reverse = 0_usize; + let mut position = 1_usize; + while position < length { + let mut bit = length >> 1; + while bit & reverse != 0 { + reverse ^= bit; + bit >>= 1; + } + reverse ^= bit; + // This is equivalent to adding 1 to a reversed number + if position < reverse { + // Only swap each element once + result.swap(position, reverse); + } + position += 1; + } + result +} + +pub fn fast_fourier_transform(input: &[f64], input_permutation: &[usize]) -> Vec { + let n = input.len(); + let mut result = Vec::new(); + result.reserve_exact(n); + for position in input_permutation { + result.push(Complex64::new(input[*position], 0.0)); + } + let mut segment_length = 1_usize; + while segment_length < n { + segment_length <<= 1; + let angle: f64 = std::f64::consts::TAU / segment_length as f64; + let w_len = Complex64::new(angle.cos(), angle.sin()); + for segment_start in (0..n).step_by(segment_length) { + let mut w = Complex64::new(1.0, 0.0); + for position in segment_start..(segment_start + segment_length / 2) { + let a = result[position]; + let b = result[position + segment_length / 2] * w; + result[position] = a + b; + result[position + segment_length / 2] = a - b; + w *= w_len; + } + } + } + result +} + +pub fn inverse_fast_fourier_transform( + input: &[Complex64], + input_permutation: &[usize], +) -> Vec { + let n = input.len(); + let mut result = Vec::new(); + result.reserve_exact(n); + for position in input_permutation { + result.push(input[*position]); + } + let mut segment_length = 1_usize; + while segment_length < n { + segment_length <<= 1; + let angle: f64 = -std::f64::consts::TAU / segment_length as f64; + let w_len = Complex64::new(angle.cos(), angle.sin()); + for segment_start in (0..n).step_by(segment_length) { + let mut w = Complex64::new(1.0, 0.0); + for position in segment_start..(segment_start + segment_length / 2) { + let a = result[position]; + let b = result[position + segment_length / 2] * w; + result[position] = a + b; + result[position + segment_length / 2] = a - b; + w *= w_len; + } + } + } + let scale = 1.0 / n as f64; + result.iter().map(|x| x.re * scale).collect() +} + +#[cfg(test)] +mod tests { + use super::*; + fn almost_equal(a: f64, b: f64, epsilon: f64) -> bool { + (a - b).abs() < epsilon + } + + const EPSILON: f64 = 1e-6; + + #[test] + fn small_polynomial_returns_self() { + let polynomial = vec![1.0f64, 1.0, 0.0, 2.5]; + let permutation = fast_fourier_transform_input_permutation(polynomial.len()); + let fft = fast_fourier_transform(&polynomial, &permutation); + let ifft = inverse_fast_fourier_transform(&fft, &permutation); + for (x, y) in ifft.iter().zip(polynomial.iter()) { + assert!(almost_equal(*x, *y, EPSILON)); + } + } + + #[test] + fn square_small_polynomial() { + let mut polynomial = vec![1.0f64, 1.0, 0.0, 2.0]; + polynomial.append(&mut vec![0.0; 4]); + let permutation = fast_fourier_transform_input_permutation(polynomial.len()); + let mut fft = fast_fourier_transform(&polynomial, &permutation); + fft.iter_mut().for_each(|num| *num *= *num); + let ifft = inverse_fast_fourier_transform(&fft, &permutation); + let expected = [1.0, 2.0, 1.0, 4.0, 4.0, 0.0, 4.0, 0.0, 0.0]; + for (x, y) in ifft.iter().zip(expected.iter()) { + assert!(almost_equal(*x, *y, EPSILON)); + } + } + + #[test] + #[ignore] + fn square_big_polynomial() { + // This test case takes ~1050ms on my machine in unoptimized mode, + // but it takes ~70ms in release mode. + let n = 1 << 17; // ~100_000 + let mut polynomial = vec![1.0f64; n]; + polynomial.append(&mut vec![0.0f64; n]); + let permutation = fast_fourier_transform_input_permutation(polynomial.len()); + let mut fft = fast_fourier_transform(&polynomial, &permutation); + fft.iter_mut().for_each(|num| *num *= *num); + let ifft = inverse_fast_fourier_transform(&fft, &permutation); + let expected = (0..((n << 1) - 1)).map(|i| std::cmp::min(i + 1, (n << 1) - 1 - i) as f64); + for (&x, y) in ifft.iter().zip(expected) { + assert!(almost_equal(x, y, EPSILON)); + } + } +} diff --git a/src/math/fast_power.rs b/src/math/fast_power.rs new file mode 100644 index 00000000000..c76466d0c4e --- /dev/null +++ b/src/math/fast_power.rs @@ -0,0 +1,29 @@ +/// fast_power returns the result of base^power mod modulus +pub fn fast_power(mut base: usize, mut power: usize, modulus: usize) -> usize { + assert!(base >= 1); + + let mut res = 1; + while power > 0 { + if power & 1 == 1 { + res = (res * base) % modulus; + } + base = (base * base) % modulus; + power >>= 1; + } + res +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test() { + const MOD: usize = 1000000007; + assert_eq!(fast_power(2, 1, MOD), 2); + assert_eq!(fast_power(2, 2, MOD), 4); + assert_eq!(fast_power(2, 4, MOD), 16); + assert_eq!(fast_power(3, 4, MOD), 81); + assert_eq!(fast_power(2, 100, MOD), 976371285); + } +} diff --git a/src/math/faster_perfect_numbers.rs b/src/math/faster_perfect_numbers.rs new file mode 100644 index 00000000000..28a0f857f71 --- /dev/null +++ b/src/math/faster_perfect_numbers.rs @@ -0,0 +1,40 @@ +use super::{mersenne_primes::is_mersenne_prime, prime_numbers::prime_numbers}; +use std::convert::TryInto; + +/* + Generates a list of perfect numbers till `num` using the Lucas Lehmer test algorithm. + url : https://en.wikipedia.org/wiki/Lucas%E2%80%93Lehmer_primality_test +*/ +pub fn generate_perfect_numbers(num: usize) -> Vec { + let mut results = Vec::new(); + let prime_limit = get_prime_limit(num); + + for i in prime_numbers(prime_limit).iter() { + let prime = *i; + if is_mersenne_prime(prime) { + results.push( + (2_usize.pow(prime.try_into().unwrap()) - 1) + * (2_usize.pow((prime - 1).try_into().unwrap())), + ); + } + } + results.into_iter().filter(|x| *x <= num).collect() +} + +// Gets an approximate limit for the generate_perfect_numbers function +fn get_prime_limit(num: usize) -> usize { + (((num * 8 + 1) as f64).log2() as usize) / 2_usize +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn perfect_numbers_till_n() { + let n = 335564540; + assert_eq!(generate_perfect_numbers(n), [6, 28, 496, 8128, 33550336]); + assert_eq!(generate_perfect_numbers(40), [6, 28]); + assert_eq!(generate_perfect_numbers(0), []); + } +} diff --git a/src/math/field.rs b/src/math/field.rs new file mode 100644 index 00000000000..9fb26965cd6 --- /dev/null +++ b/src/math/field.rs @@ -0,0 +1,333 @@ +use core::fmt; +use std::hash::{Hash, Hasher}; +use std::ops::{Add, Div, Mul, Neg, Sub}; + +/// A field +/// +/// +pub trait Field: + Neg + + Add + + Sub + + Mul + + Div + + Eq + + Copy + + fmt::Debug +{ + const CHARACTERISTIC: u64; + const ZERO: Self; + const ONE: Self; + + /// Multiplicative inverse + fn inverse(self) -> Self; + + /// Z-mod structure + fn integer_mul(self, a: i64) -> Self; + fn from_integer(a: i64) -> Self { + Self::ONE.integer_mul(a) + } + + /// Iterate over all elements in this field + /// + /// The iterator finishes only for finite fields. + type ElementsIter: Iterator; + fn elements() -> Self::ElementsIter; +} + +/// Prime field of order `P`, that is, finite field `GF(P) = ℤ/Pℤ` +/// +/// Only primes `P` <= 2^63 - 25 are supported, because the field elements are represented by `i64`. +// TODO: Extend field implementation for any prime `P` by e.g. using u32 blocks. +#[derive(Clone, Copy)] +pub struct PrimeField { + a: i64, +} + +impl PrimeField

{ + /// Reduces the representation into the range [0, p) + fn reduce(self) -> Self { + let Self { a } = self; + let p: i64 = P.try_into().expect("module not fitting into signed 64 bit"); + let a = a.rem_euclid(p); + assert!(a >= 0); + Self { a } + } + + /// Returns the positive integer in the range [0, p) representing this element + pub fn to_integer(&self) -> u64 { + self.reduce().a as u64 + } +} + +impl From for PrimeField

{ + fn from(a: i64) -> Self { + Self { a } + } +} + +impl PartialEq for PrimeField

{ + fn eq(&self, other: &Self) -> bool { + self.reduce().a == other.reduce().a + } +} + +impl Eq for PrimeField

{} + +impl Neg for PrimeField

{ + type Output = Self; + + fn neg(self) -> Self::Output { + Self { a: -self.a } + } +} + +impl Add for PrimeField

{ + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self { + a: self.a.checked_add(rhs.a).unwrap_or_else(|| { + let x = self.reduce(); + let y = rhs.reduce(); + x.a + y.a + }), + } + } +} + +impl Sub for PrimeField

{ + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self { + a: self.a.checked_sub(rhs.a).unwrap_or_else(|| { + let x = self.reduce(); + let y = rhs.reduce(); + x.a - y.a + }), + } + } +} + +impl Mul for PrimeField

{ + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + Self { + a: self.a.checked_mul(rhs.a).unwrap_or_else(|| { + let x = self.reduce(); + let y = rhs.reduce(); + x.a * y.a + }), + } + } +} + +impl Div for PrimeField

{ + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: Self) -> Self::Output { + self * rhs.inverse() + } +} + +impl fmt::Debug for PrimeField

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let x = self.reduce(); + write!(f, "{}", x.reduce().a) + } +} + +impl Field for PrimeField

{ + const CHARACTERISTIC: u64 = P; + const ZERO: Self = Self { a: 0 }; + const ONE: Self = Self { a: 1 }; + + fn inverse(self) -> Self { + assert_ne!(self.a, 0); + Self { + a: mod_inverse( + self.a, + P.try_into().expect("module not fitting into signed 64 bit"), + ), + } + } + + fn integer_mul(self, mut n: i64) -> Self { + if n == 0 { + return Self::ZERO; + } + let mut x = self; + if n < 0 { + x = -x; + n = -n; + } + let mut y = Self::ZERO; + while n > 1 { + if n % 2 == 1 { + y = y + x; + n -= 1; + } + x = x + x; + n /= 2; + } + x + y + } + + type ElementsIter = PrimeFieldElementsIter

; + + fn elements() -> Self::ElementsIter { + PrimeFieldElementsIter::default() + } +} + +#[derive(Default)] +pub struct PrimeFieldElementsIter { + x: i64, +} + +impl Iterator for PrimeFieldElementsIter

{ + type Item = PrimeField

; + + fn next(&mut self) -> Option { + if self.x as u64 == P { + None + } else { + let res = PrimeField::from_integer(self.x); + self.x += 1; + Some(res) + } + } +} + +impl Hash for PrimeField

{ + fn hash(&self, state: &mut H) { + let Self { a } = self.reduce(); + state.write_i64(a); + } +} + +// TODO: should we use extended_euclidean_algorithm adjusted to i64? +fn mod_inverse(mut a: i64, mut b: i64) -> i64 { + let mut s = 1; + let mut t = 0; + let step = |x, y, q| (y, x - q * y); + while b != 0 { + let q = a / b; + (a, b) = step(a, b, q); + (s, t) = step(s, t, q); + } + assert!(a == 1 || a == -1); + a * s +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use super::*; + + #[test] + fn test_field_elements() { + fn test() { + let expected: HashSet> = (0..P as i64).map(Into::into).collect(); + for gen in 1..P - 1 { + // every field element != 0 generates the whole field additively + let gen = PrimeField::from(gen as i64); + let mut generated: HashSet> = std::iter::once(gen).collect(); + let mut x = gen; + for _ in 0..P { + x = x + gen; + generated.insert(x); + } + assert_eq!(generated, expected); + } + } + test::<5>(); + test::<7>(); + test::<11>(); + test::<13>(); + test::<17>(); + test::<19>(); + test::<23>(); + test::<71>(); + test::<101>(); + } + + #[test] + fn large_prime_field() { + const P: u64 = 2_u64.pow(63) - 25; // largest prime fitting into i64 + type F = PrimeField

; + let x = F::from(P as i64 - 1); + let y = x.inverse(); + assert_eq!(x * y, F::ONE); + } + + #[test] + fn inverse() { + fn test() { + for x in -7..7 { + let x = PrimeField::

::from(x); + if x != PrimeField::ZERO { + // multiplicative + assert_eq!(x.inverse() * x, PrimeField::ONE); + assert_eq!(x * x.inverse(), PrimeField::ONE); + assert_eq!((x.inverse().a * x.a).rem_euclid(P as i64), 1); + assert_eq!(x / x, PrimeField::ONE); + } + // additive + assert_eq!(x + (-x), PrimeField::ZERO); + assert_eq!((-x) + x, PrimeField::ZERO); + assert_eq!(x - x, PrimeField::ZERO); + } + } + test::<5>(); + test::<7>(); + test::<11>(); + test::<13>(); + test::<17>(); + test::<19>(); + test::<23>(); + test::<71>(); + test::<101>(); + } + + #[test] + fn test_mod_inverse() { + assert_eq!(mod_inverse(-6, 7), 1); + assert_eq!(mod_inverse(-5, 7), -3); + assert_eq!(mod_inverse(-4, 7), -2); + assert_eq!(mod_inverse(-3, 7), 2); + assert_eq!(mod_inverse(-2, 7), 3); + assert_eq!(mod_inverse(-1, 7), -1); + assert_eq!(mod_inverse(1, 7), 1); + assert_eq!(mod_inverse(2, 7), -3); + assert_eq!(mod_inverse(3, 7), -2); + assert_eq!(mod_inverse(4, 7), 2); + assert_eq!(mod_inverse(5, 7), 3); + assert_eq!(mod_inverse(6, 7), -1); + } + + #[test] + fn integer_mul() { + type F = PrimeField<23>; + for x in 0..23 { + let x = F { a: x }; + for n in -7..7 { + assert_eq!(x.integer_mul(n), F { a: n * x.a }); + } + } + } + + #[test] + fn from_integer() { + type F = PrimeField<23>; + for x in -100..100 { + assert_eq!(F::from_integer(x), F { a: x }); + } + assert_eq!(F::from(0), F::ZERO); + assert_eq!(F::from(1), F::ONE); + } +} diff --git a/src/math/frizzy_number.rs b/src/math/frizzy_number.rs new file mode 100644 index 00000000000..22f154a412c --- /dev/null +++ b/src/math/frizzy_number.rs @@ -0,0 +1,61 @@ +/// This Rust program calculates the n-th Frizzy number for a given base. +/// A Frizzy number is defined as the n-th number that is a sum of powers +/// of the given base, with the powers corresponding to the binary representation +/// of n. + +/// The `get_nth_frizzy` function takes two arguments: +/// * `base` - The base whose n-th sum of powers is required. +/// * `n` - Index from ascending order of the sum of powers of the base. + +/// It returns the n-th sum of powers of the base. + +/// # Example +/// To find the Frizzy number with a base of 3 and n equal to 4: +/// - Ascending order of sums of powers of 3: 3^0 = 1, 3^1 = 3, 3^1 + 3^0 = 4, 3^2 + 3^0 = 9. +/// - The answer is 9. +/// +/// # Arguments +/// * `base` - The base whose n-th sum of powers is required. +/// * `n` - Index from ascending order of the sum of powers of the base. +/// +/// # Returns +/// The n-th sum of powers of the base. + +pub fn get_nth_frizzy(base: i32, mut n: i32) -> f64 { + let mut final1 = 0.0; + let mut i = 0; + while n > 0 { + final1 += (base.pow(i) as f64) * ((n % 2) as f64); + i += 1; + n /= 2; + } + final1 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_nth_frizzy() { + // Test case 1: base = 3, n = 4 + // 3^2 + 3^0 = 9 + assert_eq!(get_nth_frizzy(3, 4), 9.0); + + // Test case 2: base = 2, n = 5 + // 2^2 + 2^0 = 5 + assert_eq!(get_nth_frizzy(2, 5), 5.0); + + // Test case 3: base = 4, n = 3 + // 4^1 + 4^0 = 5 + assert_eq!(get_nth_frizzy(4, 3), 5.0); + + // Test case 4: base = 5, n = 2 + // 5^1 + 5^0 = 5 + assert_eq!(get_nth_frizzy(5, 2), 5.0); + + // Test case 5: base = 6, n = 1 + // 6^0 = 1 + assert_eq!(get_nth_frizzy(6, 1), 1.0); + } +} diff --git a/src/math/gaussian_elimination.rs b/src/math/gaussian_elimination.rs new file mode 100644 index 00000000000..1370b15ddd7 --- /dev/null +++ b/src/math/gaussian_elimination.rs @@ -0,0 +1,79 @@ +// Gaussian Elimination of Quadratic Matrices +// Takes an augmented matrix as input, returns vector of results +// Wikipedia reference: augmented matrix: https://en.wikipedia.org/wiki/Augmented_matrix +// Wikipedia reference: algorithm: https://en.wikipedia.org/wiki/Gaussian_elimination + +pub fn gaussian_elimination(matrix: &mut [Vec]) -> Vec { + let size = matrix.len(); + assert_eq!(size, matrix[0].len() - 1); + + for i in 0..size - 1 { + for j in i..size - 1 { + echelon(matrix, i, j); + } + } + + for i in (1..size).rev() { + eliminate(matrix, i); + } + + // Disable cargo clippy warnings about needless range loops. + // Checking the diagonal like this is simpler than any alternative. + #[allow(clippy::needless_range_loop)] + for i in 0..size { + if matrix[i][i] == 0f32 { + println!("Infinitely many solutions"); + } + } + + let mut result: Vec = vec![0f32; size]; + for i in 0..size { + result[i] = matrix[i][size] / matrix[i][i]; + } + result +} + +fn echelon(matrix: &mut [Vec], i: usize, j: usize) { + let size = matrix.len(); + if matrix[i][i] == 0f32 { + } else { + let factor = matrix[j + 1][i] / matrix[i][i]; + (i..=size).for_each(|k| { + matrix[j + 1][k] -= factor * matrix[i][k]; + }); + } +} + +fn eliminate(matrix: &mut [Vec], i: usize) { + let size = matrix.len(); + if matrix[i][i] == 0f32 { + } else { + for j in (1..=i).rev() { + let factor = matrix[j - 1][i] / matrix[i][i]; + for k in (0..=size).rev() { + matrix[j - 1][k] -= factor * matrix[i][k]; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::gaussian_elimination; + + #[test] + fn test_gauss() { + let mut matrix: Vec> = vec![ + vec![1.5, 2.0, 1.0, -1.0, -2.0, 1.0, 1.0], + vec![3.0, 3.0, -1.0, 16.0, 18.0, 1.0, 1.0], + vec![1.0, 1.0, 3.0, -2.0, -6.0, 1.0, 1.0], + vec![1.0, 1.0, 99.0, 19.0, 2.0, 1.0, 1.0], + vec![1.0, -2.0, 16.0, 1.0, 9.0, 10.0, 1.0], + vec![1.0, 3.0, 1.0, -5.0, 1.0, 1.0, 95.0], + ]; + let result = vec![ + -264.05893, 159.63196, -6.156921, 35.310387, -18.806696, 81.67839, + ]; + assert_eq!(gaussian_elimination(&mut matrix), result); + } +} diff --git a/src/math/gaussian_error_linear_unit.rs b/src/math/gaussian_error_linear_unit.rs new file mode 100644 index 00000000000..c7e52cca6a4 --- /dev/null +++ b/src/math/gaussian_error_linear_unit.rs @@ -0,0 +1,58 @@ +//! # Gaussian Error Linear Unit (GELU) Function +//! +//! The `gaussian_error_linear_unit` function computes the Gaussian Error Linear Unit (GELU) values of a given f64 number or a vector of f64 numbers. +//! +//! GELU is an activation function used in neural networks that introduces a smooth approximation of the rectifier function (ReLU). +//! It is defined using the Gaussian cumulative distribution function and can help mitigate the vanishing gradient problem. +//! +//! ## Formula +//! +//! For a given input value `x`, the GELU function computes the output `y` as follows: +//! +//! `y = 0.5 * (1.0 + tanh(2.0 / sqrt(π) * (x + 0.044715 * x^3)))` +//! +//! Where `tanh` is the hyperbolic tangent function and `π` is the mathematical constant (approximately 3.14159). +//! +//! ## Gaussian Error Linear Unit (GELU) Function Implementation +//! +//! This implementation takes either a single f64 value or a reference to a vector of f64 values and returns the GELU transformation applied to each element. The input values are not altered. +//! +use std::f64::consts::E; +use std::f64::consts::PI; + +fn tanh(vector: f64) -> f64 { + (2. / (1. + E.powf(-2. * vector.to_owned()))) - 1. +} + +pub fn gaussian_error_linear_unit(vector: &Vec) -> Vec { + let mut gelu_vec = vector.to_owned(); + for value in &mut gelu_vec { + *value = *value + * 0.5 + * (1. + tanh(f64::powf(2. / PI, 0.5) * (*value + 0.044715 * value.powf(3.)))); + } + + gelu_vec +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_gaussian_error_linear_unit() { + let test_vector = vec![-10., 2., -3., 4., -5., 10., 0.05]; + assert_eq!( + gaussian_error_linear_unit(&test_vector), + vec![ + -0.0, + 1.9545976940877752, + -0.0036373920817729943, + 3.9999297540518075, + -2.2917961972623857e-7, + 10.0, + 0.025996938238622008 + ] + ); + } +} diff --git a/src/math/gcd_of_n_numbers.rs b/src/math/gcd_of_n_numbers.rs new file mode 100644 index 00000000000..b06a4f9e126 --- /dev/null +++ b/src/math/gcd_of_n_numbers.rs @@ -0,0 +1,29 @@ +/// returns the greatest common divisor of n numbers +pub fn gcd(nums: &[usize]) -> usize { + if nums.len() == 1 { + return nums[0]; + } + let a = nums[0]; + let b = gcd(&nums[1..]); + gcd_of_two_numbers(a, b) +} + +fn gcd_of_two_numbers(a: usize, b: usize) -> usize { + if b == 0 { + return a; + } + gcd_of_two_numbers(b, a % b) +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn it_works() { + assert_eq!(gcd(&[1, 2, 3, 4, 5]), 1); + assert_eq!(gcd(&[2, 4, 6, 8, 10]), 2); + assert_eq!(gcd(&[3, 6, 9, 12, 15]), 3); + assert_eq!(gcd(&[10]), 10); + assert_eq!(gcd(&[21, 110]), 1); + } +} diff --git a/src/math/geometric_series.rs b/src/math/geometric_series.rs new file mode 100644 index 00000000000..e9631f09ff5 --- /dev/null +++ b/src/math/geometric_series.rs @@ -0,0 +1,49 @@ +// Author : cyrixninja +// Wikipedia : https://en.wikipedia.org/wiki/Geometric_series +// Calculate a geometric series. + +pub fn geometric_series(nth_term: f64, start_term_a: f64, common_ratio_r: f64) -> Vec { + let mut series = Vec::new(); + let mut multiple = 1.0; + + for _ in 0..(nth_term as i32) { + series.push(start_term_a * multiple); + multiple *= common_ratio_r; + } + + series +} + +#[cfg(test)] +mod tests { + use super::*; + + fn assert_approx_eq(a: f64, b: f64) { + let epsilon = 1e-10; + assert!((a - b).abs() < epsilon, "Expected {a}, found {b}"); + } + + #[test] + fn test_geometric_series() { + let result = geometric_series(4.0, 2.0, 2.0); + assert_eq!(result.len(), 4); + assert_approx_eq(result[0], 2.0); + assert_approx_eq(result[1], 4.0); + assert_approx_eq(result[2], 8.0); + assert_approx_eq(result[3], 16.0); + + let result = geometric_series(4.1, 2.1, 2.1); + assert_eq!(result.len(), 4); + assert_approx_eq(result[0], 2.1); + assert_approx_eq(result[1], 4.41); + assert_approx_eq(result[2], 9.261); + assert_approx_eq(result[3], 19.4481); + + let result = geometric_series(4.0, -2.0, 2.0); + assert_eq!(result.len(), 4); + assert_approx_eq(result[0], -2.0); + assert_approx_eq(result[1], -4.0); + assert_approx_eq(result[2], -8.0); + assert_approx_eq(result[3], -16.0); + } +} diff --git a/src/math/greatest_common_divisor.rs b/src/math/greatest_common_divisor.rs new file mode 100644 index 00000000000..8c88434bade --- /dev/null +++ b/src/math/greatest_common_divisor.rs @@ -0,0 +1,116 @@ +/// Greatest Common Divisor. +/// +/// greatest_common_divisor(num1, num2) returns the greatest number of num1 and num2. +/// +/// Wikipedia reference: https://en.wikipedia.org/wiki/Greatest_common_divisor +/// gcd(a, b) = gcd(a, -b) = gcd(-a, b) = gcd(-a, -b) by definition of divisibility +use std::cmp::{max, min}; + +pub fn greatest_common_divisor_recursive(a: i64, b: i64) -> i64 { + if a == 0 { + b.abs() + } else { + greatest_common_divisor_recursive(b % a, a) + } +} + +pub fn greatest_common_divisor_iterative(mut a: i64, mut b: i64) -> i64 { + while a != 0 { + let remainder = b % a; + b = a; + a = remainder; + } + b.abs() +} + +pub fn greatest_common_divisor_stein(a: u64, b: u64) -> u64 { + match ((a, b), (a & 1, b & 1)) { + // gcd(x, x) = x + ((x, y), _) if x == y => y, + // gcd(x, 0) = gcd(0, x) = x + ((0, x), _) | ((x, 0), _) => x, + // gcd(x, y) = gcd(x / 2, y) if x is even and y is odd + // gcd(x, y) = gcd(x, y / 2) if y is even and x is odd + ((x, y), (0, 1)) | ((y, x), (1, 0)) => greatest_common_divisor_stein(x >> 1, y), + // gcd(x, y) = 2 * gcd(x / 2, y / 2) if x and y are both even + ((x, y), (0, 0)) => greatest_common_divisor_stein(x >> 1, y >> 1) << 1, + // if x and y are both odd + ((x, y), (1, 1)) => { + // then gcd(x, y) = gcd((x - y) / 2, y) if x >= y + // gcd(x, y) = gcd((y - x) / 2, x) otherwise + let (x, y) = (min(x, y), max(x, y)); + greatest_common_divisor_stein((y - x) >> 1, x) + } + _ => unreachable!(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn positive_number_recursive() { + assert_eq!(greatest_common_divisor_recursive(4, 16), 4); + assert_eq!(greatest_common_divisor_recursive(16, 4), 4); + assert_eq!(greatest_common_divisor_recursive(3, 5), 1); + assert_eq!(greatest_common_divisor_recursive(40, 40), 40); + assert_eq!(greatest_common_divisor_recursive(27, 12), 3); + } + + #[test] + fn positive_number_iterative() { + assert_eq!(greatest_common_divisor_iterative(4, 16), 4); + assert_eq!(greatest_common_divisor_iterative(16, 4), 4); + assert_eq!(greatest_common_divisor_iterative(3, 5), 1); + assert_eq!(greatest_common_divisor_iterative(40, 40), 40); + assert_eq!(greatest_common_divisor_iterative(27, 12), 3); + } + + #[test] + fn positive_number_stein() { + assert_eq!(greatest_common_divisor_stein(4, 16), 4); + assert_eq!(greatest_common_divisor_stein(16, 4), 4); + assert_eq!(greatest_common_divisor_stein(3, 5), 1); + assert_eq!(greatest_common_divisor_stein(40, 40), 40); + assert_eq!(greatest_common_divisor_stein(27, 12), 3); + } + + #[test] + fn negative_number_recursive() { + assert_eq!(greatest_common_divisor_recursive(-32, -8), 8); + assert_eq!(greatest_common_divisor_recursive(-8, -32), 8); + assert_eq!(greatest_common_divisor_recursive(-3, -5), 1); + assert_eq!(greatest_common_divisor_recursive(-40, -40), 40); + assert_eq!(greatest_common_divisor_recursive(-12, -27), 3); + } + + #[test] + fn negative_number_iterative() { + assert_eq!(greatest_common_divisor_iterative(-32, -8), 8); + assert_eq!(greatest_common_divisor_iterative(-8, -32), 8); + assert_eq!(greatest_common_divisor_iterative(-3, -5), 1); + assert_eq!(greatest_common_divisor_iterative(-40, -40), 40); + assert_eq!(greatest_common_divisor_iterative(-12, -27), 3); + } + + #[test] + fn mix_recursive() { + assert_eq!(greatest_common_divisor_recursive(0, -5), 5); + assert_eq!(greatest_common_divisor_recursive(-5, 0), 5); + assert_eq!(greatest_common_divisor_recursive(-64, 32), 32); + assert_eq!(greatest_common_divisor_recursive(-32, 64), 32); + assert_eq!(greatest_common_divisor_recursive(-40, 40), 40); + assert_eq!(greatest_common_divisor_recursive(12, -27), 3); + } + + #[test] + fn mix_iterative() { + assert_eq!(greatest_common_divisor_iterative(0, -5), 5); + assert_eq!(greatest_common_divisor_iterative(-5, 0), 5); + assert_eq!(greatest_common_divisor_iterative(-64, 32), 32); + assert_eq!(greatest_common_divisor_iterative(-32, 64), 32); + assert_eq!(greatest_common_divisor_iterative(-40, 40), 40); + assert_eq!(greatest_common_divisor_iterative(12, -27), 3); + } +} diff --git a/src/math/huber_loss.rs b/src/math/huber_loss.rs new file mode 100644 index 00000000000..7bb304b81b0 --- /dev/null +++ b/src/math/huber_loss.rs @@ -0,0 +1,43 @@ +//! # Huber Loss Function +//! +//! The `huber_loss` function calculates the Huber loss, which is a robust loss function used in machine learning, particularly in regression problems. +//! +//! Huber loss combines the benefits of mean squared error (MSE) and mean absolute error (MAE). It behaves like MSE when the difference between actual and predicted values is small (less than a specified `delta`), and like MAE when the difference is large. +//! +//! ## Formula +//! +//! For a pair of actual and predicted values, represented as vectors `actual` and `predicted`, and a specified `delta` value, the Huber loss is calculated as: +//! +//! - If the absolute difference between `actual[i]` and `predicted[i]` is less than or equal to `delta`, the loss is `0.5 * (actual[i] - predicted[i])^2`. +//! - If the absolute difference is greater than `delta`, the loss is `delta * |actual[i] - predicted[i]| - 0.5 * delta`. +//! +//! The total loss is the sum of individual losses over all elements. +//! +//! ## Huber Loss Function Implementation +//! +//! This implementation takes two references to vectors of f64 values, `actual` and `predicted`, and a `delta` value. It returns the Huber loss between them, providing a robust measure of dissimilarity between actual and predicted values. +//! +pub fn huber_loss(actual: &[f64], predicted: &[f64], delta: f64) -> f64 { + let mut loss: Vec = Vec::new(); + for (a, p) in actual.iter().zip(predicted.iter()) { + if (a - p).abs() <= delta { + loss.push(0.5 * (a - p).powf(2.)); + } else { + loss.push(delta * (a - p).abs() - (0.5 * delta)); + } + } + + loss.iter().sum() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_huber_loss() { + let test_vector_actual = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + let test_vector = vec![5.0, 7.0, 9.0, 11.0, 13.0]; + assert_eq!(huber_loss(&test_vector_actual, &test_vector, 1.0), 27.5); + } +} diff --git a/src/math/infix_to_postfix.rs b/src/math/infix_to_postfix.rs new file mode 100644 index 00000000000..123c792779d --- /dev/null +++ b/src/math/infix_to_postfix.rs @@ -0,0 +1,94 @@ +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum InfixToPostfixError { + UnknownCharacter(char), + UnmatchedParent, +} + +/// Function to convert [infix expression](https://en.wikipedia.org/wiki/Infix_notation) to [postfix expression](https://en.wikipedia.org/wiki/Reverse_Polish_notation) +pub fn infix_to_postfix(infix: &str) -> Result { + let mut postfix = String::new(); + let mut stack: Vec = Vec::new(); + + // Define the precedence of operators + let precedence = |op: char| -> u8 { + match op { + '+' | '-' => 1, + '*' | '/' => 2, + '^' => 3, + _ => 0, + } + }; + + for token in infix.chars() { + match token { + c if c.is_alphanumeric() => { + postfix.push(c); + } + '(' => { + stack.push('('); + } + ')' => { + while let Some(top) = stack.pop() { + if top == '(' { + break; + } + postfix.push(top); + } + } + '+' | '-' | '*' | '/' | '^' => { + while let Some(top) = stack.last() { + if *top == '(' || precedence(*top) < precedence(token) { + break; + } + postfix.push(stack.pop().unwrap()); + } + stack.push(token); + } + other => return Err(InfixToPostfixError::UnknownCharacter(other)), + } + } + + while let Some(top) = stack.pop() { + if top == '(' { + return Err(InfixToPostfixError::UnmatchedParent); + } + + postfix.push(top); + } + + Ok(postfix) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_infix_to_postfix { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (infix, expected) = $inputs; + assert_eq!(infix_to_postfix(infix), expected) + } + )* + } + } + + test_infix_to_postfix! { + single_symbol: ("x", Ok(String::from("x"))), + simple_sum: ("x+y", Ok(String::from("xy+"))), + multiply_sum_left: ("x*(y+z)", Ok(String::from("xyz+*"))), + multiply_sum_right: ("(x+y)*z", Ok(String::from("xy+z*"))), + multiply_two_sums: ("(a+b)*(c+d)", Ok(String::from("ab+cd+*"))), + product_and_power: ("a*b^c", Ok(String::from("abc^*"))), + power_and_product: ("a^b*c", Ok(String::from("ab^c*"))), + product_of_powers: ("(a*b)^c", Ok(String::from("ab*c^"))), + product_in_exponent: ("a^(b*c)", Ok(String::from("abc*^"))), + regular_0: ("a-b+c-d*e", Ok(String::from("ab-c+de*-"))), + regular_1: ("a*(b+c)+d/(e+f)", Ok(String::from("abc+*def+/+"))), + regular_2: ("(a-b+c)*(d+e*f)", Ok(String::from("ab-c+def*+*"))), + unknown_character: ("(a-b)*#", Err(InfixToPostfixError::UnknownCharacter('#'))), + unmatched_paren: ("((a-b)", Err(InfixToPostfixError::UnmatchedParent)), + } +} diff --git a/src/math/interest.rs b/src/math/interest.rs new file mode 100644 index 00000000000..6347f211abe --- /dev/null +++ b/src/math/interest.rs @@ -0,0 +1,55 @@ +// value of e +use std::f64::consts::E; + +// function to calculate simple interest +pub fn simple_interest(principal: f64, annual_rate: f64, years: f64) -> (f64, f64) { + let interest = principal * annual_rate * years; + let value = principal * (1.0 + (annual_rate * years)); + + println!("Interest earned: {interest}"); + println!("Future value: {value}"); + + (interest, value) +} + +// function to calculate compound interest compounded over periods or continuously +pub fn compound_interest(principal: f64, annual_rate: f64, years: f64, period: Option) -> f64 { + // checks if the period is None type, if so calculates continuous compounding interest + let value = if period.is_none() { + principal * E.powf(annual_rate * years) + } else { + // unwraps the option type or defaults to 0 if None type and assigns it to prim_period + let prim_period: f64 = period.unwrap_or(0.0); + // checks if the period is less than or equal to zero + if prim_period <= 0.0_f64 { + return f64::NAN; + } + principal * (1.0 + (annual_rate / prim_period).powf(prim_period * years)) + }; + println!("Future value: {value}"); + value +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_simple() { + let x = 385.65_f64 * 0.03_f64 * 5.0_f64; + let y = 385.65_f64 * (1.0 + (0.03_f64 * 5.0_f64)); + assert_eq!(simple_interest(385.65_f64, 0.03_f64, 5.0_f64), (x, y)); + } + #[test] + fn test_compounding() { + let x = 385.65_f64 * E.powf(0.03_f64 * 5.0_f64); + assert_eq!(compound_interest(385.65_f64, 0.03_f64, 5.0_f64, None), x); + + let y = 385.65_f64 * (1.0 + (0.03_f64 / 5.0_f64).powf(5.0_f64 * 5.0_f64)); + assert_eq!( + compound_interest(385.65_f64, 0.03_f64, 5.0_f64, Some(5.0_f64)), + y + ); + assert!(compound_interest(385.65_f64, 0.03_f64, 5.0_f64, Some(-5.0_f64)).is_nan()); + assert!(compound_interest(385.65_f64, 0.03_f64, 5.0_f64, Some(0.0_f64)).is_nan()); + } +} diff --git a/src/math/interpolation.rs b/src/math/interpolation.rs new file mode 100644 index 00000000000..b5bb39ce5d6 --- /dev/null +++ b/src/math/interpolation.rs @@ -0,0 +1,90 @@ +/// In mathematics, linear interpolation is a method of curve fitting +/// using linear polynomials to construct new data points within the range of a discrete set of known data points. +/// Formula: y = y0 + (x - x0) * (y1 - y0) / (x1 - x0) +/// Source: https://en.wikipedia.org/wiki/Linear_interpolation +/// point0 and point1 are a tuple containing x and y values we want to interpolate between +pub fn linear_interpolation(x: f64, point0: (f64, f64), point1: (f64, f64)) -> f64 { + point0.1 + (x - point0.0) * (point1.1 - point0.1) / (point1.0 - point0.0) +} + +/// In numerical analysis, the Lagrange interpolating polynomial +/// is the unique polynomial of lowest degree that interpolates a given set of data. +/// +/// Source: https://en.wikipedia.org/wiki/Lagrange_polynomial +/// Source: https://mathworld.wolfram.com/LagrangeInterpolatingPolynomial.html +/// x is the point we wish to interpolate +/// defined points are a vector of tuples containing known x and y values of our function +pub fn lagrange_polynomial_interpolation(x: f64, defined_points: &Vec<(f64, f64)>) -> f64 { + let mut defined_x_values: Vec = Vec::new(); + let mut defined_y_values: Vec = Vec::new(); + + for (x, y) in defined_points { + defined_x_values.push(*x); + defined_y_values.push(*y); + } + + let mut sum = 0.0; + + for y_index in 0..defined_y_values.len() { + let mut numerator = 1.0; + let mut denominator = 1.0; + for x_index in 0..defined_x_values.len() { + if y_index == x_index { + continue; + } + denominator *= defined_x_values[y_index] - defined_x_values[x_index]; + numerator *= x - defined_x_values[x_index]; + } + + sum += numerator / denominator * defined_y_values[y_index]; + } + sum +} + +#[cfg(test)] +mod tests { + + use std::assert_eq; + + use super::*; + #[test] + fn test_linear_intepolation() { + let point1 = (0.0, 0.0); + let point2 = (1.0, 1.0); + let point3 = (2.0, 2.0); + + let x1 = 0.5; + let x2 = 1.5; + + let y1 = linear_interpolation(x1, point1, point2); + let y2 = linear_interpolation(x2, point2, point3); + + assert_eq!(y1, x1); + assert_eq!(y2, x2); + assert_eq!( + linear_interpolation(x1, point1, point2), + linear_interpolation(x1, point2, point1) + ); + } + + #[test] + fn test_lagrange_polynomial_interpolation() { + // defined values for x^2 function + let defined_points = vec![(0.0, 0.0), (1.0, 1.0), (2.0, 4.0), (3.0, 9.0)]; + + // check for equality + assert_eq!(lagrange_polynomial_interpolation(1.0, &defined_points), 1.0); + assert_eq!(lagrange_polynomial_interpolation(2.0, &defined_points), 4.0); + assert_eq!(lagrange_polynomial_interpolation(3.0, &defined_points), 9.0); + + //other + assert_eq!( + lagrange_polynomial_interpolation(0.5, &defined_points), + 0.25 + ); + assert_eq!( + lagrange_polynomial_interpolation(2.5, &defined_points), + 6.25 + ); + } +} diff --git a/src/math/interquartile_range.rs b/src/math/interquartile_range.rs new file mode 100644 index 00000000000..fed92b77709 --- /dev/null +++ b/src/math/interquartile_range.rs @@ -0,0 +1,85 @@ +// Author : cyrixninja +// Interquartile Range : An implementation of interquartile range (IQR) which is a measure of statistical +// dispersion, which is the spread of the data. +// Wikipedia Reference : https://en.wikipedia.org/wiki/Interquartile_range + +use std::cmp::Ordering; + +pub fn find_median(numbers: &[f64]) -> f64 { + let mut numbers = numbers.to_vec(); + numbers.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); + + let length = numbers.len(); + let mid = length / 2; + + if length % 2 == 0 { + f64::midpoint(numbers[mid - 1], numbers[mid]) + } else { + numbers[mid] + } +} + +pub fn interquartile_range(numbers: &[f64]) -> f64 { + if numbers.is_empty() { + panic!("Error: The list is empty. Please provide a non-empty list."); + } + + let mut numbers = numbers.to_vec(); + numbers.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); + + let length = numbers.len(); + let mid = length / 2; + let (q1, q3) = if length % 2 == 0 { + let first_half = &numbers[0..mid]; + let second_half = &numbers[mid..length]; + (find_median(first_half), find_median(second_half)) + } else { + let first_half = &numbers[0..mid]; + let second_half = &numbers[mid + 1..length]; + (find_median(first_half), find_median(second_half)) + }; + + q3 - q1 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_find_median() { + let numbers1 = vec![1.0, 2.0, 2.0, 3.0, 4.0]; + assert_eq!(find_median(&numbers1), 2.0); + + let numbers2 = vec![1.0, 2.0, 2.0, 3.0, 4.0, 4.0]; + assert_eq!(find_median(&numbers2), 2.5); + + let numbers3 = vec![-1.0, 2.0, 0.0, 3.0, 4.0, -4.0]; + assert_eq!(find_median(&numbers3), 1.0); + + let numbers4 = vec![1.1, 2.2, 2.0, 3.3, 4.4, 4.0]; + assert_eq!(find_median(&numbers4), 2.75); + } + + #[test] + fn test_interquartile_range() { + let numbers1 = vec![4.0, 1.0, 2.0, 3.0, 2.0]; + assert_eq!(interquartile_range(&numbers1), 2.0); + + let numbers2 = vec![-2.0, -7.0, -10.0, 9.0, 8.0, 4.0, -67.0, 45.0]; + assert_eq!(interquartile_range(&numbers2), 17.0); + + let numbers3 = vec![-2.1, -7.1, -10.1, 9.1, 8.1, 4.1, -67.1, 45.1]; + assert_eq!(interquartile_range(&numbers3), 17.2); + + let numbers4 = vec![0.0, 0.0, 0.0, 0.0, 0.0]; + assert_eq!(interquartile_range(&numbers4), 0.0); + } + + #[test] + #[should_panic(expected = "Error: The list is empty. Please provide a non-empty list.")] + fn test_interquartile_range_empty_list() { + let numbers: Vec = vec![]; + interquartile_range(&numbers); + } +} diff --git a/src/math/karatsuba_multiplication.rs b/src/math/karatsuba_multiplication.rs new file mode 100644 index 00000000000..4547faf9119 --- /dev/null +++ b/src/math/karatsuba_multiplication.rs @@ -0,0 +1,69 @@ +/* +Finds the product of two numbers using Karatsuba Algorithm + */ +use std::cmp::max; +const TEN: i128 = 10; + +pub fn multiply(num1: i128, num2: i128) -> i128 { + _multiply(num1, num2) +} + +fn _multiply(num1: i128, num2: i128) -> i128 { + if num1 < 10 || num2 < 10 { + return num1 * num2; + } + let mut num1_str = num1.to_string(); + let mut num2_str = num2.to_string(); + + let n = max(num1_str.len(), num2_str.len()); + num1_str = normalize(num1_str, n); + num2_str = normalize(num2_str, n); + + let a = &num1_str[0..n / 2]; + let b = &num1_str[n / 2..]; + let c = &num2_str[0..n / 2]; + let d = &num2_str[n / 2..]; + + let ac = _multiply(a.parse().unwrap(), c.parse().unwrap()); + let bd = _multiply(b.parse().unwrap(), d.parse().unwrap()); + let a_b: i128 = a.parse::().unwrap() + b.parse::().unwrap(); + let c_d: i128 = c.parse::().unwrap() + d.parse::().unwrap(); + let ad_bc = _multiply(a_b, c_d) - (ac + bd); + + let m = n / 2 + n % 2; + (TEN.pow(2 * m as u32) * ac) + (TEN.pow(m as u32) * ad_bc) + (bd) +} + +fn normalize(mut a: String, n: usize) -> String { + let padding = n.saturating_sub(a.len()); + a.insert_str(0, &"0".repeat(padding)); + a +} +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_1() { + let n1: i128 = 314159265; + let n2: i128 = 314159265; + let ans = multiply(n1, n2); + assert_eq!(ans, n1 * n2); + } + + #[test] + fn test_2() { + let n1: i128 = 3141592653589793232; + let n2: i128 = 2718281828459045233; + let ans = multiply(n1, n2); + assert_eq!(ans, n1 * n2); + } + + #[test] + fn test_3() { + let n1: i128 = 123456789; + let n2: i128 = 101112131415; + let ans = multiply(n1, n2); + assert_eq!(ans, n1 * n2); + } +} diff --git a/src/math/lcm_of_n_numbers.rs b/src/math/lcm_of_n_numbers.rs new file mode 100644 index 00000000000..811db77683f --- /dev/null +++ b/src/math/lcm_of_n_numbers.rs @@ -0,0 +1,30 @@ +// returns the least common multiple of n numbers + +pub fn lcm(nums: &[usize]) -> usize { + if nums.len() == 1 { + return nums[0]; + } + let a = nums[0]; + let b = lcm(&nums[1..]); + a * b / gcd_of_two_numbers(a, b) +} + +fn gcd_of_two_numbers(a: usize, b: usize) -> usize { + if b == 0 { + return a; + } + gcd_of_two_numbers(b, a % b) +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn it_works() { + assert_eq!(lcm(&[1, 2, 3, 4, 5]), 60); + assert_eq!(lcm(&[2, 4, 6, 8, 10]), 120); + assert_eq!(lcm(&[3, 6, 9, 12, 15]), 180); + assert_eq!(lcm(&[10]), 10); + assert_eq!(lcm(&[21, 110]), 2310); + } +} diff --git a/src/math/leaky_relu.rs b/src/math/leaky_relu.rs new file mode 100644 index 00000000000..954a1023d6d --- /dev/null +++ b/src/math/leaky_relu.rs @@ -0,0 +1,47 @@ +//! # Leaky ReLU Function +//! +//! The `leaky_relu` function computes the Leaky Rectified Linear Unit (ReLU) values of a given vector +//! of f64 numbers with a specified alpha parameter. +//! +//! The Leaky ReLU activation function is commonly used in neural networks to introduce a small negative +//! slope (controlled by the alpha parameter) for the negative input values, preventing neurons from dying +//! during training. +//! +//! ## Formula +//! +//! For a given input vector `x` and an alpha parameter `alpha`, the Leaky ReLU function computes the output +//! `y` as follows: +//! +//! `y_i = { x_i if x_i >= 0, alpha * x_i if x_i < 0 }` +//! +//! ## Leaky ReLU Function Implementation +//! +//! This implementation takes a reference to a vector of f64 values and an alpha parameter, and returns a new +//! vector with the Leaky ReLU transformation applied to each element. The input vector is not altered. +//! +pub fn leaky_relu(vector: &Vec, alpha: f64) -> Vec { + let mut _vector = vector.to_owned(); + + for value in &mut _vector { + if value < &mut 0. { + *value *= alpha; + } + } + + _vector +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_leaky_relu() { + let test_vector = vec![-10., 2., -3., 4., -5., 10., 0.05]; + let alpha = 0.01; + assert_eq!( + leaky_relu(&test_vector, alpha), + vec![-0.1, 2.0, -0.03, 4.0, -0.05, 10.0, 0.05] + ); + } +} diff --git a/src/math/least_square_approx.rs b/src/math/least_square_approx.rs new file mode 100644 index 00000000000..bc12d8e766e --- /dev/null +++ b/src/math/least_square_approx.rs @@ -0,0 +1,115 @@ +/// Least Square Approximation

+/// Function that returns a polynomial which very closely passes through the given points (in 2D) +/// +/// The result is made of coeficients, in descending order (from x^degree to free term) +/// +/// Parameters: +/// +/// points -> coordinates of given points +/// +/// degree -> degree of the polynomial +/// +pub fn least_square_approx + Copy, U: Into + Copy>( + points: &[(T, U)], + degree: i32, +) -> Option> { + use nalgebra::{DMatrix, DVector}; + + /* Used for rounding floating numbers */ + fn round_to_decimals(value: f64, decimals: i32) -> f64 { + let multiplier = 10f64.powi(decimals); + (value * multiplier).round() / multiplier + } + + /* Casting the data parsed to this function to f64 (as some points can have decimals) */ + let vals: Vec<(f64, f64)> = points + .iter() + .map(|(x, y)| ((*x).into(), (*y).into())) + .collect(); + /* Because of collect we need the Copy Trait for T and U */ + + /* Computes the sums in the system of equations */ + let mut sums = Vec::::new(); + for i in 1..=(2 * degree + 1) { + sums.push(vals.iter().map(|(x, _)| x.powi(i - 1)).sum()); + } + + /* Compute the free terms column vector */ + let mut free_col = Vec::::new(); + for i in 1..=(degree + 1) { + free_col.push(vals.iter().map(|(x, y)| y * (x.powi(i - 1))).sum()); + } + let b = DVector::from_row_slice(&free_col); + + /* Create and fill the system's matrix */ + let size = (degree + 1) as usize; + let a = DMatrix::from_fn(size, size, |i, j| sums[degree as usize + i - j]); + + /* Solve the system of equations: A * x = b */ + match a.qr().solve(&b) { + Some(x) => { + let rez: Vec = x.iter().map(|x| round_to_decimals(*x, 5)).collect(); + Some(rez) + } + None => None, //<-- The system cannot be solved (badly conditioned system's matrix) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ten_points_1st_degree() { + let points = vec![ + (5.3, 7.8), + (4.9, 8.1), + (6.1, 6.9), + (4.7, 8.3), + (6.5, 7.7), + (5.6, 7.0), + (5.8, 8.2), + (4.5, 8.0), + (6.3, 7.2), + (5.1, 8.4), + ]; + + assert_eq!( + least_square_approx(&points, 1), + Some(vec![-0.49069, 10.44898]) + ); + } + + #[test] + fn eight_points_5th_degree() { + let points = vec![ + (4f64, 8f64), + (8f64, 2f64), + (1f64, 7f64), + (10f64, 3f64), + (11.0, 0.0), + (7.0, 3.0), + (10.0, 1.0), + (13.0, 13.0), + ]; + + assert_eq!( + least_square_approx(&points, 5), + Some(vec![ + 0.00603, -0.21304, 2.79929, -16.53468, 40.29473, -19.35771 + ]) + ); + } + + #[test] + fn four_points_2nd_degree() { + let points = vec![ + (2.312, 8.345344), + (-2.312, 8.345344), + (-0.7051, 3.49716601), + (0.7051, 3.49716601), + ]; + + assert_eq!(least_square_approx(&points, 2), Some(vec![1.0, 0.0, 3.0])); + } +} diff --git a/src/math/linear_sieve.rs b/src/math/linear_sieve.rs new file mode 100644 index 00000000000..8fb49d26a75 --- /dev/null +++ b/src/math/linear_sieve.rs @@ -0,0 +1,134 @@ +/* +Linear Sieve algorithm: +Time complexity is indeed O(n) with O(n) memory, but the sieve generally +runs slower than a well implemented sieve of Eratosthenes. Some use cases are: +- factorizing any number k in the sieve in O(log(k)) +- calculating arbitrary multiplicative functions on sieve numbers + without increasing the time complexity +- As a by product, all prime numbers less than `max_number` are stored + in `primes` vector. + */ +pub struct LinearSieve { + max_number: usize, + pub primes: Vec, + pub minimum_prime_factor: Vec, +} + +impl LinearSieve { + pub const fn new() -> Self { + LinearSieve { + max_number: 0, + primes: vec![], + minimum_prime_factor: vec![], + } + } + + pub fn prepare(&mut self, max_number: usize) -> Result<(), &'static str> { + if max_number <= 1 { + return Err("Sieve size should be more than 1"); + } + if self.max_number > 0 { + return Err("Sieve already initialized"); + } + self.max_number = max_number; + self.minimum_prime_factor.resize(max_number + 1, 0); + for i in 2..=max_number { + if self.minimum_prime_factor[i] == 0 { + self.minimum_prime_factor[i] = i; + self.primes.push(i); + /* + if needed, a multiplicative function can be + calculated for this prime number here: + function[i] = base_case(i); + */ + } + for p in self.primes.iter() { + let mlt = (*p) * i; + if *p > i || mlt > max_number { + break; + } + self.minimum_prime_factor[mlt] = *p; + /* + multiplicative function for mlt can be calculated here: + if i % p: + function[mlt] = add_to_prime_exponent(function[i], i, p); + else: + function[mlt] = function[i] * function[p] + */ + } + } + Ok(()) + } + + pub fn factorize(&self, mut number: usize) -> Result, &'static str> { + if number > self.max_number { + return Err("Number is too big, its minimum_prime_factor was not calculated"); + } + if number == 0 { + return Err("Number is zero"); + } + let mut result: Vec = Vec::new(); + while number > 1 { + result.push(self.minimum_prime_factor[number]); + number /= self.minimum_prime_factor[number]; + } + Ok(result) + } +} + +impl Default for LinearSieve { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::LinearSieve; + + #[test] + fn small_primes_list() { + let mut ls = LinearSieve::new(); + ls.prepare(25).unwrap(); + assert_eq!(ls.primes, vec![2, 3, 5, 7, 11, 13, 17, 19, 23]); + } + + #[test] + fn divisible_by_mpf() { + let mut ls = LinearSieve::new(); + ls.prepare(1000).unwrap(); + for i in 2..=1000 { + let div = i / ls.minimum_prime_factor[i]; + assert_eq!(i % ls.minimum_prime_factor[i], 0); + if div == 1 { + // Number must be prime + assert!(ls.primes.binary_search(&i).is_ok()); + } + } + } + + #[test] + fn check_factorization() { + let mut ls = LinearSieve::new(); + ls.prepare(1000).unwrap(); + for i in 1..=1000 { + let factorization = ls.factorize(i).unwrap(); + let mut product = 1usize; + for (idx, p) in factorization.iter().enumerate() { + assert!(ls.primes.binary_search(p).is_ok()); + product *= *p; + if idx > 0 { + assert!(*p >= factorization[idx - 1]); + } + } + assert_eq!(product, i); + } + } + + #[test] + fn check_number_of_primes() { + let mut ls = LinearSieve::new(); + ls.prepare(100_000).unwrap(); + assert_eq!(ls.primes.len(), 9592); + } +} diff --git a/src/math/logarithm.rs b/src/math/logarithm.rs new file mode 100644 index 00000000000..c94e8247d11 --- /dev/null +++ b/src/math/logarithm.rs @@ -0,0 +1,77 @@ +use std::f64::consts::E; + +/// Calculates the **logbase(x)** +/// +/// Parameters: +///

-> base: base of log +///

-> x: value for which log shall be evaluated +///

-> tol: tolerance; the precision of the approximation (submultiples of 10-1) +/// +/// Advisable to use **std::f64::consts::*** for specific bases (like 'e') +pub fn log, U: Into>(base: U, x: T, tol: f64) -> f64 { + let mut rez = 0f64; + let mut argument: f64 = x.into(); + let usable_base: f64 = base.into(); + + if argument <= 0f64 || usable_base <= 0f64 { + println!("Log does not support negative argument or negative base."); + f64::NAN + } else if argument < 1f64 && usable_base == E { + argument -= 1f64; + let mut prev_rez = 1f64; + let mut step: i32 = 1; + /* + For x in (0, 1) and base 'e', the function is using MacLaurin Series: + ln(|1 + x|) = Σ "(-1)^n-1 * x^n / n", for n = 1..inf + Substituting x with x-1 yields: + ln(|x|) = Σ "(-1)^n-1 * (x-1)^n / n" + */ + while (prev_rez - rez).abs() > tol { + prev_rez = rez; + rez += (-1f64).powi(step - 1) * argument.powi(step) / step as f64; + step += 1; + } + + rez + } else { + /* Using the basic change of base formula for log */ + let ln_x = argument.ln(); + let ln_base = usable_base.ln(); + + ln_x / ln_base + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn basic() { + assert_eq!(log(E, E, 0.0), 1.0); + assert_eq!(log(E, E.powi(100), 0.0), 100.0); + assert_eq!(log(10, 10000.0, 0.0), 4.0); + assert_eq!(log(234501.0, 1.0, 1.0), 0.0); + } + + #[test] + fn test_log_positive_base() { + assert_eq!(log(10.0, 100.0, 0.00001), 2.0); + assert_eq!(log(2.0, 8.0, 0.00001), 3.0); + } + + #[test] + fn test_log_zero_base() { + assert!(log(0.0, 100.0, 0.00001).is_nan()); + } + + #[test] + fn test_log_negative_base() { + assert!(log(-1.0, 100.0, 0.00001).is_nan()); + } + + #[test] + fn test_log_tolerance() { + assert_eq!(log(10.0, 100.0, 1e-10), 2.0); + } +} diff --git a/src/math/lucas_series.rs b/src/math/lucas_series.rs new file mode 100644 index 00000000000..cbc6f48fc40 --- /dev/null +++ b/src/math/lucas_series.rs @@ -0,0 +1,60 @@ +// Author : cyrixninja +// Lucas Series : Function to get the Nth Lucas Number +// Wikipedia Reference : https://en.wikipedia.org/wiki/Lucas_number +// Other References : https://the-algorithms.com/algorithm/lucas-series?lang=python + +pub fn recursive_lucas_number(n: u32) -> u32 { + match n { + 0 => 2, + 1 => 1, + _ => recursive_lucas_number(n - 1) + recursive_lucas_number(n - 2), + } +} + +pub fn dynamic_lucas_number(n: u32) -> u32 { + let mut a = 2; + let mut b = 1; + + for _ in 0..n { + let temp = a; + a = b; + b += temp; + } + + a +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_lucas_number { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (n, expected) = $inputs; + assert_eq!(recursive_lucas_number(n), expected); + assert_eq!(dynamic_lucas_number(n), expected); + } + )* + } + } + + test_lucas_number! { + input_0: (0, 2), + input_1: (1, 1), + input_2: (2, 3), + input_3: (3, 4), + input_4: (4, 7), + input_5: (5, 11), + input_6: (6, 18), + input_7: (7, 29), + input_8: (8, 47), + input_9: (9, 76), + input_10: (10, 123), + input_15: (15, 1364), + input_20: (20, 15127), + input_25: (25, 167761), + } +} diff --git a/src/math/matrix_ops.rs b/src/math/matrix_ops.rs new file mode 100644 index 00000000000..29fa722f46f --- /dev/null +++ b/src/math/matrix_ops.rs @@ -0,0 +1,569 @@ +// Basic matrix operations using a Matrix type with internally uses +// a vector representation to store matrix elements. +// Generic using the MatrixElement trait, which can be implemented with +// the matrix_element_type_def macro. +// Wikipedia reference: https://www.wikiwand.com/en/Matrix_(mathematics) +use std::ops::{Add, AddAssign, Index, IndexMut, Mul, Sub}; + +// Define macro to build a matrix idiomatically +#[macro_export] +macro_rules! matrix { + [$([$($x:expr),* $(,)*]),* $(,)*] => {{ + Matrix::from(vec![$(vec![$($x,)*],)*]) + }}; +} + +// Define a trait "alias" for suitable matrix elements +pub trait MatrixElement: + Add + Sub + Mul + AddAssign + Copy + From +{ +} + +// Define a macro to implement the MatrixElement trait for desired types +#[macro_export] +macro_rules! matrix_element_type_def { + ($T: ty) => { + // Implement trait for type + impl MatrixElement for $T {} + + // Defining left-hand multiplication in this form + // prevents errors for uncovered types + impl Mul<&Matrix<$T>> for $T { + type Output = Matrix<$T>; + + fn mul(self, rhs: &Matrix<$T>) -> Self::Output { + rhs * self + } + } + }; + + ($T: ty, $($Ti: ty),+) => { + // Decompose type definitions recursively + matrix_element_type_def!($T); + matrix_element_type_def!($($Ti),+); + }; +} + +matrix_element_type_def!(i16, i32, i64, i128, u8, u16, u32, u128, f32, f64); + +#[derive(PartialEq, Eq, Debug)] +pub struct Matrix { + data: Vec, + rows: usize, + cols: usize, +} + +impl Matrix { + pub fn new(data: Vec, rows: usize, cols: usize) -> Self { + // Build a matrix from the internal vector representation + if data.len() != rows * cols { + panic!("Inconsistent data and dimensions combination for matrix") + } + Matrix { data, rows, cols } + } + + pub fn zero(rows: usize, cols: usize) -> Self { + // Build a matrix of zeros + Matrix { + data: vec![0.into(); rows * cols], + rows, + cols, + } + } + + pub fn identity(len: usize) -> Self { + // Build an identity matrix + let mut identity = Matrix::zero(len, len); + // Diagonal of ones + for i in 0..len { + identity[[i, i]] = 1.into(); + } + identity + } + + pub fn transpose(&self) -> Self { + // Transpose a matrix of any size + let mut result = Matrix::zero(self.cols, self.rows); + for i in 0..self.rows { + for j in 0..self.cols { + result[[i, j]] = self[[j, i]]; + } + } + result + } +} + +impl Index<[usize; 2]> for Matrix { + type Output = T; + + fn index(&self, index: [usize; 2]) -> &Self::Output { + let [i, j] = index; + if i >= self.rows || j >= self.cols { + panic!("Matrix index out of bounds"); + } + + &self.data[(self.cols * i) + j] + } +} + +impl IndexMut<[usize; 2]> for Matrix { + fn index_mut(&mut self, index: [usize; 2]) -> &mut Self::Output { + let [i, j] = index; + if i >= self.rows || j >= self.cols { + panic!("Matrix index out of bounds"); + } + + &mut self.data[(self.cols * i) + j] + } +} + +impl Add<&Matrix> for &Matrix { + type Output = Matrix; + + fn add(self, rhs: &Matrix) -> Self::Output { + // Add two matrices. They need to have identical dimensions. + if self.rows != rhs.rows || self.cols != rhs.cols { + panic!("Matrix dimensions do not match"); + } + + let mut result = Matrix::zero(self.rows, self.cols); + for i in 0..self.rows { + for j in 0..self.cols { + result[[i, j]] = self[[i, j]] + rhs[[i, j]]; + } + } + result + } +} + +impl Sub for &Matrix { + type Output = Matrix; + + fn sub(self, rhs: Self) -> Self::Output { + // Subtract one matrix from another. They need to have identical dimensions. + if self.rows != rhs.rows || self.cols != rhs.cols { + panic!("Matrix dimensions do not match"); + } + + let mut result = Matrix::zero(self.rows, self.cols); + for i in 0..self.rows { + for j in 0..self.cols { + result[[i, j]] = self[[i, j]] - rhs[[i, j]]; + } + } + result + } +} + +impl Mul for &Matrix { + type Output = Matrix; + + fn mul(self, rhs: Self) -> Self::Output { + // Multiply two matrices. The multiplier needs to have the same amount + // of columns as the multiplicand has rows. + if self.cols != rhs.rows { + panic!("Matrix dimensions do not match"); + } + + let mut result = Matrix::zero(self.rows, rhs.cols); + for i in 0..self.rows { + for j in 0..rhs.cols { + result[[i, j]] = { + let mut sum = 0.into(); + for k in 0..self.cols { + sum += self[[i, k]] * rhs[[k, j]]; + } + sum + }; + } + } + result + } +} + +impl Mul for &Matrix { + type Output = Matrix; + + fn mul(self, rhs: T) -> Self::Output { + // Multiply a matrix of any size with a scalar + let mut result = Matrix::zero(self.rows, self.cols); + for i in 0..self.rows { + for j in 0..self.cols { + result[[i, j]] = rhs * self[[i, j]]; + } + } + result + } +} + +impl From>> for Matrix { + fn from(v: Vec>) -> Self { + let rows = v.len(); + let cols = v.first().map_or(0, |row| row.len()); + + // Ensure consistent dimensions + for row in v.iter().skip(1) { + if row.len() != cols { + panic!("Invalid matrix dimensions. Columns must be consistent."); + } + } + if rows != 0 && cols == 0 { + panic!("Invalid matrix dimensions. Multiple empty rows"); + } + + let data = v.into_iter().flat_map(|row| row.into_iter()).collect(); + Self::new(data, rows, cols) + } +} + +#[cfg(test)] +// rustfmt skipped to prevent unformatting matrix definitions to a single line +#[rustfmt::skip] +mod tests { + use super::Matrix; + use std::panic; + + const DELTA: f64 = 1e-3; + + macro_rules! assert_f64_eq { + ($a:expr, $b:expr) => { + assert_eq!($a.data.len(), $b.data.len()); + if !$a + .data + .iter() + .zip($b.data.iter()) + .all(|(x, y)| (*x as f64 - *y as f64).abs() < DELTA) + { + panic!(); + } + }; + } + + #[test] + fn test_invalid_matrix() { + let result = panic::catch_unwind(|| matrix![ + [1, 0, 0, 4], + [0, 1, 0], + [0, 0, 1], + ]); + assert!(result.is_err()); + } + + #[test] + fn test_empty_matrix() { + let a: Matrix = matrix![]; + + let result = panic::catch_unwind(|| a[[0, 0]]); + assert!(result.is_err()); + } + + #[test] + fn test_zero_matrix() { + let a: Matrix = Matrix::zero(3, 5); + + let z = matrix![ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ]; + + assert_f64_eq!(a, z); + } + + #[test] + fn test_identity_matrix() { + let a: Matrix = Matrix::identity(5); + + let id = matrix![ + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ]; + + assert_f64_eq!(a, id); + } + + #[test] + fn test_invalid_add() { + let a = matrix![ + [1, 0, 1], + [0, 2, 0], + [5, 0, 1] + ]; + + let err = matrix![ + [1, 2], + [2, 4], + ]; + + let result = panic::catch_unwind(|| &a + &err); + assert!(result.is_err()); + } + + #[test] + fn test_add_i32() { + let a = matrix![ + [1, 0, 1], + [0, 2, 0], + [5, 0, 1] + ]; + + let b = matrix![ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + ]; + + let add = matrix![ + [2, 0, 1], + [0, 3, 0], + [5, 0, 2], + ]; + + assert_eq!(&a + &b, add); + } + + #[test] + fn test_add_f64() { + let a = matrix![ + [1.0, 2.0, 1.0], + [3.0, 2.0, 0.0], + [5.0, 0.0, 1.0], + ]; + + let b = matrix![ + [1.0, 10.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ]; + + let add = matrix![ + [2.0, 12.0, 1.0], + [3.0, 3.0, 0.0], + [5.0, 0.0, 2.0], + ]; + + assert_f64_eq!(&a + &b, add); + } + + #[test] + fn test_invalid_sub() { + let a = matrix![ + [2, 3], + [10, 2], + ]; + + let err = matrix![ + [5, 6, 10], + [7, 2, 2], + [12, 0, 1], + ]; + + let result = panic::catch_unwind(|| &a - &err); + assert!(result.is_err()); + } + + #[test] + fn test_subtract_i32() { + let a = matrix![ + [1, 0, 1], + [0, 2, 0], + [5, 0, 1], + ]; + + let b = matrix![ + [1, 0, 0], + [0, 1, 3], + [0, 0, 1], + ]; + + let sub = matrix![ + [0, 0, 1], + [0, 1, -3], + [5, 0, 0], + ]; + + assert_eq!(&a - &b, sub); + } + + #[test] + fn test_subtract_f64() { + let a = matrix![ + [7.0, 2.0, 1.0], + [0.0, 3.0, 2.0], + [5.3, 8.8, std::f64::consts::PI], + ]; + + let b = matrix![ + [1.0, 0.0, 5.0], + [-2.0, 1.0, 3.0], + [0.0, 2.2, std::f64::consts::PI], + ]; + + let sub = matrix![ + [6.0, 2.0, -4.0], + [2.0, 2.0, -1.0], + [5.3, 6.6, 0.0], + ]; + + assert_f64_eq!(&a - &b, sub); + } + + #[test] + fn test_invalid_mul() { + let a = matrix![ + [1, 2, 3], + [4, 2, 6], + [3, 4, 1], + [2, 4, 8], + ]; + + let err = matrix![ + [1, 3, 3, 2], + [7, 6, 2, 1], + [3, 4, 2, 1], + [3, 4, 2, 1], + ]; + + let result = panic::catch_unwind(|| &a * &err); + assert!(result.is_err()); + } + + #[test] + fn test_mul_i32() { + let a = matrix![ + [1, 2, 3], + [4, 2, 6], + [3, 4, 1], + [2, 4, 8], + ]; + + let b = matrix![ + [1, 3, 3, 2], + [7, 6, 2, 1], + [3, 4, 2, 1], + ]; + + let mul = matrix![ + [24, 27, 13, 7], + [36, 48, 28, 16], + [34, 37, 19, 11], + [54, 62, 30, 16], + ]; + + assert_eq!(&a * &b, mul); + } + + #[test] + fn test_mul_f64() { + let a = matrix![ + [5.5, 2.9, 1.13, 9.0], + [0.0, 3.0, 11.0, 17.2], + [5.3, 8.8, 2.76, 3.3], + ]; + + let b = matrix![ + [1.0, 0.3, 5.0], + [-2.0, 1.0, 3.0], + [-3.6, 1.5, 3.0], + [0.0, 2.2, 2.0], + ]; + + let mul = matrix![ + [-4.368, 26.045, 57.59], + [-45.6, 57.34, 76.4], + [-22.236, 21.79, 67.78], + ]; + + assert_f64_eq!(&a * &b, mul); + } + + #[test] + fn test_transpose_i32() { + let a = matrix![ + [1, 0, 1], + [0, 2, 0], + [5, 0, 1], + ]; + + let t = matrix![ + [1, 0, 5], + [0, 2, 0], + [1, 0, 1], + ]; + + assert_eq!(a.transpose(), t); + } + + #[test] + fn test_transpose_f64() { + let a = matrix![ + [3.0, 4.0, 2.0], + [0.0, 1.0, 3.0], + [3.0, 1.0, 1.0], + ]; + + let t = matrix![ + [3.0, 0.0, 3.0], + [4.0, 1.0, 1.0], + [2.0, 3.0, 1.0], + ]; + + assert_eq!(a.transpose(), t); + } + + #[test] + fn test_matrix_scalar_zero_mul() { + let a = matrix![ + [3, 2, 2], + [0, 2, 0], + [5, 4, 1], + ]; + + let scalar = 0; + + let scalar_mul = Matrix::zero(3, 3); + + assert_eq!(scalar * &a, scalar_mul); + } + + #[test] + fn test_matrix_scalar_mul_i32() { + let a = matrix![ + [3, 2, 2], + [0, 2, 0], + [5, 4, 1], + ]; + + let scalar = 3; + + let scalar_mul = matrix![ + [9, 6, 6], + [0, 6, 0], + [15, 12, 3], + ]; + + assert_eq!(scalar * &a, scalar_mul); + } + + #[test] + fn test_matrix_scalar_mul_f64() { + let a = matrix![ + [3.2, 5.5, 9.2], + [1.1, 0.0, 2.3], + [0.3, 4.2, 0.0], + ]; + + let scalar = 1.5_f64; + + let scalar_mul = matrix![ + [4.8, 8.25, 13.8], + [1.65, 0.0, 3.45], + [0.45, 6.3, 0.0], + ]; + + assert_f64_eq!(scalar * &a, scalar_mul); + } +} diff --git a/src/math/mersenne_primes.rs b/src/math/mersenne_primes.rs new file mode 100644 index 00000000000..f66b33898c0 --- /dev/null +++ b/src/math/mersenne_primes.rs @@ -0,0 +1,39 @@ +// mersenne prime : https://en.wikipedia.org/wiki/Mersenne_prime +pub fn is_mersenne_prime(n: usize) -> bool { + if n == 2 { + return true; + } + let mut s = 4; + let m = 2_usize.pow(std::convert::TryInto::try_into(n).unwrap()) - 1; + for _ in 0..n - 2 { + s = ((s * s) - 2) % m; + } + s == 0 +} + +pub fn get_mersenne_primes(limit: usize) -> Vec { + let mut results: Vec = Vec::new(); + for num in 1..=limit { + if is_mersenne_prime(num) { + results.push(num); + } + } + results +} + +#[cfg(test)] +mod tests { + use super::{get_mersenne_primes, is_mersenne_prime}; + + #[test] + fn validity_check() { + assert!(is_mersenne_prime(3)); + assert!(is_mersenne_prime(13)); + assert!(!is_mersenne_prime(32)); + } + + #[allow(dead_code)] + fn generation_check() { + assert_eq!(get_mersenne_primes(30), [2, 3, 5, 7, 13, 17, 19]); + } +} diff --git a/src/math/miller_rabin.rs b/src/math/miller_rabin.rs new file mode 100644 index 00000000000..dbeeac5acbd --- /dev/null +++ b/src/math/miller_rabin.rs @@ -0,0 +1,237 @@ +use num_bigint::BigUint; +use num_traits::{One, ToPrimitive, Zero}; +use std::cmp::Ordering; + +fn modulo_power(mut base: u64, mut power: u64, modulo: u64) -> u64 { + base %= modulo; + if base == 0 { + return 0; // return zero if base is divisible by modulo + } + let mut ans: u128 = 1; + let mut bbase: u128 = base as u128; + while power > 0 { + if (power % 2) == 1 { + ans = (ans * bbase) % (modulo as u128); + } + bbase = (bbase * bbase) % (modulo as u128); + power /= 2; + } + ans as u64 +} + +fn check_prime_base(number: u64, base: u64, two_power: u64, odd_power: u64) -> bool { + // returns false if base is a witness + let mut x: u128 = modulo_power(base, odd_power, number) as u128; + let bnumber: u128 = number as u128; + if x == 1 || x == (bnumber - 1) { + return true; + } + for _ in 1..two_power { + x = (x * x) % bnumber; + if x == (bnumber - 1) { + return true; + } + } + false +} + +pub fn miller_rabin(number: u64, bases: &[u64]) -> u64 { + // returns zero on a probable prime, and a witness if number is not prime + // A base set for deterministic performance on 64 bit numbers is: + // [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37] + // another one for 32 bits: + // [2, 3, 5, 7], with smallest number to fail 3'215'031'751 = 151 * 751 * 28351 + // note that all bases should be prime + if number <= 4 { + match number { + 0 => { + panic!("0 is invalid input for Miller-Rabin. 0 is not prime by definition, but has no witness"); + } + 2 | 3 => return 0, + _ => return number, + } + } + if bases.contains(&number) { + return 0; + } + let two_power: u64 = (number - 1).trailing_zeros() as u64; + let odd_power = (number - 1) >> two_power; + for base in bases { + if !check_prime_base(number, *base, two_power, odd_power) { + return *base; + } + } + 0 +} + +pub fn big_miller_rabin(number_ref: &BigUint, bases: &[u64]) -> u64 { + let number = number_ref.clone(); + + if BigUint::from(5u32).cmp(&number) == Ordering::Greater { + if number.eq(&BigUint::zero()) { + panic!("0 is invalid input for Miller-Rabin. 0 is not prime by definition, but has no witness"); + } else if number.eq(&BigUint::from(2u32)) || number.eq(&BigUint::from(3u32)) { + return 0; + } else { + return number.to_u64().unwrap(); + } + } + + if let Some(num) = number.to_u64() { + if bases.contains(&num) { + return 0; + } + } + + let num_minus_one = &number - BigUint::one(); + + let two_power: u64 = num_minus_one.trailing_zeros().unwrap(); + let odd_power: BigUint = &num_minus_one >> two_power; + for base in bases { + let mut x = BigUint::from(*base).modpow(&odd_power, &number); + + if x.eq(&BigUint::one()) || x.eq(&num_minus_one) { + continue; + } + + let mut not_a_witness = false; + + for _ in 1..two_power { + x = (&x * &x) % &number; + if x.eq(&num_minus_one) { + not_a_witness = true; + break; + } + } + + if not_a_witness { + continue; + } + + return *base; + } + + 0 +} + +#[cfg(test)] +mod tests { + use super::*; + + static DEFAULT_BASES: [u64; 12] = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]; + + #[test] + fn basic() { + // these bases make miller rabin deterministic for any number < 2 ^ 64 + // can use smaller number of bases for deterministic performance for numbers < 2 ^ 32 + + assert_eq!(miller_rabin(3, &DEFAULT_BASES), 0); + assert_eq!(miller_rabin(7, &DEFAULT_BASES), 0); + assert_eq!(miller_rabin(11, &DEFAULT_BASES), 0); + assert_eq!(miller_rabin(2003, &DEFAULT_BASES), 0); + + assert_ne!(miller_rabin(1, &DEFAULT_BASES), 0); + assert_ne!(miller_rabin(4, &DEFAULT_BASES), 0); + assert_ne!(miller_rabin(6, &DEFAULT_BASES), 0); + assert_ne!(miller_rabin(21, &DEFAULT_BASES), 0); + assert_ne!(miller_rabin(2004, &DEFAULT_BASES), 0); + + // bigger test cases. + // primes are generated using openssl + // non primes are randomly picked and checked using openssl + + // primes: + assert_eq!(miller_rabin(3629611793, &DEFAULT_BASES), 0); + assert_eq!(miller_rabin(871594686869, &DEFAULT_BASES), 0); + assert_eq!(miller_rabin(968236663804121, &DEFAULT_BASES), 0); + assert_eq!(miller_rabin(6920153791723773023, &DEFAULT_BASES), 0); + + // random non primes: + assert_ne!(miller_rabin(4546167556336341257, &DEFAULT_BASES), 0); + assert_ne!(miller_rabin(4363186415423517377, &DEFAULT_BASES), 0); + assert_ne!(miller_rabin(815479701131020226, &DEFAULT_BASES), 0); + // these two are made of two 31 bit prime factors: + // 1950202127 * 2058609037 = 4014703722618821699 + assert_ne!(miller_rabin(4014703722618821699, &DEFAULT_BASES), 0); + // 1679076769 * 2076341633 = 3486337000477823777 + assert_ne!(miller_rabin(3486337000477823777, &DEFAULT_BASES), 0); + } + + #[test] + fn big_basic() { + assert_eq!(big_miller_rabin(&BigUint::from(3u32), &DEFAULT_BASES), 0); + assert_eq!(big_miller_rabin(&BigUint::from(7u32), &DEFAULT_BASES), 0); + assert_eq!(big_miller_rabin(&BigUint::from(11u32), &DEFAULT_BASES), 0); + assert_eq!(big_miller_rabin(&BigUint::from(2003u32), &DEFAULT_BASES), 0); + + assert_ne!(big_miller_rabin(&BigUint::from(1u32), &DEFAULT_BASES), 0); + assert_ne!(big_miller_rabin(&BigUint::from(4u32), &DEFAULT_BASES), 0); + assert_ne!(big_miller_rabin(&BigUint::from(6u32), &DEFAULT_BASES), 0); + assert_ne!(big_miller_rabin(&BigUint::from(21u32), &DEFAULT_BASES), 0); + assert_ne!(big_miller_rabin(&BigUint::from(2004u32), &DEFAULT_BASES), 0); + + assert_eq!( + big_miller_rabin(&BigUint::from(3629611793u64), &DEFAULT_BASES), + 0 + ); + assert_eq!( + big_miller_rabin(&BigUint::from(871594686869u64), &DEFAULT_BASES), + 0 + ); + assert_eq!( + big_miller_rabin(&BigUint::from(968236663804121u64), &DEFAULT_BASES), + 0 + ); + assert_eq!( + big_miller_rabin(&BigUint::from(6920153791723773023u64), &DEFAULT_BASES), + 0 + ); + + assert_ne!( + big_miller_rabin(&BigUint::from(4546167556336341257u64), &DEFAULT_BASES), + 0 + ); + assert_ne!( + big_miller_rabin(&BigUint::from(4363186415423517377u64), &DEFAULT_BASES), + 0 + ); + assert_ne!( + big_miller_rabin(&BigUint::from(815479701131020226u64), &DEFAULT_BASES), + 0 + ); + assert_ne!( + big_miller_rabin(&BigUint::from(4014703722618821699u64), &DEFAULT_BASES), + 0 + ); + assert_ne!( + big_miller_rabin(&BigUint::from(3486337000477823777u64), &DEFAULT_BASES), + 0 + ); + } + + #[test] + #[ignore] + fn big_primes() { + let p1 = + BigUint::parse_bytes(b"4764862697132131451620315518348229845593592794669", 10).unwrap(); + assert_eq!(big_miller_rabin(&p1, &DEFAULT_BASES), 0); + + let p2 = BigUint::parse_bytes( + b"12550757946601963214089118080443488976766669415957018428703", + 10, + ) + .unwrap(); + assert_eq!(big_miller_rabin(&p2, &DEFAULT_BASES), 0); + + // An RSA-worthy prime + let p3 = BigUint::parse_bytes(b"157d6l5zkv45ve4azfw7nyyjt6rzir2gcjoytjev5iacnkaii8hlkyk3op7bx9qfqiie23vj9iw4qbp7zupydfq9ut6mq6m36etya6cshtqi1yi9q5xyiws92el79dqt8qk7l2pqmxaa0sxhmd2vpaibo9dkfd029j1rvkwlw4724ctgaqs5jzy0bqi5pqdjc2xerhn", 36).unwrap(); + assert_eq!(big_miller_rabin(&p3, &DEFAULT_BASES), 0); + + let n1 = BigUint::parse_bytes(b"coy6tkiaqswmce1r03ycdif3t796wzjwneewbe3cmncaplm85jxzcpdmvy0moic3lql70a81t5qdn2apac0dndhohewkspuk1wyndxsgxs3ux4a7730unru7dfmygh", 36).unwrap(); + assert_ne!(big_miller_rabin(&n1, &DEFAULT_BASES), 0); + + // RSA-2048 + let n2 = BigUint::parse_bytes(b"4l91lq4a2sgekpv8ukx1gxsk7mfeks46haggorlkazm0oufxwijid6q6v44u5me3kz3ne6yczp4fcvo62oej72oe7pjjtyxgid5b8xdz1e8daafspbzcy1hd8i4urjh9hm0tyylsgqsss3jn372d6fmykpw4bb9cr1ngxnncsbod3kg49o7owzqnsci5pwqt8bch0t60gq0st2gyx7ii3mzhb1pp1yvjyor35hwvok1sxj3ih46rpd27li8y5yli3mgdttcn65k3szfa6rbcnbgkojqjjq72gar6raslnh6sjd2fy7yj3bwo43obvbg3ws8y28kpol3okb5b3fld03sq1kgrj2fugiaxgplva6x5ssilqq4g0b21xy2kiou3sqsgonmqx55v", 36).unwrap(); + assert_ne!(big_miller_rabin(&n2, &DEFAULT_BASES), 0); + } +} diff --git a/src/math/mod.rs b/src/math/mod.rs new file mode 100644 index 00000000000..7407465c3b0 --- /dev/null +++ b/src/math/mod.rs @@ -0,0 +1,183 @@ +mod abs; +mod aliquot_sum; +mod amicable_numbers; +mod area_of_polygon; +mod area_under_curve; +mod armstrong_number; +mod average; +mod baby_step_giant_step; +mod bell_numbers; +mod binary_exponentiation; +mod binomial_coefficient; +mod catalan_numbers; +mod ceil; +mod chinese_remainder_theorem; +mod collatz_sequence; +mod combinations; +mod cross_entropy_loss; +mod decimal_to_fraction; +mod doomsday; +mod elliptic_curve; +mod euclidean_distance; +mod exponential_linear_unit; +mod extended_euclidean_algorithm; +pub mod factorial; +mod factors; +mod fast_fourier_transform; +mod fast_power; +mod faster_perfect_numbers; +mod field; +mod frizzy_number; +mod gaussian_elimination; +mod gaussian_error_linear_unit; +mod gcd_of_n_numbers; +mod geometric_series; +mod greatest_common_divisor; +mod huber_loss; +mod infix_to_postfix; +mod interest; +mod interpolation; +mod interquartile_range; +mod karatsuba_multiplication; +mod lcm_of_n_numbers; +mod leaky_relu; +mod least_square_approx; +mod linear_sieve; +mod logarithm; +mod lucas_series; +mod matrix_ops; +mod mersenne_primes; +mod miller_rabin; +mod modular_exponential; +mod newton_raphson; +mod nthprime; +mod pascal_triangle; +mod perfect_cube; +mod perfect_numbers; +mod perfect_square; +mod pollard_rho; +mod postfix_evaluation; +mod prime_check; +mod prime_factors; +mod prime_numbers; +mod quadratic_residue; +mod random; +mod relu; +mod sieve_of_eratosthenes; +mod sigmoid; +mod signum; +mod simpsons_integration; +mod softmax; +mod sprague_grundy_theorem; +mod square_pyramidal_numbers; +mod square_root; +mod sum_of_digits; +mod sum_of_geometric_progression; +mod sum_of_harmonic_series; +mod sylvester_sequence; +mod tanh; +mod trapezoidal_integration; +mod trial_division; +mod trig_functions; +mod vector_cross_product; +mod zellers_congruence_algorithm; + +pub use self::abs::abs; +pub use self::aliquot_sum::aliquot_sum; +pub use self::amicable_numbers::amicable_pairs_under_n; +pub use self::area_of_polygon::area_of_polygon; +pub use self::area_under_curve::area_under_curve; +pub use self::armstrong_number::is_armstrong_number; +pub use self::average::{mean, median, mode}; +pub use self::baby_step_giant_step::baby_step_giant_step; +pub use self::bell_numbers::bell_number; +pub use self::binary_exponentiation::binary_exponentiation; +pub use self::binomial_coefficient::binom; +pub use self::catalan_numbers::init_catalan; +pub use self::ceil::ceil; +pub use self::chinese_remainder_theorem::chinese_remainder_theorem; +pub use self::collatz_sequence::sequence; +pub use self::combinations::combinations; +pub use self::cross_entropy_loss::cross_entropy_loss; +pub use self::decimal_to_fraction::decimal_to_fraction; +pub use self::doomsday::get_week_day; +pub use self::elliptic_curve::EllipticCurve; +pub use self::euclidean_distance::euclidean_distance; +pub use self::exponential_linear_unit::exponential_linear_unit; +pub use self::extended_euclidean_algorithm::extended_euclidean_algorithm; +pub use self::factorial::{factorial, factorial_bigmath, factorial_recursive}; +pub use self::factors::factors; +pub use self::fast_fourier_transform::{ + fast_fourier_transform, fast_fourier_transform_input_permutation, + inverse_fast_fourier_transform, +}; +pub use self::fast_power::fast_power; +pub use self::faster_perfect_numbers::generate_perfect_numbers; +pub use self::field::{Field, PrimeField}; +pub use self::frizzy_number::get_nth_frizzy; +pub use self::gaussian_elimination::gaussian_elimination; +pub use self::gaussian_error_linear_unit::gaussian_error_linear_unit; +pub use self::gcd_of_n_numbers::gcd; +pub use self::geometric_series::geometric_series; +pub use self::greatest_common_divisor::{ + greatest_common_divisor_iterative, greatest_common_divisor_recursive, + greatest_common_divisor_stein, +}; +pub use self::huber_loss::huber_loss; +pub use self::infix_to_postfix::infix_to_postfix; +pub use self::interest::{compound_interest, simple_interest}; +pub use self::interpolation::{lagrange_polynomial_interpolation, linear_interpolation}; +pub use self::interquartile_range::interquartile_range; +pub use self::karatsuba_multiplication::multiply; +pub use self::lcm_of_n_numbers::lcm; +pub use self::leaky_relu::leaky_relu; +pub use self::least_square_approx::least_square_approx; +pub use self::linear_sieve::LinearSieve; +pub use self::logarithm::log; +pub use self::lucas_series::dynamic_lucas_number; +pub use self::lucas_series::recursive_lucas_number; +pub use self::matrix_ops::Matrix; +pub use self::mersenne_primes::{get_mersenne_primes, is_mersenne_prime}; +pub use self::miller_rabin::{big_miller_rabin, miller_rabin}; +pub use self::modular_exponential::{mod_inverse, modular_exponential}; +pub use self::newton_raphson::find_root; +pub use self::nthprime::nthprime; +pub use self::pascal_triangle::pascal_triangle; +pub use self::perfect_cube::perfect_cube_binary_search; +pub use self::perfect_numbers::perfect_numbers; +pub use self::perfect_square::perfect_square; +pub use self::perfect_square::perfect_square_binary_search; +pub use self::pollard_rho::{pollard_rho_factorize, pollard_rho_get_one_factor}; +pub use self::postfix_evaluation::evaluate_postfix; +pub use self::prime_check::prime_check; +pub use self::prime_factors::prime_factors; +pub use self::prime_numbers::prime_numbers; +pub use self::quadratic_residue::{cipolla, tonelli_shanks}; +pub use self::random::PCG32; +pub use self::relu::relu; +pub use self::sieve_of_eratosthenes::sieve_of_eratosthenes; +pub use self::sigmoid::sigmoid; +pub use self::signum::signum; +pub use self::simpsons_integration::simpsons_integration; +pub use self::softmax::softmax; +pub use self::sprague_grundy_theorem::calculate_grundy_number; +pub use self::square_pyramidal_numbers::square_pyramidal_number; +pub use self::square_root::{fast_inv_sqrt, square_root}; +pub use self::sum_of_digits::{sum_digits_iterative, sum_digits_recursive}; +pub use self::sum_of_geometric_progression::sum_of_geometric_progression; +pub use self::sum_of_harmonic_series::sum_of_harmonic_progression; +pub use self::sylvester_sequence::sylvester; +pub use self::tanh::tanh; +pub use self::trapezoidal_integration::trapezoidal_integral; +pub use self::trial_division::trial_division; +pub use self::trig_functions::cosine; +pub use self::trig_functions::cosine_no_radian_arg; +pub use self::trig_functions::cotan; +pub use self::trig_functions::cotan_no_radian_arg; +pub use self::trig_functions::sine; +pub use self::trig_functions::sine_no_radian_arg; +pub use self::trig_functions::tan; +pub use self::trig_functions::tan_no_radian_arg; +pub use self::vector_cross_product::cross_product; +pub use self::vector_cross_product::vector_magnitude; +pub use self::zellers_congruence_algorithm::zellers_congruence_algorithm; diff --git a/src/math/modular_exponential.rs b/src/math/modular_exponential.rs new file mode 100644 index 00000000000..1e9a9f41cac --- /dev/null +++ b/src/math/modular_exponential.rs @@ -0,0 +1,136 @@ +/// Calculate the greatest common divisor (GCD) of two numbers and the +/// coefficients of Bézout's identity using the Extended Euclidean Algorithm. +/// +/// # Arguments +/// +/// * `a` - One of the numbers to find the GCD of +/// * `m` - The other number to find the GCD of +/// +/// # Returns +/// +/// A tuple (gcd, x1, x2) such that: +/// gcd - the greatest common divisor of a and m. +/// x1, x2 - the coefficients such that `a * x1 + m * x2` is equivalent to `gcd` modulo `m`. +pub fn gcd_extended(a: i64, m: i64) -> (i64, i64, i64) { + if a == 0 { + (m, 0, 1) + } else { + let (gcd, x1, x2) = gcd_extended(m % a, a); + let x = x2 - (m / a) * x1; + (gcd, x, x1) + } +} + +/// Find the modular multiplicative inverse of a number modulo `m`. +/// +/// # Arguments +/// +/// * `b` - The number to find the modular inverse of +/// * `m` - The modulus +/// +/// # Returns +/// +/// The modular inverse of `b` modulo `m`. +/// +/// # Panics +/// +/// Panics if the inverse does not exist (i.e., `b` and `m` are not coprime). +pub fn mod_inverse(b: i64, m: i64) -> i64 { + let (gcd, x, _) = gcd_extended(b, m); + if gcd != 1 { + panic!("Inverse does not exist"); + } else { + // Ensure the modular inverse is positive + (x % m + m) % m + } +} + +/// Perform modular exponentiation of a number raised to a power modulo `m`. +/// This function handles both positive and negative exponents. +/// +/// # Arguments +/// +/// * `base` - The base number to be raised to the `power` +/// * `power` - The exponent to raise the `base` to +/// * `modulus` - The modulus to perform the operation under +/// +/// # Returns +/// +/// The result of `base` raised to `power` modulo `modulus`. +pub fn modular_exponential(base: i64, mut power: i64, modulus: i64) -> i64 { + if modulus == 1 { + return 0; // Base case: any number modulo 1 is 0 + } + + // Adjust if the exponent is negative by finding the modular inverse + let mut base = if power < 0 { + mod_inverse(base, modulus) + } else { + base % modulus + }; + + let mut result = 1; // Initialize result + power = power.abs(); // Work with the absolute value of the exponent + + // Perform the exponentiation + while power > 0 { + if power & 1 == 1 { + result = (result * base) % modulus; + } + power >>= 1; // Divide the power by 2 + base = (base * base) % modulus; // Square the base + } + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_modular_exponential_positive() { + assert_eq!(modular_exponential(2, 3, 5), 3); // 2^3 % 5 = 8 % 5 = 3 + assert_eq!(modular_exponential(7, 2, 13), 10); // 7^2 % 13 = 49 % 13 = 10 + assert_eq!(modular_exponential(5, 5, 31), 25); // 5^5 % 31 = 3125 % 31 = 25 + assert_eq!(modular_exponential(10, 8, 11), 1); // 10^8 % 11 = 100000000 % 11 = 1 + assert_eq!(modular_exponential(123, 45, 67), 62); // 123^45 % 67 + } + + #[test] + fn test_modular_inverse() { + assert_eq!(mod_inverse(7, 13), 2); // Inverse of 7 mod 13 is 2 + assert_eq!(mod_inverse(5, 31), 25); // Inverse of 5 mod 31 is 25 + assert_eq!(mod_inverse(10, 11), 10); // Inverse of 10 mod 1 is 10 + assert_eq!(mod_inverse(123, 67), 6); // Inverse of 123 mod 67 is 6 + assert_eq!(mod_inverse(9, 17), 2); // Inverse of 9 mod 17 is 2 + } + + #[test] + fn test_modular_exponential_negative() { + assert_eq!( + modular_exponential(7, -2, 13), + mod_inverse(7, 13).pow(2) % 13 + ); // Inverse of 7 mod 13 is 2, 2^2 % 13 = 4 % 13 = 4 + assert_eq!( + modular_exponential(5, -5, 31), + mod_inverse(5, 31).pow(5) % 31 + ); // Inverse of 5 mod 31 is 25, 25^5 % 31 = 25 + assert_eq!( + modular_exponential(10, -8, 11), + mod_inverse(10, 11).pow(8) % 11 + ); // Inverse of 10 mod 11 is 10, 10^8 % 11 = 10 + assert_eq!( + modular_exponential(123, -5, 67), + mod_inverse(123, 67).pow(5) % 67 + ); // Inverse of 123 mod 67 is calculated via the function + } + + #[test] + fn test_modular_exponential_edge_cases() { + assert_eq!(modular_exponential(0, 0, 1), 0); // 0^0 % 1 should be 0 as the modulus is 1 + assert_eq!(modular_exponential(0, 10, 1), 0); // 0^n % 1 should be 0 for any n + assert_eq!(modular_exponential(10, 0, 1), 0); // n^0 % 1 should be 0 for any n + assert_eq!(modular_exponential(1, 1, 1), 0); // 1^1 % 1 should be 0 + assert_eq!(modular_exponential(-1, 2, 1), 0); // (-1)^2 % 1 should be 0 + } +} diff --git a/src/math/newton_raphson.rs b/src/math/newton_raphson.rs new file mode 100644 index 00000000000..5a21ef625a9 --- /dev/null +++ b/src/math/newton_raphson.rs @@ -0,0 +1,27 @@ +pub fn find_root(f: fn(f64) -> f64, fd: fn(f64) -> f64, guess: f64, iterations: i32) -> f64 { + let mut result = guess; + for _ in 0..iterations { + result = iteration(f, fd, result); + } + result +} + +pub fn iteration(f: fn(f64) -> f64, fd: fn(f64) -> f64, guess: f64) -> f64 { + guess - f(guess) / fd(guess) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn math_fn(x: f64) -> f64 { + x.cos() - (x * x * x) + } + fn math_fnd(x: f64) -> f64 { + -x.sin() - 3.0 * (x * x) + } + #[test] + fn basic() { + assert_eq!(find_root(math_fn, math_fnd, 0.5, 6), 0.8654740331016144); + } +} diff --git a/src/math/nthprime.rs b/src/math/nthprime.rs new file mode 100644 index 00000000000..1b0e93c855b --- /dev/null +++ b/src/math/nthprime.rs @@ -0,0 +1,58 @@ +// Generate the nth prime number. +// Algorithm is inspired by the optimized version of the Sieve of Eratosthenes. +pub fn nthprime(nth: u64) -> u64 { + let mut total_prime: u64 = 0; + let mut size_factor: u64 = 2; + + let mut s: u64 = nth * size_factor; + let mut primes: Vec = Vec::new(); + + let n: u64 = nth; + + while total_prime < n { + primes = get_primes(s).to_vec(); + + total_prime = primes[2..].iter().sum(); + size_factor += 1; + s = n * size_factor; + } + + count_prime(primes, n).unwrap() +} + +fn get_primes(s: u64) -> Vec { + let mut v: Vec = vec![1; s as usize]; + + for index in 2..s { + if v[index as usize] == 1 { + for j in index..s { + if index * j < s { + v[(index * j) as usize] = 0; + } else { + break; + } + } + } + } + v +} + +fn count_prime(primes: Vec, n: u64) -> Option { + let mut counter: u64 = 0; + for (i, prime) in primes.iter().enumerate().skip(2) { + counter += prime; + if counter == n { + return Some(i as u64); + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn my_test() { + assert_eq!(nthprime(100), 541u64); + } +} diff --git a/src/math/pascal_triangle.rs b/src/math/pascal_triangle.rs new file mode 100644 index 00000000000..34643029b6b --- /dev/null +++ b/src/math/pascal_triangle.rs @@ -0,0 +1,52 @@ +/// ## Pascal's triangle problem +/// +/// pascal_triangle(num_rows) returns the first num_rows of Pascal's triangle.\ +/// About Pascal's triangle: +/// +/// # Arguments: +/// * `num_rows`: number of rows of triangle +/// +/// # Complexity +/// - time complexity: O(n^2), +/// - space complexity: O(n^2), +pub fn pascal_triangle(num_rows: i32) -> Vec> { + let mut ans: Vec> = vec![]; + + for i in 1..=num_rows { + let mut vec: Vec = vec![1]; + + let mut res: i32 = 1; + for k in 1..i { + res *= i - k; + res /= k; + vec.push(res); + } + ans.push(vec); + } + + ans +} + +#[cfg(test)] +mod tests { + use super::pascal_triangle; + + #[test] + fn test() { + assert_eq!(pascal_triangle(3), vec![vec![1], vec![1, 1], vec![1, 2, 1]]); + assert_eq!( + pascal_triangle(4), + vec![vec![1], vec![1, 1], vec![1, 2, 1], vec![1, 3, 3, 1]] + ); + assert_eq!( + pascal_triangle(5), + vec![ + vec![1], + vec![1, 1], + vec![1, 2, 1], + vec![1, 3, 3, 1], + vec![1, 4, 6, 4, 1] + ] + ); + } +} diff --git a/src/math/perfect_cube.rs b/src/math/perfect_cube.rs new file mode 100644 index 00000000000..d4f2c7becef --- /dev/null +++ b/src/math/perfect_cube.rs @@ -0,0 +1,61 @@ +// Check if a number is a perfect cube using binary search. +pub fn perfect_cube_binary_search(n: i64) -> bool { + if n < 0 { + return perfect_cube_binary_search(-n); + } + + // Initialize left and right boundaries for binary search. + let mut left = 0; + let mut right = n.abs(); // Use the absolute value to handle negative numbers + + // Binary search loop to find the cube root. + while left <= right { + // Calculate the mid-point. + let mid = left + (right - left) / 2; + // Calculate the cube of the mid-point. + let cube = mid * mid * mid; + + // Check if the cube equals the original number. + match cube.cmp(&n) { + std::cmp::Ordering::Equal => return true, + std::cmp::Ordering::Less => left = mid + 1, + std::cmp::Ordering::Greater => right = mid - 1, + } + } + + // If no cube root is found, return false. + false +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_perfect_cube { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (n, expected) = $inputs; + assert_eq!(perfect_cube_binary_search(n), expected); + assert_eq!(perfect_cube_binary_search(-n), expected); + } + )* + } + } + + test_perfect_cube! { + num_0_a_perfect_cube: (0, true), + num_1_is_a_perfect_cube: (1, true), + num_27_is_a_perfect_cube: (27, true), + num_64_is_a_perfect_cube: (64, true), + num_8_is_a_perfect_cube: (8, true), + num_2_is_not_a_perfect_cube: (2, false), + num_3_is_not_a_perfect_cube: (3, false), + num_4_is_not_a_perfect_cube: (4, false), + num_5_is_not_a_perfect_cube: (5, false), + num_999_is_not_a_perfect_cube: (999, false), + num_1000_is_a_perfect_cube: (1000, true), + num_1001_is_not_a_perfect_cube: (1001, false), + } +} diff --git a/src/math/perfect_numbers.rs b/src/math/perfect_numbers.rs new file mode 100644 index 00000000000..0d819d2b2f1 --- /dev/null +++ b/src/math/perfect_numbers.rs @@ -0,0 +1,47 @@ +pub fn is_perfect_number(num: usize) -> bool { + let mut sum = 0; + + for i in 1..num - 1 { + if num % i == 0 { + sum += i; + } + } + + num == sum +} + +pub fn perfect_numbers(max: usize) -> Vec { + let mut result: Vec = Vec::new(); + + // It is not known if there are any odd perfect numbers, so we go around all the numbers. + for i in 1..=max { + if is_perfect_number(i) { + result.push(i); + } + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basic() { + assert!(is_perfect_number(6)); + assert!(is_perfect_number(28)); + assert!(is_perfect_number(496)); + assert!(is_perfect_number(8128)); + + assert!(!is_perfect_number(5)); + assert!(!is_perfect_number(86)); + assert!(!is_perfect_number(497)); + assert!(!is_perfect_number(8120)); + + assert_eq!(perfect_numbers(10), vec![6]); + assert_eq!(perfect_numbers(100), vec![6, 28]); + assert_eq!(perfect_numbers(496), vec![6, 28, 496]); + assert_eq!(perfect_numbers(1000), vec![6, 28, 496]); + } +} diff --git a/src/math/perfect_square.rs b/src/math/perfect_square.rs new file mode 100644 index 00000000000..7b0f69976c4 --- /dev/null +++ b/src/math/perfect_square.rs @@ -0,0 +1,57 @@ +// Author : cyrixninja +// Perfect Square : Checks if a number is perfect square number or not +// https://en.wikipedia.org/wiki/Perfect_square +pub fn perfect_square(num: i32) -> bool { + if num < 0 { + return false; + } + let sqrt_num = (num as f64).sqrt() as i32; + sqrt_num * sqrt_num == num +} + +pub fn perfect_square_binary_search(n: i32) -> bool { + if n < 0 { + return false; + } + + let mut left = 0; + let mut right = n; + + while left <= right { + let mid = i32::midpoint(left, right); + let mid_squared = mid * mid; + + match mid_squared.cmp(&n) { + std::cmp::Ordering::Equal => return true, + std::cmp::Ordering::Greater => right = mid - 1, + std::cmp::Ordering::Less => left = mid + 1, + } + } + + false +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_perfect_square() { + assert!(perfect_square(9)); + assert!(perfect_square(81)); + assert!(perfect_square(4)); + assert!(perfect_square(0)); + assert!(!perfect_square(3)); + assert!(!perfect_square(-19)); + } + + #[test] + fn test_perfect_square_binary_search() { + assert!(perfect_square_binary_search(9)); + assert!(perfect_square_binary_search(81)); + assert!(perfect_square_binary_search(4)); + assert!(perfect_square_binary_search(0)); + assert!(!perfect_square_binary_search(3)); + assert!(!perfect_square_binary_search(-19)); + } +} diff --git a/src/math/pollard_rho.rs b/src/math/pollard_rho.rs new file mode 100644 index 00000000000..1ba7481e989 --- /dev/null +++ b/src/math/pollard_rho.rs @@ -0,0 +1,279 @@ +use super::miller_rabin; + +struct LinearCongruenceGenerator { + // modulus as 2 ^ 32 + multiplier: u32, + increment: u32, + state: u32, +} + +impl LinearCongruenceGenerator { + fn new(multiplier: u32, increment: u32, state: u32) -> Self { + Self { + multiplier, + increment, + state, + } + } + fn next(&mut self) -> u32 { + self.state = (self.multiplier as u64 * self.state as u64 + self.increment as u64) as u32; + self.state + } + fn get_64bits(&mut self) -> u64 { + ((self.next() as u64) << 32) | (self.next() as u64) + } +} + +fn gcd(mut a: u64, mut b: u64) -> u64 { + while a != 0 { + let tmp = b % a; + b = a; + a = tmp; + } + b +} + +#[inline] +fn advance(x: u128, c: u64, number: u64) -> u128 { + ((x * x) + c as u128) % number as u128 +} + +fn pollard_rho_customizable( + number: u64, + x0: u64, + c: u64, + iterations_before_check: u32, + iterations_cutoff: u32, +) -> u64 { + /* + Here we are using Brent's method for finding cycle. + It is generally faster because we will not use `advance` function as often + as Floyd's method. + We also wait to do a few iterations before calculating the GCD, because + it is an expensive function. We will correct for overshooting later. + This function may return either 1, `number` or a proper divisor of `number` + */ + let mut x = x0 as u128; // tortoise + let mut x_start = 0_u128; // to save the starting tortoise if we overshoot + let mut y = 0_u128; // hare + let mut remainder = 1_u128; + let mut current_gcd = 1_u64; + let mut max_iterations = 1_u32; + while current_gcd == 1 { + y = x; + for _ in 1..max_iterations { + x = advance(x, c, number); + } + let mut big_iteration = 0_u32; + while big_iteration < max_iterations && current_gcd == 1 { + x_start = x; + let mut small_iteration = 0_u32; + while small_iteration < iterations_before_check + && small_iteration < (max_iterations - big_iteration) + { + small_iteration += 1; + x = advance(x, c, number); + let diff = x.abs_diff(y); + remainder = (remainder * diff) % number as u128; + } + current_gcd = gcd(remainder as u64, number); + big_iteration += iterations_before_check; + } + max_iterations *= 2; + if max_iterations > iterations_cutoff { + break; + } + } + if current_gcd == number { + while current_gcd == 1 { + x_start = advance(x_start, c, number); + current_gcd = gcd(x_start.abs_diff(y) as u64, number); + } + } + current_gcd +} + +/* +Note: using this function with `check_is_prime` = false +and a prime number will result in an infinite loop. + +RNG's internal state is represented as `seed`. It is +advisable (but not mandatory) to reuse the saved seed value +In subsequent calls to this function. + */ +pub fn pollard_rho_get_one_factor(number: u64, seed: &mut u32, check_is_prime: bool) -> u64 { + // LCG parameters from wikipedia + let mut rng = LinearCongruenceGenerator::new(1103515245, 12345, *seed); + if number <= 1 { + return number; + } + if check_is_prime { + let mut bases = vec![2u64, 3, 5, 7]; + if number > 3_215_031_000 { + bases.append(&mut vec![11, 13, 17, 19, 23, 29, 31, 37]); + } + if miller_rabin(number, &bases) == 0 { + return number; + } + } + let mut factor = 1u64; + while factor == 1 || factor == number { + let x = rng.get_64bits(); + let c = rng.get_64bits(); + factor = pollard_rho_customizable( + number, + (x % (number - 3)) + 2, + (c % (number - 2)) + 1, + 32, + 1 << 18, // This shouldn't take much longer than number ^ 0.25 + ); + // These numbers were selected based on local testing. + // For specific applications there maybe better choices. + } + *seed = rng.state; + factor +} + +fn get_small_factors(mut number: u64, primes: &[usize]) -> (u64, Vec) { + let mut result: Vec = Vec::new(); + for p in primes { + while (number % *p as u64) == 0 { + number /= *p as u64; + result.push(*p as u64); + } + } + (number, result) +} + +fn factor_using_mpf(mut number: usize, mpf: &[usize]) -> Vec { + let mut result = Vec::new(); + while number > 1 { + result.push(mpf[number] as u64); + number /= mpf[number]; + } + result +} + +/* +`primes` and `minimum_prime_factors` use usize because so does +LinearSieve implementation in this repository + */ +pub fn pollard_rho_factorize( + mut number: u64, + seed: &mut u32, + primes: &[usize], + minimum_prime_factors: &[usize], +) -> Vec { + if number <= 1 { + return vec![]; + } + let mut result: Vec = Vec::new(); + { + // Create a new scope to keep the outer scope clean + let (rem, mut res) = get_small_factors(number, primes); + number = rem; + result.append(&mut res); + } + if number == 1 { + return result; + } + let mut to_be_factored = vec![number]; + while let Some(last) = to_be_factored.pop() { + if last < minimum_prime_factors.len() as u64 { + result.append(&mut factor_using_mpf(last as usize, minimum_prime_factors)); + continue; + } + let fact = pollard_rho_get_one_factor(last, seed, true); + if fact == last { + result.push(last); + continue; + } + to_be_factored.push(fact); + to_be_factored.push(last / fact); + } + result.sort_unstable(); + result +} + +#[cfg(test)] +mod test { + use super::super::LinearSieve; + use super::*; + + fn check_is_proper_factor(number: u64, factor: u64) -> bool { + factor > 1 && factor < number && ((number % factor) == 0) + } + + fn check_factorization(number: u64, factors: &[u64]) -> bool { + let bases = vec![2u64, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]; + let mut prod = 1_u64; + let mut prime_check = 0_u64; + for p in factors { + prod *= *p; + prime_check |= miller_rabin(*p, &bases); + } + prime_check == 0 && prod == number + } + + #[test] + fn one_factor() { + // a few small cases + let mut sieve = LinearSieve::new(); + sieve.prepare(1e5 as usize).unwrap(); + let numbers = vec![1235, 239874233, 4353234, 456456, 120983]; + let mut seed = 314159_u32; // first digits of pi; nothing up my sleeve + for num in numbers { + let factor = pollard_rho_get_one_factor(num, &mut seed, true); + assert!(check_is_proper_factor(num, factor)); + let factor = pollard_rho_get_one_factor(num, &mut seed, false); + assert!(check_is_proper_factor(num, factor)); + assert!(check_factorization( + num, + &pollard_rho_factorize(num, &mut seed, &sieve.primes, &sieve.minimum_prime_factor) + )); + } + // check if it goes into infinite loop if `number` is prime + let numbers = vec![ + 2, 3, 5, 7, 11, 13, 101, 998244353, 1000000007, 1000000009, 1671398671, 1652465729, + 1894404511, 1683402997, 1661963047, 1946039987, 2071566551, 1867816303, 1952199377, + 1622379469, 1739317499, 1775433631, 1994828917, 1818930719, 1672996277, + ]; + for num in numbers { + assert_eq!(pollard_rho_get_one_factor(num, &mut seed, true), num); + assert!(check_factorization( + num, + &pollard_rho_factorize(num, &mut seed, &sieve.primes, &sieve.minimum_prime_factor) + )); + } + } + #[test] + fn big_numbers() { + // Bigger cases: + // Each of these numbers is a product of two 31 bit primes + // This shouldn't take more than a 10ms per number on a modern PC + let mut seed = 314159_u32; // first digits of pi; nothing up my sleeve + let numbers: Vec = vec![ + 2761929023323646159, + 3189046231347719467, + 3234246546378360389, + 3869305776707280953, + 3167208188639390813, + 3088042782711408869, + 3628455596280801323, + 2953787574901819241, + 3909561575378030219, + 4357328471891213977, + 2824368080144930999, + 3348680054093203003, + 2704267100962222513, + 2916169237307181179, + 3669851121098875703, + ]; + for num in numbers { + assert!(check_factorization( + num, + &pollard_rho_factorize(num, &mut seed, &[], &[]) + )); + } + } +} diff --git a/src/math/postfix_evaluation.rs b/src/math/postfix_evaluation.rs new file mode 100644 index 00000000000..27bf4e3eacc --- /dev/null +++ b/src/math/postfix_evaluation.rs @@ -0,0 +1,105 @@ +//! This module provides a function to evaluate postfix (Reverse Polish Notation) expressions. +//! Postfix notation is a mathematical notation in which every operator follows all of its operands. +//! +//! The evaluator supports the four basic arithmetic operations: addition, subtraction, multiplication, and division. +//! It handles errors such as division by zero, invalid operators, insufficient operands, and invalid postfix expressions. + +/// Enumeration of errors that can occur when evaluating a postfix expression. +#[derive(Debug, PartialEq)] +pub enum PostfixError { + DivisionByZero, + InvalidOperator, + InsufficientOperands, + InvalidExpression, +} + +/// Evaluates a postfix expression and returns the result or an error. +/// +/// # Arguments +/// +/// * `expression` - A string slice that contains the postfix expression to be evaluated. +/// The tokens (numbers and operators) should be separated by whitespace. +/// +/// # Returns +/// +/// * `Ok(isize)` if the expression is valid and evaluates to an integer. +/// * `Err(PostfixError)` if the expression is invalid or encounters errors during evaluation. +/// +/// # Errors +/// +/// * `PostfixError::DivisionByZero` - If a division by zero is attempted. +/// * `PostfixError::InvalidOperator` - If an unknown operator is encountered. +/// * `PostfixError::InsufficientOperands` - If there are not enough operands for an operator. +/// * `PostfixError::InvalidExpression` - If the expression is malformed (e.g., multiple values are left on the stack). +pub fn evaluate_postfix(expression: &str) -> Result { + let mut stack: Vec = Vec::new(); + + for token in expression.split_whitespace() { + if let Ok(number) = token.parse::() { + // If the token is a number, push it onto the stack. + stack.push(number); + } else { + // If the token is an operator, pop the top two values from the stack, + // apply the operator, and push the result back onto the stack. + if let (Some(b), Some(a)) = (stack.pop(), stack.pop()) { + match token { + "+" => stack.push(a + b), + "-" => stack.push(a - b), + "*" => stack.push(a * b), + "/" => { + if b == 0 { + return Err(PostfixError::DivisionByZero); + } + stack.push(a / b); + } + _ => return Err(PostfixError::InvalidOperator), + } + } else { + return Err(PostfixError::InsufficientOperands); + } + } + } + // The final result should be the only element on the stack. + if stack.len() == 1 { + Ok(stack[0]) + } else { + Err(PostfixError::InvalidExpression) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! postfix_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(evaluate_postfix(input), expected); + } + )* + } + } + + postfix_tests! { + test_addition_of_two_numbers: ("2 3 +", Ok(5)), + test_multiplication_and_addition: ("5 2 * 4 +", Ok(14)), + test_simple_division: ("10 2 /", Ok(5)), + test_operator_without_operands: ("+", Err(PostfixError::InsufficientOperands)), + test_division_by_zero_error: ("5 0 /", Err(PostfixError::DivisionByZero)), + test_invalid_operator_in_expression: ("2 3 #", Err(PostfixError::InvalidOperator)), + test_missing_operator_for_expression: ("2 3", Err(PostfixError::InvalidExpression)), + test_extra_operands_in_expression: ("2 3 4 +", Err(PostfixError::InvalidExpression)), + test_empty_expression_error: ("", Err(PostfixError::InvalidExpression)), + test_single_number_expression: ("42", Ok(42)), + test_addition_of_negative_numbers: ("-3 -2 +", Ok(-5)), + test_complex_expression_with_multiplication_and_addition: ("3 5 8 * 7 + *", Ok(141)), + test_expression_with_extra_whitespace: (" 3 4 + ", Ok(7)), + test_valid_then_invalid_operator: ("5 2 + 1 #", Err(PostfixError::InvalidOperator)), + test_first_division_by_zero: ("5 0 / 6 0 /", Err(PostfixError::DivisionByZero)), + test_complex_expression_with_multiple_operators: ("5 1 2 + 4 * + 3 -", Ok(14)), + test_expression_with_only_whitespace: (" ", Err(PostfixError::InvalidExpression)), + } +} diff --git a/src/math/prime_check.rs b/src/math/prime_check.rs new file mode 100644 index 00000000000..4902a65dbf7 --- /dev/null +++ b/src/math/prime_check.rs @@ -0,0 +1,33 @@ +pub fn prime_check(num: usize) -> bool { + if (num > 1) & (num < 4) { + return true; + } else if (num < 2) || (num % 2 == 0) { + return false; + } + + let stop: usize = (num as f64).sqrt() as usize + 1; + for i in (3..stop).step_by(2) { + if num % i == 0 { + return false; + } + } + true +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basic() { + assert!(prime_check(3)); + assert!(prime_check(7)); + assert!(prime_check(11)); + assert!(prime_check(2003)); + + assert!(!prime_check(4)); + assert!(!prime_check(6)); + assert!(!prime_check(21)); + assert!(!prime_check(2004)); + } +} diff --git a/src/math/prime_factors.rs b/src/math/prime_factors.rs new file mode 100644 index 00000000000..7b89b09c9b8 --- /dev/null +++ b/src/math/prime_factors.rs @@ -0,0 +1,36 @@ +// Finds the prime factors of a number in increasing order, with repetition. + +pub fn prime_factors(n: u64) -> Vec { + let mut i = 2; + let mut n = n; + let mut factors = Vec::new(); + while i * i <= n { + if n % i != 0 { + if i != 2 { + i += 1; + } + i += 1; + } else { + n /= i; + factors.push(i); + } + } + if n > 1 { + factors.push(n); + } + factors +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn it_works() { + assert_eq!(prime_factors(0), vec![]); + assert_eq!(prime_factors(1), vec![]); + assert_eq!(prime_factors(11), vec![11]); + assert_eq!(prime_factors(25), vec![5, 5]); + assert_eq!(prime_factors(33), vec![3, 11]); + assert_eq!(prime_factors(2560), vec![2, 2, 2, 2, 2, 2, 2, 2, 2, 5]); + } +} diff --git a/src/math/prime_numbers.rs b/src/math/prime_numbers.rs new file mode 100644 index 00000000000..f045133a168 --- /dev/null +++ b/src/math/prime_numbers.rs @@ -0,0 +1,39 @@ +pub fn prime_numbers(max: usize) -> Vec { + let mut result: Vec = Vec::new(); + + if max >= 2 { + result.push(2) + } + for i in (3..=max).step_by(2) { + let stop: usize = (i as f64).sqrt() as usize + 1; + let mut status = true; + + for j in (3..stop).step_by(2) { + if i % j == 0 { + status = false; + break; + } + } + if status { + result.push(i) + } + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basic() { + assert_eq!(prime_numbers(0), vec![]); + assert_eq!(prime_numbers(11), vec![2, 3, 5, 7, 11]); + assert_eq!(prime_numbers(25), vec![2, 3, 5, 7, 11, 13, 17, 19, 23]); + assert_eq!( + prime_numbers(33), + vec![2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31] + ); + } +} diff --git a/src/math/quadratic_residue.rs b/src/math/quadratic_residue.rs new file mode 100644 index 00000000000..e3f2e6b819b --- /dev/null +++ b/src/math/quadratic_residue.rs @@ -0,0 +1,268 @@ +/// Cipolla algorithm +/// +/// Solving quadratic residue problem: +/// x^2 = a (mod p) , p is an odd prime +/// with O(M*log(n)) time complexity, M depends on the complexity of complex numbers multiplication. +/// +/// Wikipedia reference: https://en.wikipedia.org/wiki/Cipolla%27s_algorithm +/// When a is the primitive root modulo n, the answer is unique. +/// Otherwise it will return the smallest positive solution +use std::rc::Rc; +use std::time::{SystemTime, UNIX_EPOCH}; + +use rand::Rng; + +use super::{fast_power, PCG32}; + +#[derive(Debug)] +struct CustomFiniteField { + modulus: u64, + i_square: u64, +} + +impl CustomFiniteField { + pub fn new(modulus: u64, i_square: u64) -> Self { + Self { modulus, i_square } + } +} + +#[derive(Clone, Debug)] +struct CustomComplexNumber { + real: u64, + imag: u64, + f: Rc, +} + +impl CustomComplexNumber { + pub fn new(real: u64, imag: u64, f: Rc) -> Self { + Self { real, imag, f } + } + + pub fn mult_other(&mut self, rhs: &Self) { + let tmp = (self.imag * rhs.real + self.real * rhs.imag) % self.f.modulus; + self.real = (self.real * rhs.real + + ((self.imag * rhs.imag) % self.f.modulus) * self.f.i_square) + % self.f.modulus; + self.imag = tmp; + } + + pub fn mult_self(&mut self) { + let tmp = (self.imag * self.real + self.real * self.imag) % self.f.modulus; + self.real = (self.real * self.real + + ((self.imag * self.imag) % self.f.modulus) * self.f.i_square) + % self.f.modulus; + self.imag = tmp; + } + + pub fn fast_power(mut base: Self, mut power: u64) -> Self { + let mut result = CustomComplexNumber::new(1, 0, base.f.clone()); + while power != 0 { + if (power & 1) != 0 { + result.mult_other(&base); // result *= base; + } + base.mult_self(); // base *= base; + power >>= 1; + } + result + } +} + +fn is_residue(x: u64, modulus: u64) -> bool { + let power = (modulus - 1) >> 1; + x != 0 && fast_power(x as usize, power as usize, modulus as usize) == 1 +} + +/// The Legendre symbol `(a | p)` +/// +/// Returns 0 if a = 0 mod p, 1 if a is a square mod p, -1 if it not a square mod p. +/// +/// +pub fn legendre_symbol(a: u64, odd_prime: u64) -> i64 { + debug_assert!(odd_prime % 2 != 0, "prime must be odd"); + if a == 0 { + 0 + } else if is_residue(a, odd_prime) { + 1 + } else { + -1 + } +} + +// return two solutions (x1, x2) for Quadratic Residue problem x^2 = a (mod p), where p is an odd prime +// if a is Quadratic Nonresidues, return None +pub fn cipolla(a: u32, p: u32, seed: Option) -> Option<(u32, u32)> { + // The params should be kept in u32 range for multiplication overflow issue + // But inside we use u64 for convenience + let a = a as u64; + let p = p as u64; + if a == 0 { + return Some((0, 0)); + } + if !is_residue(a, p) { + return None; + } + let seed = match seed { + Some(seed) => seed, + None => SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(), + }; + let mut rng = PCG32::new_default(seed); + let r = loop { + let r = rng.get_u64() % p; + if r == 0 || !is_residue((p + r * r - a) % p, p) { + break r; + } + }; + let filed = Rc::new(CustomFiniteField::new(p, (p + r * r - a) % p)); + let comp = CustomComplexNumber::new(r, 1, filed); + let power = (p + 1) >> 1; + let x0 = CustomComplexNumber::fast_power(comp, power).real as u32; + let x1 = p as u32 - x0; + if x0 < x1 { + Some((x0, x1)) + } else { + Some((x1, x0)) + } +} + +/// Returns one of the two possible solutions of _x² = a mod p_, if any. +/// +/// The other solution is _-x mod p_. If there is no solution, returns `None`. +/// +/// Reference: H. Cohen, _A course in computational algebraic number theory_, Algorithm 1.4.3 +/// +/// ## Implementation details +/// +/// To avoid multiplication overflows, internally the algorithm uses the `128`-bit arithmetic. +/// +/// Also see [`cipolla`]. +pub fn tonelli_shanks(a: i64, odd_prime: u64) -> Option { + let p: u128 = odd_prime as u128; + let e = (p - 1).trailing_zeros(); + let q = (p - 1) >> e; // p = 2^e * q, with q odd + + let a = if a < 0 { + a.rem_euclid(p as i64) as u128 + } else { + a as u128 + }; + + let power_mod_p = |b, e| fast_power(b as usize, e as usize, p as usize) as u128; + + // find generator: choose a random non-residue n mod p + let mut rng = rand::rng(); + let n = loop { + let n = rng.random_range(0..p); + if legendre_symbol(n as u64, p as u64) == -1 { + break n; + } + }; + let z = power_mod_p(n, q); + + // init + let mut y = z; + let mut r = e; + let mut x = power_mod_p(a, (q - 1) / 2) % p; + let mut b = (a * x * x) % p; + x = (a * x) % p; + + while b % p != 1 { + // find exponent + let m = (1..r) + .scan(b, |prev, m| { + *prev = (*prev * *prev) % p; + Some((m, *prev == 1)) + }) + .find_map(|(m, cond)| cond.then_some(m)); + let Some(m) = m else { + return None; // non-residue + }; + + // reduce exponent + let t = power_mod_p(y as u128, 2_u128.pow(r - m - 1)); + y = (t * t) % p; + r = m; + x = (x * t) % p; + b = (b * y) % p; + } + + Some(x as u64) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn tonelli_shanks_residues(x: u64, odd_prime: u64) -> Option<(u64, u64)> { + let x = tonelli_shanks(x as i64, odd_prime)?; + let x2 = (-(x as i64)).rem_euclid(odd_prime as i64) as u64; + Some(if x < x2 { (x, x2) } else { (x2, x) }) + } + + #[test] + fn cipolla_small_numbers() { + assert_eq!(cipolla(1, 43, None), Some((1, 42))); + assert_eq!(cipolla(2, 23, None), Some((5, 18))); + assert_eq!(cipolla(17, 83, Some(42)), Some((10, 73))); + } + + #[test] + fn tonelli_shanks_small_numbers() { + assert_eq!(tonelli_shanks_residues(1, 43).unwrap(), (1, 42)); + assert_eq!(tonelli_shanks_residues(2, 23).unwrap(), (5, 18)); + assert_eq!(tonelli_shanks_residues(17, 83).unwrap(), (10, 73)); + } + + #[test] + fn cipolla_random_numbers() { + assert_eq!(cipolla(392203, 852167, None), Some((413252, 438915))); + assert_eq!( + cipolla(379606557, 425172197, None), + Some((143417827, 281754370)) + ); + assert_eq!( + cipolla(585251669, 892950901, None), + Some((192354555, 700596346)) + ); + assert_eq!( + cipolla(404690348, 430183399, Some(19260817)), + Some((57227138, 372956261)) + ); + assert_eq!( + cipolla(210205747, 625380647, Some(998244353)), + Some((76810367, 548570280)) + ); + } + + #[test] + fn tonelli_shanks_random_numbers() { + assert_eq!( + tonelli_shanks_residues(392203, 852167), + Some((413252, 438915)) + ); + assert_eq!( + tonelli_shanks_residues(379606557, 425172197), + Some((143417827, 281754370)) + ); + assert_eq!( + tonelli_shanks_residues(585251669, 892950901), + Some((192354555, 700596346)) + ); + assert_eq!( + tonelli_shanks_residues(404690348, 430183399), + Some((57227138, 372956261)) + ); + assert_eq!( + tonelli_shanks_residues(210205747, 625380647), + Some((76810367, 548570280)) + ); + } + + #[test] + fn no_answer() { + assert_eq!(cipolla(650927, 852167, None), None); + assert_eq!(tonelli_shanks(650927, 852167), None); + } +} diff --git a/src/math/random.rs b/src/math/random.rs new file mode 100644 index 00000000000..de218035484 --- /dev/null +++ b/src/math/random.rs @@ -0,0 +1,143 @@ +/* +Permuted Congruential Generator +https://en.wikipedia.org/wiki/Permuted_congruential_generator + +Note that this is _NOT_ intended for serious applications. Use this generator +at your own risk and only use your own values instead of the default ones if +you really know what you are doing. + */ +pub struct PCG32 { + state: u64, + multiplier: u64, + increment: u64, +} + +pub const PCG32_MULTIPLIER: u64 = 6364136223846793005_u64; +pub const PCG32_INCREMENT: u64 = 1442695040888963407_u64; + +pub struct IterMut<'a> { + pcg: &'a mut PCG32, +} + +impl PCG32 { + /// `stream` should be less than 1 << 63 + pub fn new(seed: u64, multiplier: u64, stream: u64) -> Self { + // We should make sure that increment is odd + let increment = (stream << 1) | 1; + let mut pcg = PCG32 { + state: seed.wrapping_add(increment), + multiplier, + increment, + }; + pcg.next(); + pcg + } + pub fn new_default(seed: u64) -> Self { + let multiplier = PCG32_MULTIPLIER; + let increment = PCG32_INCREMENT; + let mut pcg = PCG32 { + state: seed.wrapping_add(increment), + multiplier, + increment, + }; + pcg.next(); + pcg + } + #[inline] + pub fn next(&mut self) { + self.state = self + .state + .wrapping_mul(self.multiplier) + .wrapping_add(self.increment); + } + #[inline] + /// Advance the PCG by `delta` steps in O(lg(`delta`)) time. By passing + /// a negative i64 as u64, it can go back too. + pub fn advance(&mut self, mut delta: u64) { + let mut acc_mult = 1u64; + let mut acc_incr = 0u64; + let mut curr_mlt = self.multiplier; + let mut curr_inc = self.increment; + while delta > 0 { + if delta & 1 != 0 { + acc_mult = acc_mult.wrapping_mul(curr_mlt); + acc_incr = acc_incr.wrapping_mul(curr_mlt).wrapping_add(curr_inc); + } + curr_inc = curr_mlt.wrapping_add(1).wrapping_mul(curr_inc); + curr_mlt = curr_mlt.wrapping_mul(curr_mlt); + delta >>= 1; + } + self.state = acc_mult.wrapping_mul(self.state).wrapping_add(acc_incr); + } + #[inline] + pub fn get_u32(&mut self) -> u32 { + let mut x = self.state; + let count = (x >> 59) as u32; + + self.next(); + + x ^= x >> 18; + ((x >> 27) as u32).rotate_right(count) + } + #[inline] + pub fn get_u64(&mut self) -> u64 { + self.get_u32() as u64 ^ ((self.get_u32() as u64) << 32) + } + #[inline] + pub fn get_u16(&mut self) -> (u16, u16) { + let res = self.get_u32(); + (res as u16, (res >> 16) as u16) + } + #[inline] + pub fn get_u8(&mut self) -> (u8, u8, u8, u8) { + let res = self.get_u32(); + ( + res as u8, + (res >> 8) as u8, + (res >> 16) as u8, + (res >> 24) as u8, + ) + } + #[inline] + pub fn get_state(&self) -> u64 { + self.state + } + pub fn iter_mut(&mut self) -> IterMut { + IterMut { pcg: self } + } +} + +impl Iterator for IterMut<'_> { + type Item = u32; + fn next(&mut self) -> Option { + Some(self.pcg.get_u32()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn no_birthday() { + // If the distribution is not almost uniform, the probability of + // birthday paradox increases. For n=2^32 and k=1e5, the probability + // of not having a collision is about (1 - (k+1)/n) ^ (k/2) which is + // 0.3121 for this (n, k). + // So this test is a (dumb) test for distribution, and for speed. This + // is only basic sanity checking, as the actual algorithm was + // rigorously tested by others before. + let numbers = 1e5 as usize; + let mut pcg = PCG32::new_default(314159); + let mut pcg2 = PCG32::new_default(314159); + assert_eq!(pcg.get_u32(), pcg2.get_u32()); + let mut randoms: Vec = pcg.iter_mut().take(numbers).collect::>(); + pcg2.advance(1000); + assert_eq!(pcg2.get_u32(), randoms[1000]); + pcg2.advance((-1001_i64) as u64); + assert_eq!(pcg2.get_u32(), randoms[0]); + randoms.sort_unstable(); + randoms.dedup(); + assert_eq!(randoms.len(), numbers); + } +} diff --git a/src/math/relu.rs b/src/math/relu.rs new file mode 100644 index 00000000000..dacc2fe8289 --- /dev/null +++ b/src/math/relu.rs @@ -0,0 +1,34 @@ +//Rust implementation of the ReLU (rectified linear unit) activation function. +//The formula for ReLU is quite simple really: (if x>0 -> x, else -> 0) +//More information on the concepts of ReLU can be found here: +//https://en.wikipedia.org/wiki/Rectifier_(neural_networks) + +//The function below takes a reference to a mutable Vector as an argument +//and returns the vector with 'ReLU' applied to all values. +//Of course, these functions can be changed by the developer so that the input vector isn't manipulated. +//This is simply an implemenation of the formula. + +pub fn relu(array: &mut Vec) -> &mut Vec { + //note that these calculations are assuming the Vector values consists of real numbers in radians + for value in &mut *array { + if value <= &mut 0. { + *value = 0.; + } + } + + array +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_relu() { + let mut test: Vec = Vec::from([1.0, 0.5, -1.0, 0.0, 0.3]); + assert_eq!( + relu(&mut test), + &mut Vec::::from([1.0, 0.5, 0.0, 0.0, 0.3]) + ); + } +} diff --git a/src/math/sieve_of_eratosthenes.rs b/src/math/sieve_of_eratosthenes.rs new file mode 100644 index 00000000000..ed331845317 --- /dev/null +++ b/src/math/sieve_of_eratosthenes.rs @@ -0,0 +1,120 @@ +/// Implements the Sieve of Eratosthenes algorithm to find all prime numbers up to a given limit. +/// +/// # Arguments +/// +/// * `num` - The upper limit up to which to find prime numbers (inclusive). +/// +/// # Returns +/// +/// A vector containing all prime numbers up to the specified limit. +pub fn sieve_of_eratosthenes(num: usize) -> Vec { + let mut result: Vec = Vec::new(); + if num >= 2 { + let mut sieve: Vec = vec![true; num + 1]; + + // 0 and 1 are not prime numbers + sieve[0] = false; + sieve[1] = false; + + let end: usize = (num as f64).sqrt() as usize; + + // Mark non-prime numbers in the sieve and collect primes up to `end` + update_sieve(&mut sieve, end, num, &mut result); + + // Collect remaining primes beyond `end` + result.extend(extract_remaining_primes(&sieve, end + 1)); + } + result +} + +/// Marks non-prime numbers in the sieve and collects prime numbers up to `end`. +/// +/// # Arguments +/// +/// * `sieve` - A mutable slice of booleans representing the sieve. +/// * `end` - The square root of the upper limit, used to optimize the algorithm. +/// * `num` - The upper limit up to which to mark non-prime numbers. +/// * `result` - A mutable vector to store the prime numbers. +fn update_sieve(sieve: &mut [bool], end: usize, num: usize, result: &mut Vec) { + for start in 2..=end { + if sieve[start] { + result.push(start); // Collect prime numbers up to `end` + for i in (start * start..=num).step_by(start) { + sieve[i] = false; + } + } + } +} + +/// Extracts remaining prime numbers from the sieve beyond the given start index. +/// +/// # Arguments +/// +/// * `sieve` - A slice of booleans representing the sieve with non-prime numbers marked as false. +/// * `start` - The index to start checking for primes (inclusive). +/// +/// # Returns +/// +/// A vector containing all remaining prime numbers extracted from the sieve. +fn extract_remaining_primes(sieve: &[bool], start: usize) -> Vec { + sieve[start..] + .iter() + .enumerate() + .filter_map(|(i, &is_prime)| if is_prime { Some(start + i) } else { None }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + const PRIMES_UP_TO_997: [usize; 168] = [ + 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, + 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, + 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, + 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, + 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, + 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, + 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, + 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, + 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997, + ]; + + macro_rules! sieve_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let input: usize = $test_case; + let expected: Vec = PRIMES_UP_TO_997.iter().cloned().filter(|&x| x <= input).collect(); + assert_eq!(sieve_of_eratosthenes(input), expected); + } + )* + } + } + + sieve_tests! { + test_0: 0, + test_1: 1, + test_2: 2, + test_3: 3, + test_4: 4, + test_5: 5, + test_6: 6, + test_7: 7, + test_11: 11, + test_23: 23, + test_24: 24, + test_25: 25, + test_26: 26, + test_27: 27, + test_28: 28, + test_29: 29, + test_33: 33, + test_100: 100, + test_997: 997, + test_998: 998, + test_999: 999, + test_1000: 1000, + } +} diff --git a/src/math/sigmoid.rs b/src/math/sigmoid.rs new file mode 100644 index 00000000000..bee6c0c6cb7 --- /dev/null +++ b/src/math/sigmoid.rs @@ -0,0 +1,34 @@ +//Rust implementation of the Sigmoid activation function. +//The formula for Sigmoid: 1 / (1 + e^(-x)) +//More information on the concepts of Sigmoid can be found here: +//https://en.wikipedia.org/wiki/Sigmoid_function + +//The function below takes a reference to a mutable Vector as an argument +//and returns the vector with 'Sigmoid' applied to all values. +//Of course, these functions can be changed by the developer so that the input vector isn't manipulated. +//This is simply an implemenation of the formula. + +use std::f32::consts::E; + +pub fn sigmoid(array: &mut Vec) -> &mut Vec { + //note that these calculations are assuming the Vector values consists of real numbers in radians + for value in &mut *array { + *value = 1. / (1. + E.powf(-1. * *value)); + } + + array +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sigmoid() { + let mut test = Vec::from([1.0, 0.5, -1.0, 0.0, 0.3]); + assert_eq!( + sigmoid(&mut test), + &mut Vec::::from([0.7310586, 0.62245935, 0.26894143, 0.5, 0.5744425,]) + ); + } +} diff --git a/src/math/signum.rs b/src/math/signum.rs new file mode 100644 index 00000000000..83b7e7b5538 --- /dev/null +++ b/src/math/signum.rs @@ -0,0 +1,36 @@ +/// Signum function is a mathematical function that extracts +/// the sign of a real number. It is also known as the sign function, +/// and it is an odd piecewise function. +/// If a number is negative, i.e. it is less than zero, then sgn(x) = -1 +/// If a number is zero, then sgn(0) = 0 +/// If a number is positive, i.e. it is greater than zero, then sgn(x) = 1 + +pub fn signum(number: f64) -> i8 { + if number == 0.0 { + return 0; + } else if number > 0.0 { + return 1; + } + + -1 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn positive_integer() { + assert_eq!(signum(15.0), 1); + } + + #[test] + fn negative_integer() { + assert_eq!(signum(-30.0), -1); + } + + #[test] + fn zero() { + assert_eq!(signum(0.0), 0); + } +} diff --git a/src/math/simpsons_integration.rs b/src/math/simpsons_integration.rs new file mode 100644 index 00000000000..57b173a136b --- /dev/null +++ b/src/math/simpsons_integration.rs @@ -0,0 +1,127 @@ +pub fn simpsons_integration(f: F, a: f64, b: f64, n: usize) -> f64 +where + F: Fn(f64) -> f64, +{ + let h = (b - a) / n as f64; + (0..n) + .map(|i| { + let x0 = a + i as f64 * h; + let x1 = x0 + h / 2.0; + let x2 = x0 + h; + (h / 6.0) * (f(x0) + 4.0 * f(x1) + f(x2)) + }) + .sum() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_simpsons_integration() { + let f = |x: f64| x.powi(2); + let a = 0.0; + let b = 1.0; + let n = 100; + let result = simpsons_integration(f, a, b, n); + assert!((result - 1.0 / 3.0).abs() < 1e-6); + } + + #[test] + fn test_error() { + let f = |x: f64| x.powi(2); + let a = 0.0; + let b = 1.0; + let n = 100; + let result = simpsons_integration(f, a, b, n); + let error = (1.0 / 3.0 - result).abs(); + assert!(error < 1e-6); + } + + #[test] + fn test_convergence() { + let f = |x: f64| x.powi(2); + let a = 0.0; + let b = 1.0; + let n = 100; + let result1 = simpsons_integration(f, a, b, n); + let result2 = simpsons_integration(f, a, b, 2 * n); + let result3 = simpsons_integration(f, a, b, 4 * n); + let result4 = simpsons_integration(f, a, b, 8 * n); + assert!((result1 - result2).abs() < 1e-6); + assert!((result2 - result3).abs() < 1e-6); + assert!((result3 - result4).abs() < 1e-6); + } + + #[test] + fn test_negative() { + let f = |x: f64| -x.powi(2); + let a = 0.0; + let b = 1.0; + let n = 100; + let result = simpsons_integration(f, a, b, n); + assert!((result + 1.0 / 3.0).abs() < 1e-6); + } + + #[test] + fn test_non_zero_lower_bound() { + let f = |x: f64| x.powi(2); + let a = 1.0; + let b = 2.0; + let n = 100; + let result = simpsons_integration(f, a, b, n); + assert!((result - 7.0 / 3.0).abs() < 1e-6); + } + + #[test] + fn test_non_zero_upper_bound() { + let f = |x: f64| x.powi(2); + let a = 0.0; + let b = 2.0; + let n = 100; + let result = simpsons_integration(f, a, b, n); + assert!((result - 8.0 / 3.0).abs() < 1e-6); + } + + #[test] + fn test_non_zero_lower_and_upper_bound() { + let f = |x: f64| x.powi(2); + let a = 1.0; + let b = 2.0; + let n = 100; + let result = simpsons_integration(f, a, b, n); + assert!((result - 7.0 / 3.0).abs() < 1e-6); + } + + #[test] + fn test_non_zero_lower_and_upper_bound_negative() { + let f = |x: f64| -x.powi(2); + let a = 1.0; + let b = 2.0; + let n = 100; + let result = simpsons_integration(f, a, b, n); + assert!((result + 7.0 / 3.0).abs() < 1e-6); + } + + #[test] + fn parabola_curve_length() { + // Calculate the length of the curve f(x) = x^2 for -5 <= x <= 5 + // We should integrate sqrt(1 + (f'(x))^2) + let function = |x: f64| -> f64 { (1.0 + 4.0 * x * x).sqrt() }; + let result = simpsons_integration(function, -5.0, 5.0, 1_000); + let integrated = |x: f64| -> f64 { (x * function(x) / 2.0) + ((2.0 * x).asinh() / 4.0) }; + let expected = integrated(5.0) - integrated(-5.0); + assert!((result - expected).abs() < 1e-9); + } + + #[test] + fn area_under_cosine() { + use std::f64::consts::PI; + // Calculate area under f(x) = cos(x) + 5 for -pi <= x <= pi + // cosine should cancel out and the answer should be 2pi * 5 + let function = |x: f64| -> f64 { x.cos() + 5.0 }; + let result = simpsons_integration(function, -PI, PI, 1_000); + let expected = 2.0 * PI * 5.0; + assert!((result - expected).abs() < 1e-9); + } +} diff --git a/src/math/softmax.rs b/src/math/softmax.rs new file mode 100644 index 00000000000..582bf452ef5 --- /dev/null +++ b/src/math/softmax.rs @@ -0,0 +1,56 @@ +//! # Softmax Function +//! +//! The `softmax` function computes the softmax values of a given array of f32 numbers. +//! +//! The softmax operation is often used in machine learning for converting a vector of real numbers into a +//! probability distribution. It exponentiates each element in the input array, and then normalizes the +//! results so that they sum to 1. +//! +//! ## Formula +//! +//! For a given input array `x`, the softmax function computes the output `y` as follows: +//! +//! `y_i = e^(x_i) / sum(e^(x_j) for all j)` +//! +//! ## Softmax Function Implementation +//! +//! This implementation uses the `std::f32::consts::E` constant for the base of the exponential function. and +//! f32 vectors to compute the values. The function creates a new vector and not altering the input vector. +//! +use std::f32::consts::E; + +pub fn softmax(array: Vec) -> Vec { + let mut softmax_array = array; + + for value in &mut softmax_array { + *value = E.powf(*value); + } + + let sum: f32 = softmax_array.iter().sum(); + + for value in &mut softmax_array { + *value /= sum; + } + + softmax_array +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_softmax() { + let test = vec![9.0, 0.5, -3.0, 0.0, 3.0]; + assert_eq!( + softmax(test), + vec![ + 0.9971961, + 0.00020289792, + 6.126987e-6, + 0.00012306382, + 0.0024718025 + ] + ); + } +} diff --git a/src/math/sprague_grundy_theorem.rs b/src/math/sprague_grundy_theorem.rs new file mode 100644 index 00000000000..4006d047ba6 --- /dev/null +++ b/src/math/sprague_grundy_theorem.rs @@ -0,0 +1,70 @@ +/** + * Sprague Grundy Theorem for combinatorial games like Nim + * + * The Sprague Grundy Theorem is a fundamental concept in combinatorial game theory, commonly used to analyze + * games like Nim. It calculates the Grundy number (also known as the nimber) for a position in a game. + * The Grundy number represents the game's position, and it helps determine the winning strategy. + * + * The Grundy number of a terminal state is 0; otherwise, it is recursively defined as the minimum + * excludant (mex) of the Grundy values of possible next states. + * + * For more details on Sprague Grundy Theorem, you can visit:(https://en.wikipedia.org/wiki/Sprague%E2%80%93Grundy_theorem) + * + * Author : [Gyandeep](https://github.com/Gyan172004) + */ + +pub fn calculate_grundy_number( + position: i64, + grundy_numbers: &mut [i64], + possible_moves: &[i64], +) -> i64 { + // Check if we've already calculated the Grundy number for this position. + if grundy_numbers[position as usize] != -1 { + return grundy_numbers[position as usize]; + } + + // Base case: terminal state + if position == 0 { + grundy_numbers[0] = 0; + return 0; + } + + // Calculate Grundy values for possible next states. + let mut next_state_grundy_values: Vec = vec![]; + for move_size in possible_moves.iter() { + if position - move_size >= 0 { + next_state_grundy_values.push(calculate_grundy_number( + position - move_size, + grundy_numbers, + possible_moves, + )); + } + } + + // Sort the Grundy values and find the minimum excludant. + next_state_grundy_values.sort_unstable(); + let mut mex: i64 = 0; + for grundy_value in next_state_grundy_values.iter() { + if *grundy_value != mex { + break; + } + mex += 1; + } + + // Store the calculated Grundy number and return it. + grundy_numbers[position as usize] = mex; + mex +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn calculate_grundy_number_test() { + let mut grundy_numbers: Vec = vec![-1; 7]; + let possible_moves: Vec = vec![1, 4]; + calculate_grundy_number(6, &mut grundy_numbers, &possible_moves); + assert_eq!(grundy_numbers, [0, 1, 0, 1, 2, 0, 1]); + } +} diff --git a/src/math/square_pyramidal_numbers.rs b/src/math/square_pyramidal_numbers.rs new file mode 100644 index 00000000000..2a6659e09c8 --- /dev/null +++ b/src/math/square_pyramidal_numbers.rs @@ -0,0 +1,20 @@ +// https://en.wikipedia.org/wiki/Square_pyramidal_number +// 1² + 2² + ... = ... (total) + +pub fn square_pyramidal_number(n: u64) -> u64 { + n * (n + 1) * (2 * n + 1) / 6 +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test0() { + assert_eq!(0, square_pyramidal_number(0)); + assert_eq!(1, square_pyramidal_number(1)); + assert_eq!(5, square_pyramidal_number(2)); + assert_eq!(14, square_pyramidal_number(3)); + } +} diff --git a/src/math/square_root.rs b/src/math/square_root.rs new file mode 100644 index 00000000000..4f858ad90b5 --- /dev/null +++ b/src/math/square_root.rs @@ -0,0 +1,57 @@ +/// squre_root returns the square root +/// of a f64 number using Newton's method +pub fn square_root(num: f64) -> f64 { + if num < 0.0_f64 { + return f64::NAN; + } + + let mut root = 1.0_f64; + + while (root * root - num).abs() > 1e-10_f64 { + root -= (root * root - num) / (2.0_f64 * root); + } + + root +} + +// fast_inv_sqrt returns an approximation of the inverse square root +// This algorithm was first used in Quake and has been reimplemented in a few other languages +// This crate implements it more thoroughly: https://docs.rs/quake-inverse-sqrt/latest/quake_inverse_sqrt/ +pub fn fast_inv_sqrt(num: f32) -> f32 { + // If you are confident in your input this can be removed for speed + if num < 0.0f32 { + return f32::NAN; + } + + let i = num.to_bits(); + let i = 0x5f3759df - (i >> 1); + let y = f32::from_bits(i); + + println!("num: {:?}, out: {:?}", num, y * (1.5 - 0.5 * num * y * y)); + // First iteration of Newton's approximation + y * (1.5 - 0.5 * num * y * y) + // The above can be repeated for more precision +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_fast_inv_sqrt() { + // Negatives don't have square roots: + assert!(fast_inv_sqrt(-1.0f32).is_nan()); + + // Test a few cases, expect less than 1% error: + let test_pairs = [(4.0, 0.5), (16.0, 0.25), (25.0, 0.2)]; + for pair in test_pairs { + assert!((fast_inv_sqrt(pair.0) - pair.1).abs() <= (0.01 * pair.0)); + } + } + + #[test] + fn test_sqare_root() { + assert!((square_root(4.0_f64) - 2.0_f64).abs() <= 1e-10_f64); + assert!(square_root(-4.0_f64).is_nan()); + } +} diff --git a/src/math/sum_of_digits.rs b/src/math/sum_of_digits.rs new file mode 100644 index 00000000000..1da42ff20d9 --- /dev/null +++ b/src/math/sum_of_digits.rs @@ -0,0 +1,116 @@ +/// Iteratively sums the digits of a signed integer +/// +/// ## Arguments +/// +/// * `num` - The number to sum the digits of +/// +/// ## Examples +/// +/// ``` +/// use the_algorithms_rust::math::sum_digits_iterative; +/// +/// assert_eq!(10, sum_digits_iterative(1234)); +/// assert_eq!(12, sum_digits_iterative(-246)); +/// ``` +pub fn sum_digits_iterative(num: i32) -> u32 { + // convert to unsigned integer + let mut num = num.unsigned_abs(); + // initialize sum + let mut result: u32 = 0; + + // iterate through digits + while num > 0 { + // extract next digit and add to sum + result += num % 10; + num /= 10; // chop off last digit + } + result +} + +/// Recursively sums the digits of a signed integer +/// +/// ## Arguments +/// +/// * `num` - The number to sum the digits of +/// +/// ## Examples +/// +/// ``` +/// use the_algorithms_rust::math::sum_digits_recursive; +/// +/// assert_eq!(10, sum_digits_recursive(1234)); +/// assert_eq!(12, sum_digits_recursive(-246)); +/// ``` +pub fn sum_digits_recursive(num: i32) -> u32 { + // convert to unsigned integer + let num = num.unsigned_abs(); + // base case + if num < 10 { + return num; + } + // recursive case: add last digit to sum of remaining digits + num % 10 + sum_digits_recursive((num / 10) as i32) +} + +#[cfg(test)] +mod tests { + mod iterative { + // import relevant sum_digits function + use super::super::sum_digits_iterative as sum_digits; + + #[test] + fn zero() { + assert_eq!(0, sum_digits(0)); + } + #[test] + fn positive_number() { + assert_eq!(1, sum_digits(1)); + assert_eq!(10, sum_digits(1234)); + assert_eq!(14, sum_digits(42161)); + assert_eq!(6, sum_digits(500010)); + } + #[test] + fn negative_number() { + assert_eq!(1, sum_digits(-1)); + assert_eq!(12, sum_digits(-246)); + assert_eq!(2, sum_digits(-11)); + assert_eq!(14, sum_digits(-42161)); + assert_eq!(6, sum_digits(-500010)); + } + #[test] + fn trailing_zeros() { + assert_eq!(1, sum_digits(1000000000)); + assert_eq!(3, sum_digits(300)); + } + } + + mod recursive { + // import relevant sum_digits function + use super::super::sum_digits_recursive as sum_digits; + + #[test] + fn zero() { + assert_eq!(0, sum_digits(0)); + } + #[test] + fn positive_number() { + assert_eq!(1, sum_digits(1)); + assert_eq!(10, sum_digits(1234)); + assert_eq!(14, sum_digits(42161)); + assert_eq!(6, sum_digits(500010)); + } + #[test] + fn negative_number() { + assert_eq!(1, sum_digits(-1)); + assert_eq!(12, sum_digits(-246)); + assert_eq!(2, sum_digits(-11)); + assert_eq!(14, sum_digits(-42161)); + assert_eq!(6, sum_digits(-500010)); + } + #[test] + fn trailing_zeros() { + assert_eq!(1, sum_digits(1000000000)); + assert_eq!(3, sum_digits(300)); + } + } +} diff --git a/src/math/sum_of_geometric_progression.rs b/src/math/sum_of_geometric_progression.rs new file mode 100644 index 00000000000..30401a3c2d2 --- /dev/null +++ b/src/math/sum_of_geometric_progression.rs @@ -0,0 +1,37 @@ +// Author : cyrixninja +// Find the Sum of Geometric Progression +// Wikipedia: https://en.wikipedia.org/wiki/Geometric_progression + +pub fn sum_of_geometric_progression(first_term: f64, common_ratio: f64, num_of_terms: i32) -> f64 { + if common_ratio == 1.0 { + // Formula for sum if the common ratio is 1 + return (num_of_terms as f64) * first_term; + } + + // Formula for finding the sum of n terms of a Geometric Progression + (first_term / (1.0 - common_ratio)) * (1.0 - common_ratio.powi(num_of_terms)) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_sum_of_geometric_progression { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (first_term, common_ratio, num_of_terms, expected) = $inputs; + assert_eq!(sum_of_geometric_progression(first_term, common_ratio, num_of_terms), expected); + } + )* + } + } + + test_sum_of_geometric_progression! { + regular_input_0: (1.0, 2.0, 10, 1023.0), + regular_input_1: (1.0, 10.0, 5, 11111.0), + regular_input_2: (9.0, 2.5, 5, 579.9375), + common_ratio_one: (10.0, 1.0, 3, 30.0), + } +} diff --git a/src/math/sum_of_harmonic_series.rs b/src/math/sum_of_harmonic_series.rs new file mode 100644 index 00000000000..553ba23e9c6 --- /dev/null +++ b/src/math/sum_of_harmonic_series.rs @@ -0,0 +1,40 @@ +// Author : cyrixninja +// Sum of Harmonic Series : Find the sum of n terms in an harmonic progression. The calculation starts with the +// first_term and loops adding the common difference of Arithmetic Progression by which +// the given Harmonic Progression is linked. +// Wikipedia Reference : https://en.wikipedia.org/wiki/Interquartile_range +// Other References : https://the-algorithms.com/algorithm/sum-of-harmonic-series?lang=python + +pub fn sum_of_harmonic_progression( + first_term: f64, + common_difference: f64, + number_of_terms: i32, +) -> f64 { + let mut arithmetic_progression = vec![1.0 / first_term]; + let mut current_term = 1.0 / first_term; + + for _ in 0..(number_of_terms - 1) { + current_term += common_difference; + arithmetic_progression.push(current_term); + } + + let harmonic_series: Vec = arithmetic_progression + .into_iter() + .map(|step| 1.0 / step) + .collect(); + harmonic_series.iter().sum() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sum_of_harmonic_progression() { + assert_eq!(sum_of_harmonic_progression(1.0 / 2.0, 2.0, 2), 0.75); + assert_eq!( + sum_of_harmonic_progression(1.0 / 5.0, 5.0, 5), + 0.45666666666666667 + ); + } +} diff --git a/src/math/sylvester_sequence.rs b/src/math/sylvester_sequence.rs new file mode 100644 index 00000000000..7ce7e4534d0 --- /dev/null +++ b/src/math/sylvester_sequence.rs @@ -0,0 +1,33 @@ +// Author : cyrixninja +// Sylvester Series : Calculates the nth number in Sylvester's sequence. +// Wikipedia Reference : https://en.wikipedia.org/wiki/Sylvester%27s_sequence +// Other References : https://the-algorithms.com/algorithm/sylvester-sequence?lang=python + +pub fn sylvester(number: i32) -> i128 { + assert!(number > 0, "The input value of [n={number}] has to be > 0"); + + if number == 1 { + 2 + } else { + let num = sylvester(number - 1); + let lower = num - 1; + let upper = num; + lower * upper + 1 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sylvester() { + assert_eq!(sylvester(8), 113423713055421844361000443_i128); + } + + #[test] + #[should_panic(expected = "The input value of [n=-1] has to be > 0")] + fn test_sylvester_negative() { + sylvester(-1); + } +} diff --git a/src/math/tanh.rs b/src/math/tanh.rs new file mode 100644 index 00000000000..e7a9a785366 --- /dev/null +++ b/src/math/tanh.rs @@ -0,0 +1,34 @@ +//Rust implementation of the Tanh (hyperbolic tangent) activation function. +//The formula for Tanh: (e^x - e^(-x))/(e^x + e^(-x)) OR (2/(1+e^(-2x))-1 +//More information on the concepts of Sigmoid can be found here: +//https://en.wikipedia.org/wiki/Hyperbolic_functions + +//The function below takes a reference to a mutable Vector as an argument +//and returns the vector with 'Tanh' applied to all values. +//Of course, these functions can be changed by the developer so that the input vector isn't manipulated. +//This is simply an implemenation of the formula. + +use std::f32::consts::E; + +pub fn tanh(array: &mut Vec) -> &mut Vec { + //note that these calculations are assuming the Vector values consists of real numbers in radians + for value in &mut *array { + *value = (2. / (1. + E.powf(-2. * *value))) - 1.; + } + + array +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tanh() { + let mut test = Vec::from([1.0, 0.5, -1.0, 0.0, 0.3]); + assert_eq!( + tanh(&mut test), + &mut Vec::::from([0.76159406, 0.4621172, -0.7615941, 0.0, 0.29131258,]) + ); + } +} diff --git a/src/math/trapezoidal_integration.rs b/src/math/trapezoidal_integration.rs new file mode 100644 index 00000000000..f9cda7088c5 --- /dev/null +++ b/src/math/trapezoidal_integration.rs @@ -0,0 +1,42 @@ +pub fn trapezoidal_integral(a: f64, b: f64, f: F, precision: u32) -> f64 +where + F: Fn(f64) -> f64, +{ + let delta = (b - a) / precision as f64; + + (0..precision) + .map(|trapezoid| { + let left_side = a + (delta * trapezoid as f64); + let right_side = left_side + delta; + + 0.5 * (f(left_side) + f(right_side)) * delta + }) + .sum() +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_trapezoidal_integral { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (a, b, f, prec, expected, eps) = $inputs; + let actual = trapezoidal_integral(a, b, f, prec); + assert!((actual - expected).abs() < eps); + } + )* + } + } + + test_trapezoidal_integral! { + basic_0: (0.0, 1.0, |x: f64| x.powi(2), 1000, 1.0/3.0, 0.0001), + basic_0_higher_prec: (0.0, 1.0, |x: f64| x.powi(2), 10000, 1.0/3.0, 0.00001), + basic_1: (-1.0, 1.0, |x: f64| x.powi(2), 10000, 2.0/3.0, 0.00001), + basic_1_higher_prec: (-1.0, 1.0, |x: f64| x.powi(2), 100000, 2.0/3.0, 0.000001), + flipped_limits: (1.0, 0.0, |x: f64| x.powi(2), 10000, -1.0/3.0, 0.00001), + empty_range: (0.5, 0.5, |x: f64| x.powi(2), 100, 0.0, 0.0000001), + } +} diff --git a/src/math/trial_division.rs b/src/math/trial_division.rs new file mode 100644 index 00000000000..882e6f72262 --- /dev/null +++ b/src/math/trial_division.rs @@ -0,0 +1,59 @@ +fn floor(value: f64, scale: u8) -> f64 { + let multiplier = 10i64.pow(scale as u32) as f64; + (value * multiplier).floor() +} + +fn double_to_int(amount: f64) -> i128 { + amount.round() as i128 +} + +pub fn trial_division(mut num: i128) -> Vec { + if num < 0 { + return trial_division(-num); + } + let mut result: Vec = vec![]; + + if num == 0 { + return result; + } + + while num % 2 == 0 { + result.push(2); + num /= 2; + num = double_to_int(floor(num as f64, 0)) + } + let mut f: i128 = 3; + + while f.pow(2) <= num { + if num % f == 0 { + result.push(f); + num /= f; + num = double_to_int(floor(num as f64, 0)) + } else { + f += 2 + } + } + + if num != 1 { + result.push(num) + } + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basic() { + assert_eq!(trial_division(0), vec![]); + assert_eq!(trial_division(1), vec![]); + assert_eq!(trial_division(9), vec!(3, 3)); + assert_eq!(trial_division(-9), vec!(3, 3)); + assert_eq!(trial_division(10), vec!(2, 5)); + assert_eq!(trial_division(11), vec!(11)); + assert_eq!(trial_division(33), vec!(3, 11)); + assert_eq!(trial_division(2003), vec!(2003)); + assert_eq!(trial_division(100001), vec!(11, 9091)); + } +} diff --git a/src/math/trig_functions.rs b/src/math/trig_functions.rs new file mode 100644 index 00000000000..e18b5154029 --- /dev/null +++ b/src/math/trig_functions.rs @@ -0,0 +1,263 @@ +/// Function that contains the similarities of the sine and cosine implementations +/// +/// Both of them are calculated using their MacLaurin Series +/// +/// Because there is just a '+1' that differs in their formula, this function has been +/// created for not repeating +fn template>(x: T, tol: f64, kind: i32) -> f64 { + use std::f64::consts::PI; + const PERIOD: f64 = 2.0 * PI; + /* Sometimes, this function is called for a big 'n'(when tol is very small) */ + fn factorial(n: i128) -> i128 { + (1..=n).product() + } + + /* Function to round up to the 'decimal'th decimal of the number 'x' */ + fn round_up_to_decimal(x: f64, decimal: i32) -> f64 { + let multiplier = 10f64.powi(decimal); + (x * multiplier).round() / multiplier + } + + let mut value: f64 = x.into(); //<-- This is the line for which the trait 'Into' is required + + /* Check for invalid arguments */ + if !value.is_finite() || value.is_nan() { + eprintln!("This function does not accept invalid arguments."); + return f64::NAN; + } + + /* + The argument to sine could be bigger than the sine's PERIOD + To prevent overflowing, strip the value off relative to the PERIOD + */ + while value >= PERIOD { + value -= PERIOD; + } + /* For cases when the value is smaller than the -PERIOD (e.g. sin(-3π) <=> sin(-π)) */ + while value <= -PERIOD { + value += PERIOD; + } + + let mut rez = 0f64; + let mut prev_rez = 1f64; + let mut step: i32 = 0; + /* + This while instruction is the MacLaurin Series for sine / cosine + sin(x) = Σ (-1)^n * x^2n+1 / (2n+1)!, for n >= 0 and x a Real number + cos(x) = Σ (-1)^n * x^2n / (2n)!, for n >= 0 and x a Real number + + '+1' in sine's formula is replaced with 'kind', which values are: + -> kind = 0, for cosine + -> kind = 1, for sine + */ + while (prev_rez - rez).abs() > tol { + prev_rez = rez; + rez += (-1f64).powi(step) * value.powi(2 * step + kind) + / factorial((2 * step + kind) as i128) as f64; + step += 1; + } + + /* Round up to the 6th decimal */ + round_up_to_decimal(rez, 6) +} + +/// Returns the value of sin(x), approximated with the given tolerance +/// +/// This function supposes the argument is in radians +/// +/// ### Example +/// +/// sin(1) == sin(1 rad) == sin(π/180) +pub fn sine>(x: T, tol: f64) -> f64 { + template(x, tol, 1) +} + +/// Returns the value of cos, approximated with the given tolerance, for +/// an angle 'x' in radians +pub fn cosine>(x: T, tol: f64) -> f64 { + template(x, tol, 0) +} + +/// Cosine of 'x' in degrees, with the given tolerance +pub fn cosine_no_radian_arg>(x: T, tol: f64) -> f64 { + use std::f64::consts::PI; + let val: f64 = x.into(); + cosine(val * PI / 180., tol) +} + +/// Sine function for non radian angle +/// +/// Interprets the argument in degrees, not in radians +/// +/// ### Example +/// +/// sin(1o) != \[ sin(1 rad) == sin(π/180) \] +pub fn sine_no_radian_arg>(x: T, tol: f64) -> f64 { + use std::f64::consts::PI; + let val: f64 = x.into(); + sine(val * PI / 180f64, tol) +} + +/// Tangent of angle 'x' in radians, calculated with the given tolerance +pub fn tan + Copy>(x: T, tol: f64) -> f64 { + let cos_val = cosine(x, tol); + + /* Cover special cases for division */ + if cos_val != 0f64 { + let sin_val = sine(x, tol); + sin_val / cos_val + } else { + f64::NAN + } +} + +/// Cotangent of angle 'x' in radians, calculated with the given tolerance +pub fn cotan + Copy>(x: T, tol: f64) -> f64 { + let sin_val = sine(x, tol); + + /* Cover special cases for division */ + if sin_val != 0f64 { + let cos_val = cosine(x, tol); + cos_val / sin_val + } else { + f64::NAN + } +} + +/// Tangent of 'x' in degrees, approximated with the given tolerance +pub fn tan_no_radian_arg + Copy>(x: T, tol: f64) -> f64 { + let angle: f64 = x.into(); + + use std::f64::consts::PI; + tan(angle * PI / 180., tol) +} + +/// Cotangent of 'x' in degrees, approximated with the given tolerance +pub fn cotan_no_radian_arg + Copy>(x: T, tol: f64) -> f64 { + let angle: f64 = x.into(); + + use std::f64::consts::PI; + cotan(angle * PI / 180., tol) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::f64::consts::PI; + + enum TrigFuncType { + Sine, + Cosine, + Tan, + Cotan, + } + + const TOL: f64 = 1e-10; + + impl TrigFuncType { + fn verify + Copy>(&self, angle: T, expected_result: f64, is_radian: bool) { + let value = match self { + TrigFuncType::Sine => { + if is_radian { + sine(angle, TOL) + } else { + sine_no_radian_arg(angle, TOL) + } + } + TrigFuncType::Cosine => { + if is_radian { + cosine(angle, TOL) + } else { + cosine_no_radian_arg(angle, TOL) + } + } + TrigFuncType::Tan => { + if is_radian { + tan(angle, TOL) + } else { + tan_no_radian_arg(angle, TOL) + } + } + TrigFuncType::Cotan => { + if is_radian { + cotan(angle, TOL) + } else { + cotan_no_radian_arg(angle, TOL) + } + } + }; + + assert_eq!(format!("{value:.5}"), format!("{:.5}", expected_result)); + } + } + + #[test] + fn test_sine() { + let sine_id = TrigFuncType::Sine; + sine_id.verify(0.0, 0.0, true); + sine_id.verify(-PI, 0.0, true); + sine_id.verify(-PI / 2.0, -1.0, true); + sine_id.verify(0.5, 0.4794255386, true); + /* Same tests, but angle is now in degrees */ + sine_id.verify(0, 0.0, false); + sine_id.verify(-180, 0.0, false); + sine_id.verify(-180 / 2, -1.0, false); + sine_id.verify(0.5, 0.00872654, false); + } + + #[test] + fn test_sine_bad_arg() { + assert!(sine(f64::NEG_INFINITY, 1e-1).is_nan()); + assert!(sine_no_radian_arg(f64::NAN, 1e-1).is_nan()); + } + + #[test] + fn test_cosine_bad_arg() { + assert!(cosine(f64::INFINITY, 1e-1).is_nan()); + assert!(cosine_no_radian_arg(f64::NAN, 1e-1).is_nan()); + } + + #[test] + fn test_cosine() { + let cosine_id = TrigFuncType::Cosine; + cosine_id.verify(0, 1., true); + cosine_id.verify(0, 1., false); + cosine_id.verify(45, 1. / f64::sqrt(2.), false); + cosine_id.verify(PI / 4., 1. / f64::sqrt(2.), true); + cosine_id.verify(360, 1., false); + cosine_id.verify(2. * PI, 1., true); + cosine_id.verify(15. * PI / 2., 0.0, true); + cosine_id.verify(-855, -1. / f64::sqrt(2.), false); + } + + #[test] + fn test_tan_bad_arg() { + assert!(tan(PI / 2., TOL).is_nan()); + assert!(tan(3. * PI / 2., TOL).is_nan()); + } + + #[test] + fn test_tan() { + let tan_id = TrigFuncType::Tan; + tan_id.verify(PI / 4., 1f64, true); + tan_id.verify(45, 1f64, false); + tan_id.verify(PI, 0f64, true); + tan_id.verify(180 + 45, 1f64, false); + tan_id.verify(60 - 2 * 180, 1.7320508075, false); + tan_id.verify(30 + 180 - 180, 0.57735026919, false); + } + + #[test] + fn test_cotan_bad_arg() { + assert!(cotan(tan(PI / 2., TOL), TOL).is_nan()); + assert!(!cotan(0, TOL).is_finite()); + } + + #[test] + fn test_cotan() { + let cotan_id = TrigFuncType::Cotan; + cotan_id.verify(PI / 4., 1f64, true); + cotan_id.verify(90 + 10 * 180, 0f64, false); + cotan_id.verify(30 - 5 * 180, f64::sqrt(3.), false); + } +} diff --git a/src/math/vector_cross_product.rs b/src/math/vector_cross_product.rs new file mode 100644 index 00000000000..582470d2bc7 --- /dev/null +++ b/src/math/vector_cross_product.rs @@ -0,0 +1,98 @@ +/// Cross Product and Magnitude Calculation +/// +/// This program defines functions to calculate the cross product of two 3D vectors +/// and the magnitude of a vector from its direction ratios. The main purpose is +/// to demonstrate the mathematical concepts and provide test cases for the functions. +/// +/// Time Complexity: +/// - Calculating the cross product and magnitude of a vector each takes O(1) time +/// since we are working with fixed-size arrays and performing a fixed number of +/// mathematical operations. + +/// Function to calculate the cross product of two vectors +pub fn cross_product(vec1: [f64; 3], vec2: [f64; 3]) -> [f64; 3] { + let x = vec1[1] * vec2[2] - vec1[2] * vec2[1]; + let y = -(vec1[0] * vec2[2] - vec1[2] * vec2[0]); + let z = vec1[0] * vec2[1] - vec1[1] * vec2[0]; + [x, y, z] +} + +/// Function to calculate the magnitude of a vector +pub fn vector_magnitude(vec: [f64; 3]) -> f64 { + (vec[0].powi(2) + vec[1].powi(2) + vec[2].powi(2)).sqrt() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cross_product_and_magnitude_1() { + // Test case with non-trivial vectors + let vec1 = [1.0, 2.0, 3.0]; + let vec2 = [4.0, 5.0, 6.0]; + + let cross_product = cross_product(vec1, vec2); + let magnitude = vector_magnitude(cross_product); + + // Check the expected results with a tolerance for floating-point comparisons + assert_eq!(cross_product, [-3.0, 6.0, -3.0]); + assert!((magnitude - 7.34847).abs() < 1e-5); + } + + #[test] + fn test_cross_product_and_magnitude_2() { + // Test case with orthogonal vectors + let vec1 = [1.0, 0.0, 0.0]; + let vec2 = [0.0, 1.0, 0.0]; + + let cross_product = cross_product(vec1, vec2); + let magnitude = vector_magnitude(cross_product); + + // Check the expected results + assert_eq!(cross_product, [0.0, 0.0, 1.0]); + assert_eq!(magnitude, 1.0); + } + + #[test] + fn test_cross_product_and_magnitude_3() { + // Test case with vectors along the axes + let vec1 = [2.0, 0.0, 0.0]; + let vec2 = [0.0, 3.0, 0.0]; + + let cross_product = cross_product(vec1, vec2); + let magnitude = vector_magnitude(cross_product); + + // Check the expected results + assert_eq!(cross_product, [0.0, 0.0, 6.0]); + assert_eq!(magnitude, 6.0); + } + + #[test] + fn test_cross_product_and_magnitude_4() { + // Test case with parallel vectors + let vec1 = [1.0, 2.0, 3.0]; + let vec2 = [2.0, 4.0, 6.0]; + + let cross_product = cross_product(vec1, vec2); + let magnitude = vector_magnitude(cross_product); + + // Check the expected results + assert_eq!(cross_product, [0.0, 0.0, 0.0]); + assert_eq!(magnitude, 0.0); + } + + #[test] + fn test_cross_product_and_magnitude_5() { + // Test case with zero vectors + let vec1 = [0.0, 0.0, 0.0]; + let vec2 = [0.0, 0.0, 0.0]; + + let cross_product = cross_product(vec1, vec2); + let magnitude = vector_magnitude(cross_product); + + // Check the expected results + assert_eq!(cross_product, [0.0, 0.0, 0.0]); + assert_eq!(magnitude, 0.0); + } +} diff --git a/src/math/zellers_congruence_algorithm.rs b/src/math/zellers_congruence_algorithm.rs new file mode 100644 index 00000000000..43bf49e732f --- /dev/null +++ b/src/math/zellers_congruence_algorithm.rs @@ -0,0 +1,54 @@ +// returns the day of the week from the Gregorian Date + +pub fn zellers_congruence_algorithm(date: i32, month: i32, year: i32, as_string: bool) -> String { + let q = date; + let (m, y) = if month < 3 { + (month + 12, year - 1) + } else { + (month, year) + }; + let day: i32 = + (q + (26 * (m + 1) / 10) + (y % 100) + ((y % 100) / 4) + ((y / 100) / 4) + (5 * (y / 100))) + % 7; + if as_string { + number_to_day(day) + } else { + day.to_string() + } + /* Note that the day follows the following guidelines: + 0 = Saturday + 1 = Sunday + 2 = Monday + 3 = Tuesday + 4 = Wednesday + 5 = Thursday + 6 = Friday + */ +} + +fn number_to_day(number: i32) -> String { + let days = [ + "Saturday", + "Sunday", + "Monday", + "Tuesday", + "Wednesday", + "Thursday", + "Friday", + ]; + String::from(days[number as usize]) +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn it_works() { + assert_eq!(zellers_congruence_algorithm(25, 1, 2013, false), "6"); + assert_eq!(zellers_congruence_algorithm(25, 1, 2013, true), "Friday"); + assert_eq!(zellers_congruence_algorithm(16, 4, 2022, false), "0"); + assert_eq!(zellers_congruence_algorithm(16, 4, 2022, true), "Saturday"); + assert_eq!(zellers_congruence_algorithm(14, 12, 1978, false), "5"); + assert_eq!(zellers_congruence_algorithm(15, 6, 2021, false), "3"); + } +} diff --git a/src/navigation/bearing.rs b/src/navigation/bearing.rs new file mode 100644 index 00000000000..6efec578b26 --- /dev/null +++ b/src/navigation/bearing.rs @@ -0,0 +1,40 @@ +use std::f64::consts::PI; + +pub fn bearing(lat1: f64, lng1: f64, lat2: f64, lng2: f64) -> f64 { + let lat1 = lat1 * PI / 180.0; + let lng1 = lng1 * PI / 180.0; + + let lat2 = lat2 * PI / 180.0; + let lng2 = lng2 * PI / 180.0; + + let delta_longitude = lng2 - lng1; + + let y = delta_longitude.sin() * lat2.cos(); + let x = lat1.cos() * lat2.sin() - lat1.sin() * lat2.cos() * delta_longitude.cos(); + + let mut brng = y.atan2(x); + brng = brng.to_degrees(); + + (brng + 360.0) % 360.0 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn testing() { + assert_eq!( + format!( + "{:.0}º", + bearing( + -27.2020447088982, + -49.631891179172555, + -3.106362, + -60.025826, + ) + ), + "336º" + ); + } +} diff --git a/src/navigation/haversine.rs b/src/navigation/haversine.rs new file mode 100644 index 00000000000..27e61eb535c --- /dev/null +++ b/src/navigation/haversine.rs @@ -0,0 +1,35 @@ +use std::f64::consts::PI; + +const EARTH_RADIUS: f64 = 6371000.00; + +pub fn haversine(lat1: f64, lng1: f64, lat2: f64, lng2: f64) -> f64 { + let delta_dist_lat = (lat2 - lat1) * PI / 180.0; + let delta_dist_lng = (lng2 - lng1) * PI / 180.0; + + let cos1 = lat1 * PI / 180.0; + let cos2 = lat2 * PI / 180.0; + + let delta_lat = (delta_dist_lat / 2.0).sin().powf(2.0); + let delta_lng = (delta_dist_lng / 2.0).sin().powf(2.0); + + let a = delta_lat + delta_lng * cos1.cos() * cos2.cos(); + let result = 2.0 * a.asin().sqrt(); + + result * EARTH_RADIUS +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn testing() { + assert_eq!( + format!( + "{:.2}km", + haversine(52.375603, 4.903206, 52.366059, 4.926692) / 1000.0 + ), + "1.92km" + ); + } +} diff --git a/src/navigation/mod.rs b/src/navigation/mod.rs new file mode 100644 index 00000000000..e62be90acbc --- /dev/null +++ b/src/navigation/mod.rs @@ -0,0 +1,5 @@ +mod bearing; +mod haversine; + +pub use self::bearing::bearing; +pub use self::haversine::haversine; diff --git a/src/number_theory/compute_totient.rs b/src/number_theory/compute_totient.rs new file mode 100644 index 00000000000..88af0649fcd --- /dev/null +++ b/src/number_theory/compute_totient.rs @@ -0,0 +1,60 @@ +// Totient function for +// all numbers smaller than +// or equal to n. + +// Computes and prints +// totient of all numbers +// smaller than or equal to n + +use std::vec; + +pub fn compute_totient(n: i32) -> vec::Vec { + let mut phi: Vec = Vec::new(); + + // initialize phi[i] = i + for i in 0..=n { + phi.push(i); + } + + // Compute other Phi values + for p in 2..=n { + // If phi[p] is not computed already, + // then number p is prime + if phi[(p) as usize] == p { + // Phi of a prime number p is + // always equal to p-1. + phi[(p) as usize] = p - 1; + + // Update phi values of all + // multiples of p + for i in ((2 * p)..=n).step_by(p as usize) { + phi[(i) as usize] = (phi[i as usize] / p) * (p - 1); + } + } + } + + phi[1..].to_vec() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_1() { + assert_eq!( + compute_totient(12), + vec![1, 1, 2, 2, 4, 2, 6, 4, 6, 4, 10, 4] + ); + } + + #[test] + fn test_2() { + assert_eq!(compute_totient(7), vec![1, 1, 2, 2, 4, 2, 6]); + } + + #[test] + fn test_3() { + assert_eq!(compute_totient(4), vec![1, 1, 2, 2]); + } +} diff --git a/src/number_theory/euler_totient.rs b/src/number_theory/euler_totient.rs new file mode 100644 index 00000000000..69c0694a335 --- /dev/null +++ b/src/number_theory/euler_totient.rs @@ -0,0 +1,74 @@ +pub fn euler_totient(n: u64) -> u64 { + let mut result = n; + let mut num = n; + let mut p = 2; + + // Find all prime factors and apply formula + while p * p <= num { + // Check if p is a divisor of n + if num % p == 0 { + // If yes, then it is a prime factor + // Apply the formula: result = result * (1 - 1/p) + while num % p == 0 { + num /= p; + } + result -= result / p; + } + p += 1; + } + + // If num > 1, then it is a prime factor + if num > 1 { + result -= result / num; + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + macro_rules! test_euler_totient { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(euler_totient(input), expected) + } + )* + }; + } + + test_euler_totient! { + prime_2: (2, 1), + prime_3: (3, 2), + prime_5: (5, 4), + prime_7: (7, 6), + prime_11: (11, 10), + prime_13: (13, 12), + prime_17: (17, 16), + prime_19: (19, 18), + + composite_6: (6, 2), // 2 * 3 + composite_10: (10, 4), // 2 * 5 + composite_15: (15, 8), // 3 * 5 + composite_12: (12, 4), // 2^2 * 3 + composite_18: (18, 6), // 2 * 3^2 + composite_20: (20, 8), // 2^2 * 5 + composite_30: (30, 8), // 2 * 3 * 5 + + prime_power_2_to_2: (4, 2), + prime_power_2_to_3: (8, 4), + prime_power_3_to_2: (9, 6), + prime_power_2_to_4: (16, 8), + prime_power_5_to_2: (25, 20), + prime_power_3_to_3: (27, 18), + prime_power_2_to_5: (32, 16), + + // Large numbers + large_50: (50, 20), // 2 * 5^2 + large_100: (100, 40), // 2^2 * 5^2 + large_1000: (1000, 400), // 2^3 * 5^3 + } +} diff --git a/src/number_theory/kth_factor.rs b/src/number_theory/kth_factor.rs new file mode 100644 index 00000000000..abfece86bb9 --- /dev/null +++ b/src/number_theory/kth_factor.rs @@ -0,0 +1,41 @@ +// Kth Factor of N +// The idea is to check for each number in the range [N, 1], and print the Kth number that divides N completely. + +pub fn kth_factor(n: i32, k: i32) -> i32 { + let mut factors: Vec = Vec::new(); + let k = (k as usize) - 1; + for i in 1..=n { + if n % i == 0 { + factors.push(i); + } + if let Some(number) = factors.get(k) { + return *number; + } + } + -1 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_1() { + assert_eq!(kth_factor(12, 3), 3); + } + + #[test] + fn test_2() { + assert_eq!(kth_factor(7, 2), 7); + } + + #[test] + fn test_3() { + assert_eq!(kth_factor(4, 4), -1); + } + + #[test] + fn test_4() { + assert_eq!(kth_factor(950, 5), 19); + } +} diff --git a/src/number_theory/mod.rs b/src/number_theory/mod.rs new file mode 100644 index 00000000000..0500ad775d1 --- /dev/null +++ b/src/number_theory/mod.rs @@ -0,0 +1,7 @@ +mod compute_totient; +mod euler_totient; +mod kth_factor; + +pub use self::compute_totient::compute_totient; +pub use self::euler_totient::euler_totient; +pub use self::kth_factor::kth_factor; diff --git a/src/searching/README.md b/src/searching/README.md index be5ff9d0f2c..cd5f6f07d1c 100644 --- a/src/searching/README.md +++ b/src/searching/README.md @@ -23,8 +23,48 @@ __Properties__ * Average case performance O(log n) * Worst case space complexity O(1) +### [Exponential](./exponential_search.rs) +![alt text][exponential-image] + +From [Wikipedia][exponential-wiki]: Exponential search allows for searching through a sorted, unbounded list for a specified input value (the search "key"). The algorithm consists of two stages. The first stage determines a range in which the search key would reside if it were in the list. In the second stage, a binary search is performed on this range. In the first stage, assuming that the list is sorted in ascending order, the algorithm looks for the first exponent, j, where the value 2^j is greater than the search key. This value, 2^j becomes the upper bound for the binary search with the previous power of 2, 2^(j - 1), being the lower bound for the binary search. + +__Properties__ +* Worst case performance O(log i) +* Best case performance O(1) +* Average case performance O(log i) +* Worst case space complexity O(1) + +### [Jump](./jump_search.rs) +![alt text][jump-image] + +From [Wikipedia][jump-wiki]: In computer science, a jump search or block search refers to a search algorithm for ordered lists. It works by first checking all items L(km), where k ∈ N and m is the block size, until an item is found that is larger than the search key. To find the exact position of the search key in the list a linear search is performed on the sublist L[(k-1)m, km]. + +__Properties__ +* Worst case performance O(√n) +* Best case performance O(1) +* Average case performance O(√n) +* Worst case space complexity O(1) + +### [Fibonacci](./fibonacci_search.rs) + +From [Wikipedia][fibonacci-wiki]: In computer science, the Fibonacci search technique is a method of searching a sorted array using a divide and conquer algorithm that narrows down possible locations with the aid of Fibonacci numbers. Compared to binary search where the sorted array is divided into two equal-sized parts, one of which is examined further, Fibonacci search divides the array into two parts that have sizes that are consecutive Fibonacci numbers. + +__Properties__ +* Worst case performance O(log n) +* Best case performance O(1) +* Average case performance O(log n) +* Worst case space complexity O(1) + [linear-wiki]: https://en.wikipedia.org/wiki/Linear_search [linear-image]: http://www.tutorialspoint.com/data_structures_algorithms/images/linear_search.gif [binary-wiki]: https://en.wikipedia.org/wiki/Binary_search_algorithm [binary-image]: https://upload.wikimedia.org/wikipedia/commons/f/f7/Binary_search_into_array.png + +[exponential-wiki]: https://en.wikipedia.org/wiki/Exponential_search +[exponential-image]: https://upload.wikimedia.org/wikipedia/commons/4/45/Exponential_search.svg + +[jump-wiki]: https://en.wikipedia.org/wiki/Jump_search +[jump-image]: https://static.studytonight.com/data-structures/images/Jump%20Search%20technique.PNG + +[fibonacci-wiki]: https://en.wikipedia.org/wiki/Fibonacci_search_technique diff --git a/src/searching/binary_search.rs b/src/searching/binary_search.rs index 80abcc9d652..4c64c58217c 100644 --- a/src/searching/binary_search.rs +++ b/src/searching/binary_search.rs @@ -1,61 +1,153 @@ +//! This module provides an implementation of a binary search algorithm that +//! works for both ascending and descending ordered arrays. The binary search +//! function returns the index of the target element if it is found, or `None` +//! if the target is not present in the array. + use std::cmp::Ordering; +/// Performs a binary search for a specified item within a sorted array. +/// +/// This function can handle both ascending and descending ordered arrays. It +/// takes a reference to the item to search for and a slice of the array. If +/// the item is found, it returns the index of the item within the array. If +/// the item is not found, it returns `None`. +/// +/// # Parameters +/// +/// - `item`: A reference to the item to search for. +/// - `arr`: A slice of the sorted array in which to search. +/// +/// # Returns +/// +/// An `Option` which is: +/// - `Some(index)` if the item is found at the given index. +/// - `None` if the item is not found in the array. pub fn binary_search(item: &T, arr: &[T]) -> Option { + let is_asc = is_asc_arr(arr); + let mut left = 0; let mut right = arr.len(); while left < right { - let mid = left + (right - left) / 2; - - match item.cmp(&arr[mid]) { - Ordering::Less => right = mid, - Ordering::Equal => return Some(mid), - Ordering::Greater => left = mid + 1, + if match_compare(item, arr, &mut left, &mut right, is_asc) { + return Some(left); } } + None } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn empty() { - let index = binary_search(&"a", &vec![]); - assert_eq!(index, None); - } +/// Compares the item with the middle element of the current search range and +/// updates the search bounds accordingly. This function handles both ascending +/// and descending ordered arrays. It calculates the middle index of the +/// current search range and compares the item with the element at +/// this index. It then updates the search bounds (`left` and `right`) based on +/// the result of this comparison. If the item is found, it updates `left` to +/// the index of the found item and returns `true`. +/// +/// # Parameters +/// +/// - `item`: A reference to the item to search for. +/// - `arr`: A slice of the array in which to search. +/// - `left`: A mutable reference to the left bound of the search range. +/// - `right`: A mutable reference to the right bound of the search range. +/// - `is_asc`: A boolean indicating whether the array is sorted in ascending order. +/// +/// # Returns +/// +/// A `bool` indicating whether the item was found. +fn match_compare( + item: &T, + arr: &[T], + left: &mut usize, + right: &mut usize, + is_asc: bool, +) -> bool { + let mid = *left + (*right - *left) / 2; + let cmp_result = item.cmp(&arr[mid]); - #[test] - fn one_item() { - let index = binary_search(&"a", &vec!["a"]); - assert_eq!(index, Some(0)); - } - - #[test] - fn search_strings() { - let index = binary_search(&"a", &vec!["a", "b", "c", "d", "google", "zoo"]); - assert_eq!(index, Some(0)); + match (is_asc, cmp_result) { + (true, Ordering::Less) | (false, Ordering::Greater) => { + *right = mid; + } + (true, Ordering::Greater) | (false, Ordering::Less) => { + *left = mid + 1; + } + (_, Ordering::Equal) => { + *left = mid; + return true; + } } - #[test] - fn search_ints() { - let index = binary_search(&4, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(3)); + false +} - let index = binary_search(&3, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(2)); +/// Determines if the given array is sorted in ascending order. +/// +/// This helper function checks if the first element of the array is less than the +/// last element, indicating an ascending order. It returns `false` if the array +/// has fewer than two elements. +/// +/// # Parameters +/// +/// - `arr`: A slice of the array to check. +/// +/// # Returns +/// +/// A `bool` indicating whether the array is sorted in ascending order. +fn is_asc_arr(arr: &[T]) -> bool { + arr.len() > 1 && arr[0] < arr[arr.len() - 1] +} - let index = binary_search(&2, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(1)); +#[cfg(test)] +mod tests { + use super::*; - let index = binary_search(&1, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(0)); + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (item, arr, expected) = $test_case; + assert_eq!(binary_search(&item, arr), expected); + } + )* + }; } - #[test] - fn not_found() { - let index = binary_search(&5, &vec![1, 2, 3, 4]); - assert_eq!(index, None); + test_cases! { + empty: ("a", &[] as &[&str], None), + one_item_found: ("a", &["a"], Some(0)), + one_item_not_found: ("b", &["a"], None), + search_strings_asc_start: ("a", &["a", "b", "c", "d", "google", "zoo"], Some(0)), + search_strings_asc_middle: ("google", &["a", "b", "c", "d", "google", "zoo"], Some(4)), + search_strings_asc_last: ("zoo", &["a", "b", "c", "d", "google", "zoo"], Some(5)), + search_strings_asc_not_found: ("x", &["a", "b", "c", "d", "google", "zoo"], None), + search_strings_desc_start: ("zoo", &["zoo", "google", "d", "c", "b", "a"], Some(0)), + search_strings_desc_middle: ("google", &["zoo", "google", "d", "c", "b", "a"], Some(1)), + search_strings_desc_last: ("a", &["zoo", "google", "d", "c", "b", "a"], Some(5)), + search_strings_desc_not_found: ("x", &["zoo", "google", "d", "c", "b", "a"], None), + search_ints_asc_start: (1, &[1, 2, 3, 4], Some(0)), + search_ints_asc_middle: (3, &[1, 2, 3, 4], Some(2)), + search_ints_asc_end: (4, &[1, 2, 3, 4], Some(3)), + search_ints_asc_not_found: (5, &[1, 2, 3, 4], None), + search_ints_desc_start: (4, &[4, 3, 2, 1], Some(0)), + search_ints_desc_middle: (3, &[4, 3, 2, 1], Some(1)), + search_ints_desc_end: (1, &[4, 3, 2, 1], Some(3)), + search_ints_desc_not_found: (5, &[4, 3, 2, 1], None), + with_gaps_0: (0, &[1, 3, 8, 11], None), + with_gaps_1: (1, &[1, 3, 8, 11], Some(0)), + with_gaps_2: (2, &[1, 3, 8, 11], None), + with_gaps_3: (3, &[1, 3, 8, 11], Some(1)), + with_gaps_4: (4, &[1, 3, 8, 10], None), + with_gaps_5: (5, &[1, 3, 8, 10], None), + with_gaps_6: (6, &[1, 3, 8, 10], None), + with_gaps_7: (7, &[1, 3, 8, 11], None), + with_gaps_8: (8, &[1, 3, 8, 11], Some(2)), + with_gaps_9: (9, &[1, 3, 8, 11], None), + with_gaps_10: (10, &[1, 3, 8, 11], None), + with_gaps_11: (11, &[1, 3, 8, 11], Some(3)), + with_gaps_12: (12, &[1, 3, 8, 11], None), + with_gaps_13: (13, &[1, 3, 8, 11], None), } } diff --git a/src/searching/binary_search_recursive.rs b/src/searching/binary_search_recursive.rs new file mode 100644 index 00000000000..e83fa2f48d5 --- /dev/null +++ b/src/searching/binary_search_recursive.rs @@ -0,0 +1,94 @@ +use std::cmp::Ordering; + +/// Recursively performs a binary search for a specified item within a sorted array. +/// +/// This function can handle both ascending and descending ordered arrays. It +/// takes a reference to the item to search for and a slice of the array. If +/// the item is found, it returns the index of the item within the array. If +/// the item is not found, it returns `None`. +/// +/// # Parameters +/// +/// - `item`: A reference to the item to search for. +/// - `arr`: A slice of the sorted array in which to search. +/// - `left`: The left bound of the current search range. +/// - `right`: The right bound of the current search range. +/// - `is_asc`: A boolean indicating whether the array is sorted in ascending order. +/// +/// # Returns +/// +/// An `Option` which is: +/// - `Some(index)` if the item is found at the given index. +/// - `None` if the item is not found in the array. +pub fn binary_search_rec(item: &T, arr: &[T], left: usize, right: usize) -> Option { + if left >= right { + return None; + } + + let is_asc = arr.len() > 1 && arr[0] < arr[arr.len() - 1]; + let mid = left + (right - left) / 2; + let cmp_result = item.cmp(&arr[mid]); + + match (is_asc, cmp_result) { + (true, Ordering::Less) | (false, Ordering::Greater) => { + binary_search_rec(item, arr, left, mid) + } + (true, Ordering::Greater) | (false, Ordering::Less) => { + binary_search_rec(item, arr, mid + 1, right) + } + (_, Ordering::Equal) => Some(mid), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (item, arr, expected) = $test_case; + assert_eq!(binary_search_rec(&item, arr, 0, arr.len()), expected); + } + )* + }; + } + + test_cases! { + empty: ("a", &[] as &[&str], None), + one_item_found: ("a", &["a"], Some(0)), + one_item_not_found: ("b", &["a"], None), + search_strings_asc_start: ("a", &["a", "b", "c", "d", "google", "zoo"], Some(0)), + search_strings_asc_middle: ("google", &["a", "b", "c", "d", "google", "zoo"], Some(4)), + search_strings_asc_last: ("zoo", &["a", "b", "c", "d", "google", "zoo"], Some(5)), + search_strings_asc_not_found: ("x", &["a", "b", "c", "d", "google", "zoo"], None), + search_strings_desc_start: ("zoo", &["zoo", "google", "d", "c", "b", "a"], Some(0)), + search_strings_desc_middle: ("google", &["zoo", "google", "d", "c", "b", "a"], Some(1)), + search_strings_desc_last: ("a", &["zoo", "google", "d", "c", "b", "a"], Some(5)), + search_strings_desc_not_found: ("x", &["zoo", "google", "d", "c", "b", "a"], None), + search_ints_asc_start: (1, &[1, 2, 3, 4], Some(0)), + search_ints_asc_middle: (3, &[1, 2, 3, 4], Some(2)), + search_ints_asc_end: (4, &[1, 2, 3, 4], Some(3)), + search_ints_asc_not_found: (5, &[1, 2, 3, 4], None), + search_ints_desc_start: (4, &[4, 3, 2, 1], Some(0)), + search_ints_desc_middle: (3, &[4, 3, 2, 1], Some(1)), + search_ints_desc_end: (1, &[4, 3, 2, 1], Some(3)), + search_ints_desc_not_found: (5, &[4, 3, 2, 1], None), + with_gaps_0: (0, &[1, 3, 8, 11], None), + with_gaps_1: (1, &[1, 3, 8, 11], Some(0)), + with_gaps_2: (2, &[1, 3, 8, 11], None), + with_gaps_3: (3, &[1, 3, 8, 11], Some(1)), + with_gaps_4: (4, &[1, 3, 8, 10], None), + with_gaps_5: (5, &[1, 3, 8, 10], None), + with_gaps_6: (6, &[1, 3, 8, 10], None), + with_gaps_7: (7, &[1, 3, 8, 11], None), + with_gaps_8: (8, &[1, 3, 8, 11], Some(2)), + with_gaps_9: (9, &[1, 3, 8, 11], None), + with_gaps_10: (10, &[1, 3, 8, 11], None), + with_gaps_11: (11, &[1, 3, 8, 11], Some(3)), + with_gaps_12: (12, &[1, 3, 8, 11], None), + with_gaps_13: (13, &[1, 3, 8, 11], None), + } +} diff --git a/src/searching/exponential_search.rs b/src/searching/exponential_search.rs new file mode 100644 index 00000000000..be700956149 --- /dev/null +++ b/src/searching/exponential_search.rs @@ -0,0 +1,72 @@ +use std::cmp::Ordering; + +pub fn exponential_search(item: &T, arr: &[T]) -> Option { + let len = arr.len(); + if len == 0 { + return None; + } + let mut upper = 1; + while (upper < len) && (&arr[upper] <= item) { + upper *= 2; + } + if upper > len { + upper = len + } + + // binary search + let mut lower = upper / 2; + while lower < upper { + let mid = lower + (upper - lower) / 2; + + match item.cmp(&arr[mid]) { + Ordering::Less => upper = mid, + Ordering::Equal => return Some(mid), + Ordering::Greater => lower = mid + 1, + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn empty() { + let index = exponential_search(&"a", &[]); + assert_eq!(index, None); + } + + #[test] + fn one_item() { + let index = exponential_search(&"a", &["a"]); + assert_eq!(index, Some(0)); + } + + #[test] + fn search_strings() { + let index = exponential_search(&"a", &["a", "b", "c", "d", "google", "zoo"]); + assert_eq!(index, Some(0)); + } + + #[test] + fn search_ints() { + let index = exponential_search(&4, &[1, 2, 3, 4]); + assert_eq!(index, Some(3)); + + let index = exponential_search(&3, &[1, 2, 3, 4]); + assert_eq!(index, Some(2)); + + let index = exponential_search(&2, &[1, 2, 3, 4]); + assert_eq!(index, Some(1)); + + let index = exponential_search(&1, &[1, 2, 3, 4]); + assert_eq!(index, Some(0)); + } + + #[test] + fn not_found() { + let index = exponential_search(&5, &[1, 2, 3, 4]); + assert_eq!(index, None); + } +} diff --git a/src/searching/fibonacci_search.rs b/src/searching/fibonacci_search.rs new file mode 100644 index 00000000000..dc33fbba884 --- /dev/null +++ b/src/searching/fibonacci_search.rs @@ -0,0 +1,84 @@ +use std::cmp::min; +use std::cmp::Ordering; + +pub fn fibonacci_search(item: &T, arr: &[T]) -> Option { + let len = arr.len(); + if len == 0 { + return None; + } + let mut start = -1; + + let mut f0 = 0; + let mut f1 = 1; + let mut f2 = 1; + while f2 < len { + f0 = f1; + f1 = f2; + f2 = f0 + f1; + } + while f2 > 1 { + let index = min((f0 as isize + start) as usize, len - 1); + match item.cmp(&arr[index]) { + Ordering::Less => { + f2 = f0; + f1 -= f0; + f0 = f2 - f1; + } + Ordering::Equal => return Some(index), + Ordering::Greater => { + f2 = f1; + f1 = f0; + f0 = f2 - f1; + start = index as isize; + } + } + } + if (f1 != 0) && (&arr[len - 1] == item) { + return Some(len - 1); + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn empty() { + let index = fibonacci_search(&"a", &[]); + assert_eq!(index, None); + } + + #[test] + fn one_item() { + let index = fibonacci_search(&"a", &["a"]); + assert_eq!(index, Some(0)); + } + + #[test] + fn search_strings() { + let index = fibonacci_search(&"a", &["a", "b", "c", "d", "google", "zoo"]); + assert_eq!(index, Some(0)); + } + + #[test] + fn search_ints() { + let index = fibonacci_search(&4, &[1, 2, 3, 4]); + assert_eq!(index, Some(3)); + + let index = fibonacci_search(&3, &[1, 2, 3, 4]); + assert_eq!(index, Some(2)); + + let index = fibonacci_search(&2, &[1, 2, 3, 4]); + assert_eq!(index, Some(1)); + + let index = fibonacci_search(&1, &[1, 2, 3, 4]); + assert_eq!(index, Some(0)); + } + + #[test] + fn not_found() { + let index = fibonacci_search(&5, &[1, 2, 3, 4]); + assert_eq!(index, None); + } +} diff --git a/src/searching/interpolation_search.rs b/src/searching/interpolation_search.rs new file mode 100644 index 00000000000..4ecb3229892 --- /dev/null +++ b/src/searching/interpolation_search.rs @@ -0,0 +1,57 @@ +pub fn interpolation_search(nums: &[i32], item: &i32) -> Result { + // early check + if nums.is_empty() { + return Err(0); + } + let mut low: usize = 0; + let mut high: usize = nums.len() - 1; + while low <= high { + if *item < nums[low] || *item > nums[high] { + break; + } + let offset: usize = low + + (((high - low) / (nums[high] - nums[low]) as usize) * (*item - nums[low]) as usize); + match nums[offset].cmp(item) { + std::cmp::Ordering::Equal => return Ok(offset), + std::cmp::Ordering::Less => low = offset + 1, + std::cmp::Ordering::Greater => high = offset - 1, + } + } + Err(0) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::cmp::Ordering; + + #[test] + fn returns_err_if_empty_slice() { + let nums = []; + assert_eq!(interpolation_search::(&nums, &3), Err(0)); + } + + #[test] + fn returns_err_if_target_not_found() { + let nums = [1, 2, 3, 4, 5, 6]; + assert_eq!(interpolation_search::(&nums, &10), Err(0)); + } + + #[test] + fn returns_first_index() { + let index: Result = interpolation_search::(&[1, 2, 3, 4, 5], &1); + assert_eq!(index, Ok(0)); + } + + #[test] + fn returns_last_index() { + let index: Result = interpolation_search::(&[1, 2, 3, 4, 5], &5); + assert_eq!(index, Ok(4)); + } + + #[test] + fn returns_middle_index() { + let index: Result = interpolation_search::(&[1, 2, 3, 4, 5], &3); + assert_eq!(index, Ok(2)); + } +} diff --git a/src/searching/jump_search.rs b/src/searching/jump_search.rs new file mode 100644 index 00000000000..64d49331a30 --- /dev/null +++ b/src/searching/jump_search.rs @@ -0,0 +1,65 @@ +use std::cmp::min; + +pub fn jump_search(item: &T, arr: &[T]) -> Option { + let len = arr.len(); + if len == 0 { + return None; + } + let mut step = (len as f64).sqrt() as usize; + let mut prev = 0; + + while &arr[min(len, step) - 1] < item { + prev = step; + step += (len as f64).sqrt() as usize; + if prev >= len { + return None; + } + } + while &arr[prev] < item { + prev += 1; + } + if &arr[prev] == item { + return Some(prev); + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn empty() { + assert!(jump_search(&"a", &[]).is_none()); + } + + #[test] + fn one_item() { + assert_eq!(jump_search(&"a", &["a"]).unwrap(), 0); + } + + #[test] + fn search_strings() { + assert_eq!( + jump_search(&"a", &["a", "b", "c", "d", "google", "zoo"]).unwrap(), + 0 + ); + } + + #[test] + fn search_ints() { + let arr = [1, 2, 3, 4]; + assert_eq!(jump_search(&4, &arr).unwrap(), 3); + assert_eq!(jump_search(&3, &arr).unwrap(), 2); + assert_eq!(jump_search(&2, &arr).unwrap(), 1); + assert_eq!(jump_search(&1, &arr).unwrap(), 0); + } + + #[test] + fn not_found() { + let arr = [1, 2, 3, 4]; + + assert!(jump_search(&5, &arr).is_none()); + assert!(jump_search(&0, &arr).is_none()); + } +} diff --git a/src/searching/kth_smallest.rs b/src/searching/kth_smallest.rs new file mode 100644 index 00000000000..39c77fd7412 --- /dev/null +++ b/src/searching/kth_smallest.rs @@ -0,0 +1,72 @@ +use crate::sorting::partition; +use std::cmp::{Ordering, PartialOrd}; + +/// Returns k-th smallest element of an array, i.e. its order statistics. +/// Time complexity is O(n^2) in the worst case, but only O(n) on average. +/// It mutates the input, and therefore does not require additional space. +pub fn kth_smallest(input: &mut [T], k: usize) -> Option +where + T: PartialOrd + Copy, +{ + if input.is_empty() { + return None; + } + + let kth = _kth_smallest(input, k, 0, input.len() - 1); + Some(kth) +} + +fn _kth_smallest(input: &mut [T], k: usize, lo: usize, hi: usize) -> T +where + T: PartialOrd + Copy, +{ + if lo == hi { + return input[lo]; + } + + let pivot = partition(input, lo, hi); + let i = pivot - lo + 1; + + match k.cmp(&i) { + Ordering::Equal => input[pivot], + Ordering::Less => _kth_smallest(input, k, lo, pivot - 1), + Ordering::Greater => _kth_smallest(input, k - i, pivot + 1, hi), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn empty() { + let mut zero: [u8; 0] = []; + let first = kth_smallest(&mut zero, 1); + + assert_eq!(None, first); + } + + #[test] + fn one_element() { + let mut one = [1]; + let first = kth_smallest(&mut one, 1); + + assert_eq!(1, first.unwrap()); + } + + #[test] + fn many_elements() { + // 0 1 3 4 5 7 8 9 9 10 12 13 16 17 + let mut many = [9, 17, 3, 16, 13, 10, 1, 5, 7, 12, 4, 8, 9, 0]; + + let first = kth_smallest(&mut many, 1); + let third = kth_smallest(&mut many, 3); + let sixth = kth_smallest(&mut many, 6); + let fourteenth = kth_smallest(&mut many, 14); + + assert_eq!(0, first.unwrap()); + assert_eq!(3, third.unwrap()); + assert_eq!(7, sixth.unwrap()); + assert_eq!(17, fourteenth.unwrap()); + } +} diff --git a/src/searching/kth_smallest_heap.rs b/src/searching/kth_smallest_heap.rs new file mode 100644 index 00000000000..fe2a3c15a5f --- /dev/null +++ b/src/searching/kth_smallest_heap.rs @@ -0,0 +1,89 @@ +use crate::data_structures::Heap; +use std::cmp::{Ord, Ordering}; + +/// Returns k-th smallest element of an array. +/// Time complexity is stably O(nlog(k)) in all cases +/// Extra space is required to maintain the heap, and it doesn't +/// mutate the input list. +/// +/// It is preferrable to the partition-based algorithm in cases when +/// we want to maintain the kth smallest element dynamically against +/// a stream of elements. In that case, once the heap is built, further +/// operation's complexity is O(log(k)). +pub fn kth_smallest_heap(input: &[T], k: usize) -> Option +where + T: Ord + Copy, +{ + if input.len() < k { + return None; + } + + // heap will maintain the kth smallest elements + // seen so far, when new elements, E_new arrives, + // it is compared with the largest element of the + // current Heap E_large, which is the current kth + // smallest elements. + // if E_new > E_large, then E_new cannot be the kth + // smallest because there are already k elements smaller + // than it + // otherwise, E_large cannot be the kth smallest, and should + // be removed from the heap and E_new should be added + let mut heap = Heap::new_max(); + + // first k elements goes to the heap as the baseline + for &val in input.iter().take(k) { + heap.add(val); + } + + for &val in input.iter().skip(k) { + // compare new value to the current kth smallest value + let cur_big = heap.pop().unwrap(); // heap.pop() can't be None + match val.cmp(&cur_big) { + Ordering::Greater => { + heap.add(cur_big); + } + _ => { + heap.add(val); + } + } + } + + heap.pop() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn empty() { + let zero: [u8; 0] = []; + let first = kth_smallest_heap(&zero, 1); + + assert_eq!(None, first); + } + + #[test] + fn one_element() { + let one = [1]; + let first = kth_smallest_heap(&one, 1); + + assert_eq!(1, first.unwrap()); + } + + #[test] + fn many_elements() { + // 0 1 3 4 5 7 8 9 9 10 12 13 16 17 + let many = [9, 17, 3, 16, 13, 10, 1, 5, 7, 12, 4, 8, 9, 0]; + + let first = kth_smallest_heap(&many, 1); + let third = kth_smallest_heap(&many, 3); + let sixth = kth_smallest_heap(&many, 6); + let fourteenth = kth_smallest_heap(&many, 14); + + assert_eq!(0, first.unwrap()); + assert_eq!(3, third.unwrap()); + assert_eq!(7, sixth.unwrap()); + assert_eq!(17, fourteenth.unwrap()); + } +} diff --git a/src/searching/linear_search.rs b/src/searching/linear_search.rs index d3a0be48042..d38b224d0a6 100644 --- a/src/searching/linear_search.rs +++ b/src/searching/linear_search.rs @@ -1,6 +1,15 @@ -use std::cmp::PartialEq; - -pub fn linear_search(item: &T, arr: &[T]) -> Option { +/// Performs a linear search on the given array, returning the index of the first occurrence of the item. +/// +/// # Arguments +/// +/// * `item` - A reference to the item to search for in the array. +/// * `arr` - A slice of items to search within. +/// +/// # Returns +/// +/// * `Some(usize)` - The index of the first occurrence of the item, if found. +/// * `None` - If the item is not found in the array. +pub fn linear_search(item: &T, arr: &[T]) -> Option { for (i, data) in arr.iter().enumerate() { if item == data { return Some(i); @@ -14,36 +23,54 @@ pub fn linear_search(item: &T, arr: &[T]) -> Option { mod tests { use super::*; - #[test] - fn search_strings() { - let index = linear_search(&"a", &vec!["a", "b", "c", "d", "google", "zoo"]); - assert_eq!(index, Some(0)); - } - - #[test] - fn search_ints() { - let index = linear_search(&4, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(3)); - - let index = linear_search(&3, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(2)); - - let index = linear_search(&2, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(1)); - - let index = linear_search(&1, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(0)); - } - - #[test] - fn not_found() { - let index = linear_search(&5, &vec![1, 2, 3, 4]); - assert_eq!(index, None); + macro_rules! test_cases { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (item, arr, expected) = $tc; + if let Some(expected_index) = expected { + assert_eq!(arr[expected_index], item); + } + assert_eq!(linear_search(&item, arr), expected); + } + )* + } } - #[test] - fn empty() { - let index = linear_search(&1, &vec![]); - assert_eq!(index, None); + test_cases! { + empty: ("a", &[] as &[&str], None), + one_item_found: ("a", &["a"], Some(0)), + one_item_not_found: ("b", &["a"], None), + search_strings_asc_start: ("a", &["a", "b", "c", "d", "google", "zoo"], Some(0)), + search_strings_asc_middle: ("google", &["a", "b", "c", "d", "google", "zoo"], Some(4)), + search_strings_asc_last: ("zoo", &["a", "b", "c", "d", "google", "zoo"], Some(5)), + search_strings_asc_not_found: ("x", &["a", "b", "c", "d", "google", "zoo"], None), + search_strings_desc_start: ("zoo", &["zoo", "google", "d", "c", "b", "a"], Some(0)), + search_strings_desc_middle: ("google", &["zoo", "google", "d", "c", "b", "a"], Some(1)), + search_strings_desc_last: ("a", &["zoo", "google", "d", "c", "b", "a"], Some(5)), + search_strings_desc_not_found: ("x", &["zoo", "google", "d", "c", "b", "a"], None), + search_ints_asc_start: (1, &[1, 2, 3, 4], Some(0)), + search_ints_asc_middle: (3, &[1, 2, 3, 4], Some(2)), + search_ints_asc_end: (4, &[1, 2, 3, 4], Some(3)), + search_ints_asc_not_found: (5, &[1, 2, 3, 4], None), + search_ints_desc_start: (4, &[4, 3, 2, 1], Some(0)), + search_ints_desc_middle: (3, &[4, 3, 2, 1], Some(1)), + search_ints_desc_end: (1, &[4, 3, 2, 1], Some(3)), + search_ints_desc_not_found: (5, &[4, 3, 2, 1], None), + with_gaps_0: (0, &[1, 3, 8, 11], None), + with_gaps_1: (1, &[1, 3, 8, 11], Some(0)), + with_gaps_2: (2, &[1, 3, 8, 11], None), + with_gaps_3: (3, &[1, 3, 8, 11], Some(1)), + with_gaps_4: (4, &[1, 3, 8, 10], None), + with_gaps_5: (5, &[1, 3, 8, 10], None), + with_gaps_6: (6, &[1, 3, 8, 10], None), + with_gaps_7: (7, &[1, 3, 8, 11], None), + with_gaps_8: (8, &[1, 3, 8, 11], Some(2)), + with_gaps_9: (9, &[1, 3, 8, 11], None), + with_gaps_10: (10, &[1, 3, 8, 11], None), + with_gaps_11: (11, &[1, 3, 8, 11], Some(3)), + with_gaps_12: (12, &[1, 3, 8, 11], None), + with_gaps_13: (13, &[1, 3, 8, 11], None), } } diff --git a/src/searching/mod.rs b/src/searching/mod.rs index 2c53d4c13ff..94f65988195 100644 --- a/src/searching/mod.rs +++ b/src/searching/mod.rs @@ -1,5 +1,35 @@ mod binary_search; +mod binary_search_recursive; +mod exponential_search; +mod fibonacci_search; +mod interpolation_search; +mod jump_search; +mod kth_smallest; +mod kth_smallest_heap; mod linear_search; +mod moore_voting; +mod quick_select; +mod saddleback_search; +mod ternary_search; +mod ternary_search_min_max; +mod ternary_search_min_max_recursive; +mod ternary_search_recursive; pub use self::binary_search::binary_search; +pub use self::binary_search_recursive::binary_search_rec; +pub use self::exponential_search::exponential_search; +pub use self::fibonacci_search::fibonacci_search; +pub use self::interpolation_search::interpolation_search; +pub use self::jump_search::jump_search; +pub use self::kth_smallest::kth_smallest; +pub use self::kth_smallest_heap::kth_smallest_heap; pub use self::linear_search::linear_search; +pub use self::moore_voting::moore_voting; +pub use self::quick_select::quick_select; +pub use self::saddleback_search::saddleback_search; +pub use self::ternary_search::ternary_search; +pub use self::ternary_search_min_max::ternary_search_max; +pub use self::ternary_search_min_max::ternary_search_min; +pub use self::ternary_search_min_max_recursive::ternary_search_max_rec; +pub use self::ternary_search_min_max_recursive::ternary_search_min_rec; +pub use self::ternary_search_recursive::ternary_search_rec; diff --git a/src/searching/moore_voting.rs b/src/searching/moore_voting.rs new file mode 100644 index 00000000000..8acf0dd8b36 --- /dev/null +++ b/src/searching/moore_voting.rs @@ -0,0 +1,80 @@ +/* + + Moore's voting algorithm finds out the strictly majority-occurring element + without using extra space + and O(n) + O(n) time complexity + + It is built on the intuition that a strictly major element will always have a net occurrence as 1. + Say, array given: 9 1 8 1 1 + Here, the algorithm will work as: + + (for finding element present >(n/2) times) + (assumed: all elements are >0) + + Initialisation: ele=0, cnt=0 + Loop beings. + + loop 1: arr[0]=9 + ele = 9 + cnt=1 (since cnt = 0, cnt increments to 1 and ele = 9) + + loop 2: arr[1]=1 + ele = 9 + cnt= 0 (since in this turn of the loop, the array[i] != ele, cnt decrements by 1) + + loop 3: arr[2]=8 + ele = 8 + cnt=1 (since cnt = 0, cnt increments to 1 and ele = 8) + + loop 4: arr[3]=1 + ele = 8 + cnt= 0 (since in this turn of the loop, the array[i] != ele, cnt decrements by 1) + + loop 5: arr[4]=1 + ele = 9 + cnt=1 (since cnt = 0, cnt increments to 1 and ele = 1) + + Now, this ele should be the majority element if there's any + To check, a quick O(n) loop is run to check if the count of ele is >(n/2), n being the length of the array + + -1 is returned when no such element is found. + +*/ + +pub fn moore_voting(arr: &[i32]) -> i32 { + let n = arr.len(); + let mut cnt = 0; // initializing cnt + let mut ele = 0; // initializing ele + + for &item in arr.iter() { + if cnt == 0 { + cnt = 1; + ele = item; + } else if item == ele { + cnt += 1; + } else { + cnt -= 1; + } + } + + let cnt_check = arr.iter().filter(|&&x| x == ele).count(); + + if cnt_check > (n / 2) { + ele + } else { + -1 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_moore_voting() { + let arr1: Vec = vec![9, 1, 8, 1, 1]; + assert!(moore_voting(&arr1) == 1); + let arr2: Vec = vec![1, 2, 3, 4]; + assert!(moore_voting(&arr2) == -1); + } +} diff --git a/src/searching/quick_select.rs b/src/searching/quick_select.rs new file mode 100644 index 00000000000..592c4ceb2f4 --- /dev/null +++ b/src/searching/quick_select.rs @@ -0,0 +1,44 @@ +// https://en.wikipedia.org/wiki/Quickselect + +fn partition(list: &mut [i32], left: usize, right: usize, pivot_index: usize) -> usize { + let pivot_value = list[pivot_index]; + list.swap(pivot_index, right); // Move pivot to end + let mut store_index = left; + for i in left..right { + if list[i] < pivot_value { + list.swap(store_index, i); + store_index += 1; + } + } + list.swap(right, store_index); // Move pivot to its final place + store_index +} + +pub fn quick_select(list: &mut [i32], left: usize, right: usize, index: usize) -> i32 { + if left == right { + // If the list contains only one element, + return list[left]; + } // return that element + let mut pivot_index = left + (right - left) / 2; // select a pivotIndex between left and right + pivot_index = partition(list, left, right, pivot_index); + // The pivot is in its final sorted position + match index { + x if x == pivot_index => list[index], + x if x < pivot_index => quick_select(list, left, pivot_index - 1, index), + _ => quick_select(list, pivot_index + 1, right, index), + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn it_works() { + let mut arr1 = [2, 3, 4, 5]; + assert_eq!(quick_select(&mut arr1, 0, 3, 1), 3); + let mut arr2 = [2, 5, 9, 12, 16]; + assert_eq!(quick_select(&mut arr2, 1, 3, 2), 9); + let mut arr2 = [0, 3, 8]; + assert_eq!(quick_select(&mut arr2, 0, 0, 0), 0); + } +} diff --git a/src/searching/saddleback_search.rs b/src/searching/saddleback_search.rs new file mode 100644 index 00000000000..1a722a29b1c --- /dev/null +++ b/src/searching/saddleback_search.rs @@ -0,0 +1,84 @@ +// Saddleback search is a technique used to find an element in a sorted 2D matrix in O(m + n) time, +// where m is the number of rows, and n is the number of columns. It works by starting from the +// top-right corner of the matrix and moving left or down based on the comparison of the current +// element with the target element. +use std::cmp::Ordering; + +pub fn saddleback_search(matrix: &[Vec], element: i32) -> (usize, usize) { + // Initialize left and right indices + let mut left_index = 0; + let mut right_index = matrix[0].len() - 1; + + // Start searching + while left_index < matrix.len() { + match element.cmp(&matrix[left_index][right_index]) { + // If the current element matches the target element, return its position (indices are 1-based) + Ordering::Equal => return (left_index + 1, right_index + 1), + Ordering::Greater => { + // If the target element is greater, move to the next row (downwards) + left_index += 1; + } + Ordering::Less => { + // If the target element is smaller, move to the previous column (leftwards) + if right_index == 0 { + break; // If we reach the left-most column, exit the loop + } + right_index -= 1; + } + } + } + + // If the element is not found, return (0, 0) + (0, 0) +} + +#[cfg(test)] +mod tests { + use super::*; + + // Test when the element is not present in the matrix + #[test] + fn test_element_not_found() { + let matrix = vec![vec![1, 10, 100], vec![2, 20, 200], vec![3, 30, 300]]; + assert_eq!(saddleback_search(&matrix, 123), (0, 0)); + } + + // Test when the element is at the top-left corner of the matrix + #[test] + fn test_element_at_top_left() { + let matrix = vec![vec![1, 10, 100], vec![2, 20, 200], vec![3, 30, 300]]; + assert_eq!(saddleback_search(&matrix, 1), (1, 1)); + } + + // Test when the element is at the bottom-right corner of the matrix + #[test] + fn test_element_at_bottom_right() { + let matrix = vec![vec![1, 10, 100], vec![2, 20, 200], vec![3, 30, 300]]; + assert_eq!(saddleback_search(&matrix, 300), (3, 3)); + } + + // Test when the element is at the top-right corner of the matrix + #[test] + fn test_element_at_top_right() { + let matrix = vec![vec![1, 10, 100], vec![2, 20, 200], vec![3, 30, 300]]; + assert_eq!(saddleback_search(&matrix, 100), (1, 3)); + } + + // Test when the element is at the bottom-left corner of the matrix + #[test] + fn test_element_at_bottom_left() { + let matrix = vec![vec![1, 10, 100], vec![2, 20, 200], vec![3, 30, 300]]; + assert_eq!(saddleback_search(&matrix, 3), (3, 1)); + } + + // Additional test case: Element in the middle of the matrix + #[test] + fn test_element_in_middle() { + let matrix = vec![ + vec![1, 10, 100, 1000], + vec![2, 20, 200, 2000], + vec![3, 30, 300, 3000], + ]; + assert_eq!(saddleback_search(&matrix, 200), (2, 3)); + } +} diff --git a/src/searching/ternary_search.rs b/src/searching/ternary_search.rs new file mode 100644 index 00000000000..cb9b5bee477 --- /dev/null +++ b/src/searching/ternary_search.rs @@ -0,0 +1,195 @@ +//! This module provides an implementation of a ternary search algorithm that +//! works for both ascending and descending ordered arrays. The ternary search +//! function returns the index of the target element if it is found, or `None` +//! if the target is not present in the array. + +use std::cmp::Ordering; + +/// Performs a ternary search for a specified item within a sorted array. +/// +/// This function can handle both ascending and descending ordered arrays. It +/// takes a reference to the item to search for and a slice of the array. If +/// the item is found, it returns the index of the item within the array. If +/// the item is not found, it returns `None`. +/// +/// # Parameters +/// +/// - `item`: A reference to the item to search for. +/// - `arr`: A slice of the sorted array in which to search. +/// +/// # Returns +/// +/// An `Option` which is: +/// - `Some(index)` if the item is found at the given index. +/// - `None` if the item is not found in the array. +pub fn ternary_search(item: &T, arr: &[T]) -> Option { + if arr.is_empty() { + return None; + } + + let is_asc = is_asc_arr(arr); + let mut left = 0; + let mut right = arr.len() - 1; + + while left <= right { + if match_compare(item, arr, &mut left, &mut right, is_asc) { + return Some(left); + } + } + + None +} + +/// Compares the item with two middle elements of the current search range and +/// updates the search bounds accordingly. This function handles both ascending +/// and descending ordered arrays. It calculates two middle indices of the +/// current search range and compares the item with the elements at these +/// indices. It then updates the search bounds (`left` and `right`) based on +/// the result of these comparisons. If the item is found, it returns `true`. +/// +/// # Parameters +/// +/// - `item`: A reference to the item to search for. +/// - `arr`: A slice of the array in which to search. +/// - `left`: A mutable reference to the left bound of the search range. +/// - `right`: A mutable reference to the right bound of the search range. +/// - `is_asc`: A boolean indicating whether the array is sorted in ascending order. +/// +/// # Returns +/// +/// A `bool` indicating: +/// - `true` if the item was found in the array. +/// - `false` if the item was not found in the array. +fn match_compare( + item: &T, + arr: &[T], + left: &mut usize, + right: &mut usize, + is_asc: bool, +) -> bool { + let first_mid = *left + (*right - *left) / 3; + let second_mid = *right - (*right - *left) / 3; + + // Handling the edge case where the search narrows down to a single element + if first_mid == second_mid && first_mid == *left { + return match &arr[*left] { + x if x == item => true, + _ => { + *left += 1; + false + } + }; + } + + let cmp_first_mid = item.cmp(&arr[first_mid]); + let cmp_second_mid = item.cmp(&arr[second_mid]); + + match (is_asc, cmp_first_mid, cmp_second_mid) { + // If the item matches either midpoint, it returns the index + (_, Ordering::Equal, _) => { + *left = first_mid; + return true; + } + (_, _, Ordering::Equal) => { + *left = second_mid; + return true; + } + // If the item is smaller than the element at first_mid (in ascending order) + // or greater than it (in descending order), it narrows the search to the first third. + (true, Ordering::Less, _) | (false, Ordering::Greater, _) => { + *right = first_mid.saturating_sub(1) + } + // If the item is greater than the element at second_mid (in ascending order) + // or smaller than it (in descending order), it narrows the search to the last third. + (true, _, Ordering::Greater) | (false, _, Ordering::Less) => *left = second_mid + 1, + // Otherwise, it searches the middle third. + (_, _, _) => { + *left = first_mid + 1; + *right = second_mid - 1; + } + } + + false +} + +/// Determines if the given array is sorted in ascending order. +/// +/// This helper function checks if the first element of the array is less than the +/// last element, indicating an ascending order. It returns `false` if the array +/// has fewer than two elements. +/// +/// # Parameters +/// +/// - `arr`: A slice of the array to check. +/// +/// # Returns +/// +/// A `bool` indicating whether the array is sorted in ascending order. +fn is_asc_arr(arr: &[T]) -> bool { + arr.len() > 1 && arr[0] < arr[arr.len() - 1] +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (item, arr, expected) = $test_case; + if let Some(expected_index) = expected { + assert_eq!(arr[expected_index], item); + } + assert_eq!(ternary_search(&item, arr), expected); + } + )* + }; + } + + test_cases! { + empty: ("a", &[] as &[&str], None), + one_item_found: ("a", &["a"], Some(0)), + one_item_not_found: ("b", &["a"], None), + search_two_elements_found_at_start: (1, &[1, 2], Some(0)), + search_two_elements_found_at_end: (2, &[1, 2], Some(1)), + search_two_elements_not_found_start: (0, &[1, 2], None), + search_two_elements_not_found_end: (3, &[1, 2], None), + search_three_elements_found_start: (1, &[1, 2, 3], Some(0)), + search_three_elements_found_middle: (2, &[1, 2, 3], Some(1)), + search_three_elements_found_end: (3, &[1, 2, 3], Some(2)), + search_three_elements_not_found_start: (0, &[1, 2, 3], None), + search_three_elements_not_found_end: (4, &[1, 2, 3], None), + search_strings_asc_start: ("a", &["a", "b", "c", "d", "google", "zoo"], Some(0)), + search_strings_asc_middle: ("google", &["a", "b", "c", "d", "google", "zoo"], Some(4)), + search_strings_asc_last: ("zoo", &["a", "b", "c", "d", "google", "zoo"], Some(5)), + search_strings_asc_not_found: ("x", &["a", "b", "c", "d", "google", "zoo"], None), + search_strings_desc_start: ("zoo", &["zoo", "google", "d", "c", "b", "a"], Some(0)), + search_strings_desc_middle: ("google", &["zoo", "google", "d", "c", "b", "a"], Some(1)), + search_strings_desc_last: ("a", &["zoo", "google", "d", "c", "b", "a"], Some(5)), + search_strings_desc_not_found: ("x", &["zoo", "google", "d", "c", "b", "a"], None), + search_ints_asc_start: (1, &[1, 2, 3, 4], Some(0)), + search_ints_asc_middle: (3, &[1, 2, 3, 4], Some(2)), + search_ints_asc_end: (4, &[1, 2, 3, 4], Some(3)), + search_ints_asc_not_found: (5, &[1, 2, 3, 4], None), + search_ints_desc_start: (4, &[4, 3, 2, 1], Some(0)), + search_ints_desc_middle: (3, &[4, 3, 2, 1], Some(1)), + search_ints_desc_end: (1, &[4, 3, 2, 1], Some(3)), + search_ints_desc_not_found: (5, &[4, 3, 2, 1], None), + with_gaps_0: (0, &[1, 3, 8, 11], None), + with_gaps_1: (1, &[1, 3, 8, 11], Some(0)), + with_gaps_2: (2, &[1, 3, 8, 11], None), + with_gaps_3: (3, &[1, 3, 8, 11], Some(1)), + with_gaps_4: (4, &[1, 3, 8, 10], None), + with_gaps_5: (5, &[1, 3, 8, 10], None), + with_gaps_6: (6, &[1, 3, 8, 10], None), + with_gaps_7: (7, &[1, 3, 8, 11], None), + with_gaps_8: (8, &[1, 3, 8, 11], Some(2)), + with_gaps_9: (9, &[1, 3, 8, 11], None), + with_gaps_10: (10, &[1, 3, 8, 11], None), + with_gaps_11: (11, &[1, 3, 8, 11], Some(3)), + with_gaps_12: (12, &[1, 3, 8, 11], None), + with_gaps_13: (13, &[1, 3, 8, 11], None), + } +} diff --git a/src/searching/ternary_search_min_max.rs b/src/searching/ternary_search_min_max.rs new file mode 100644 index 00000000000..f54aea5e972 --- /dev/null +++ b/src/searching/ternary_search_min_max.rs @@ -0,0 +1,112 @@ +/// Ternary search algorithm for finding maximum of unimodal function +pub fn ternary_search_max( + f: fn(f32) -> f32, + mut start: f32, + mut end: f32, + absolute_precision: f32, +) -> f32 { + while (start - end).abs() >= absolute_precision { + let mid1 = start + (end - start) / 3.0; + let mid2 = end - (end - start) / 3.0; + + let r1 = f(mid1); + let r2 = f(mid2); + + if r1 < r2 { + start = mid1; + } else if r1 > r2 { + end = mid2; + } else { + start = mid1; + end = mid2; + } + } + f(start) +} + +/// Ternary search algorithm for finding minimum of unimodal function +pub fn ternary_search_min( + f: fn(f32) -> f32, + mut start: f32, + mut end: f32, + absolute_precision: f32, +) -> f32 { + while (start - end).abs() >= absolute_precision { + let mid1 = start + (end - start) / 3.0; + let mid2 = end - (end - start) / 3.0; + + let r1 = f(mid1); + let r2 = f(mid2); + + if r1 < r2 { + end = mid2; + } else if r1 > r2 { + start = mid1; + } else { + start = mid1; + end = mid2; + } + } + f(start) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn finds_max_value() { + let expected = 4.0; + let f = |x: f32| -x * x - 2.0 * x + 3.0; + + let start: f32 = -10000000000.0; + let end: f32 = 10000000000.0; + let absolute_precision = 0.0000001; + + let result = ternary_search_max(f, start, end, absolute_precision); + + assert_eq!(result, expected); + } + + #[test] + fn finds_min_value() { + let expected = 2.0; + let f = |x: f32| x * x - 2.0 * x + 3.0; + + let start: f32 = -10000000000.0; + let end: f32 = 10000000000.0; + let absolute_precision = 0.0000001; + + let result = ternary_search_min(f, start, end, absolute_precision); + + assert_eq!(result, expected); + } + + #[test] + fn finds_max_value_2() { + let expected = 7.25; + let f = |x: f32| -x.powi(2) + 3.0 * x + 5.0; + + let start: f32 = -10000000000.0; + let end: f32 = 10000000000.0; + let absolute_precision = 0.000001; + + let result = ternary_search_max(f, start, end, absolute_precision); + + assert_eq!(result, expected); + } + + #[test] + fn finds_min_value_2() { + let expected = 2.75; + let f = |x: f32| x.powi(2) + 3.0 * x + 5.0; + + let start: f32 = -10000000000.0; + let end: f32 = 10000000000.0; + let absolute_precision = 0.000001; + + let result = ternary_search_min(f, start, end, absolute_precision); + + assert_eq!(result, expected); + } +} diff --git a/src/searching/ternary_search_min_max_recursive.rs b/src/searching/ternary_search_min_max_recursive.rs new file mode 100644 index 00000000000..88d3a0a7b1b --- /dev/null +++ b/src/searching/ternary_search_min_max_recursive.rs @@ -0,0 +1,108 @@ +/// Recursive ternary search algorithm for finding maximum of unimodal function +pub fn ternary_search_max_rec( + f: fn(f32) -> f32, + start: f32, + end: f32, + absolute_precision: f32, +) -> f32 { + if (end - start).abs() >= absolute_precision { + let mid1 = start + (end - start) / 3.0; + let mid2 = end - (end - start) / 3.0; + + let r1 = f(mid1); + let r2 = f(mid2); + + if r1 < r2 { + return ternary_search_max_rec(f, mid1, end, absolute_precision); + } else if r1 > r2 { + return ternary_search_max_rec(f, start, mid2, absolute_precision); + } + return ternary_search_max_rec(f, mid1, mid2, absolute_precision); + } + f(start) +} + +/// Recursive ternary search algorithm for finding minimum of unimodal function +pub fn ternary_search_min_rec( + f: fn(f32) -> f32, + start: f32, + end: f32, + absolute_precision: f32, +) -> f32 { + if (end - start).abs() >= absolute_precision { + let mid1 = start + (end - start) / 3.0; + let mid2 = end - (end - start) / 3.0; + + let r1 = f(mid1); + let r2 = f(mid2); + + if r1 < r2 { + return ternary_search_min_rec(f, start, mid2, absolute_precision); + } else if r1 > r2 { + return ternary_search_min_rec(f, mid1, end, absolute_precision); + } + return ternary_search_min_rec(f, mid1, mid2, absolute_precision); + } + f(start) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn finds_max_value() { + let expected = 4.0; + let f = |x: f32| -x * x - 2.0 * x + 3.0; + + let start: f32 = -10000000000.0; + let end: f32 = 10000000000.0; + let absolute_precision = 0.0000001; + + let result = ternary_search_max_rec(f, start, end, absolute_precision); + + assert_eq!(result, expected); + } + + #[test] + fn finds_min_value() { + let expected = 2.0; + let f = |x: f32| x * x - 2.0 * x + 3.0; + + let start: f32 = -10000000000.0; + let end: f32 = 10000000000.0; + let absolute_precision = 0.0000001; + + let result = ternary_search_min_rec(f, start, end, absolute_precision); + + assert_eq!(result, expected); + } + + #[test] + fn finds_max_value_2() { + let expected = 7.25; + let f = |x: f32| -x.powi(2) + 3.0 * x + 5.0; + + let start: f32 = -10000000000.0; + let end: f32 = 10000000000.0; + let absolute_precision = 0.000001; + + let result = ternary_search_max_rec(f, start, end, absolute_precision); + + assert_eq!(result, expected); + } + + #[test] + fn finds_min_value_2() { + let expected = 2.75; + let f = |x: f32| x.powi(2) + 3.0 * x + 5.0; + + let start: f32 = -10000000000.0; + let end: f32 = 10000000000.0; + let absolute_precision = 0.000001; + + let result = ternary_search_min_rec(f, start, end, absolute_precision); + + assert_eq!(result, expected); + } +} diff --git a/src/searching/ternary_search_recursive.rs b/src/searching/ternary_search_recursive.rs new file mode 100644 index 00000000000..045df86e3eb --- /dev/null +++ b/src/searching/ternary_search_recursive.rs @@ -0,0 +1,88 @@ +use std::cmp::Ordering; + +pub fn ternary_search_rec( + target: &T, + list: &[T], + start: usize, + end: usize, +) -> Option { + if list.is_empty() { + return None; + } + + if end >= start { + let mid1: usize = start + (end - start) / 3; + let mid2: usize = end - (end - start) / 3; + + match target.cmp(&list[mid1]) { + Ordering::Less => return ternary_search_rec(target, list, start, mid1 - 1), + Ordering::Equal => return Some(mid1), + Ordering::Greater => match target.cmp(&list[mid2]) { + Ordering::Greater => return ternary_search_rec(target, list, mid2 + 1, end), + Ordering::Equal => return Some(mid2), + Ordering::Less => return ternary_search_rec(target, list, mid1 + 1, mid2 - 1), + }, + } + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn returns_none_if_empty_list() { + let index = ternary_search_rec(&"a", &[], 1, 10); + assert_eq!(index, None); + } + + #[test] + fn returns_none_if_range_is_invalid() { + let index = ternary_search_rec(&1, &[1, 2, 3], 2, 1); + assert_eq!(index, None); + } + + #[test] + fn returns_index_if_list_has_one_item() { + let index = ternary_search_rec(&1, &[1], 0, 1); + assert_eq!(index, Some(0)); + } + + #[test] + fn returns_first_index() { + let index = ternary_search_rec(&1, &[1, 2, 3], 0, 2); + assert_eq!(index, Some(0)); + } + + #[test] + fn returns_first_index_if_end_out_of_bounds() { + let index = ternary_search_rec(&1, &[1, 2, 3], 0, 3); + assert_eq!(index, Some(0)); + } + + #[test] + fn returns_last_index() { + let index = ternary_search_rec(&3, &[1, 2, 3], 0, 2); + assert_eq!(index, Some(2)); + } + + #[test] + fn returns_last_index_if_end_out_of_bounds() { + let index = ternary_search_rec(&3, &[1, 2, 3], 0, 3); + assert_eq!(index, Some(2)); + } + + #[test] + fn returns_middle_index() { + let index = ternary_search_rec(&2, &[1, 2, 3], 0, 2); + assert_eq!(index, Some(1)); + } + + #[test] + fn returns_middle_index_if_end_out_of_bounds() { + let index = ternary_search_rec(&2, &[1, 2, 3], 0, 3); + assert_eq!(index, Some(1)); + } +} diff --git a/src/sorting/README.md b/src/sorting/README.md index 996ebf5840e..4b0e248db35 100644 --- a/src/sorting/README.md +++ b/src/sorting/README.md @@ -1,5 +1,15 @@ ## Sort Algorithms +### [Bogo-sort](./bogo_sort.rs) +![alt text][bogo-image] + +From [Wikipedia][bogo-wiki]: In computer science, bogosort is a sorting algorithm based on the generate and test paradigm. The function successively generates permutations of its input until it finds one that is sorted. It is not considered useful for sorting, but may be used for educational purposes, to contrast it with more efficient algorithms. + +__Properties__ +* Worst case performance (unbounded in randomized version) +* Best case performance O(n) +* Average case performance O((n+1)!) + ### [Bubble](./bubble_sort.rs) ![alt text][bubble-image] @@ -15,6 +25,32 @@ __Properties__ +### [Cocktail-Shaker](./cocktail_shaker_sort.rs) +![alt text][shaker-image] + +From [Wikipedia][shaker-wiki]: Cocktail shaker sort, also known as bidirectional bubble sort, cocktail sort, shaker sort (which can also refer to a variant of selection sort), ripple sort, shuffle sort, or shuttle sort, is an extension of bubble sort. The algorithm extends bubble sort by operating in two directions. While it improves on bubble sort by more quickly moving items to the beginning of the list, it provides only marginal performance improvements. + +__Properties__ +* Worst case performance O(n^2) +* Best case performance O(n) +* Average case performance O(n^2) + + + +### [Comb-sort](./comb_sort.rs) +![comb sort][comb-sort] + +From [wikipedia][comb-sort-wiki]: Comb sort is a relatively simple sorting algorithm and improves on bubble sort in the same way that shell sort improves on insertion sort. The basic idea of comb sort is that the gap(distance from two compared elements) can be much more than 1. And the inner loop of bubble sort, which does actual `swap`, is modified such that the gap between swapped elements goes down in steps of a `shrink factor k: [n/k, n/k^2, ..., 1]`. And the gap is divided by the shrink factor in every loop, and the process repeats until the gap is 1. At this point, comb sort continues using a gap of 1 until the list is fully sorted. The final stage of the sort is thus equivalent to a bubble sort, but this time most turtles have been dealt with, so a bubble sort will be efficient. And the shrink factor has a great effect on the efficiency of comb sort and `k=1.3` has been suggested as an ideal value. + +__Properties__ +* Worst case performance O(n^2) +* Best case performance O(n log n) +* Average case performance O(n^2/2^p) + +where `p` is the number of increments. + + + ### [Counting](./counting_sort.rs) From [Wikipedia][counting-wiki]: In computer science, counting sort is an algorithm for sorting a collection of objects according to keys that are small integers; that is, it is an integer sorting algorithm. It operates by counting the number of objects that have each distinct key value, and using arithmetic on those counts to determine the positions of each key value in the output sequence. Its running time is linear in the number of items and the difference between the maximum and minimum key values, so it is only suitable for direct use in situations where the variation in keys is not significantly greater than the number of items. However, it is often used as a subroutine in another sorting algorithm, radix sort, that can handle larger keys more efficiently. @@ -41,6 +77,18 @@ __Properties__ ###### View the algorithm in [action][insertion-toptal] +### [Gnome](./gnome_sort.rs) +![alt text][gnome-image] + +From [Wikipedia][gnome-wiki]: The gnome sort is a sorting algorithm which is similar to insertion sort in that it works with one item at a time but gets the item to the proper place by a series of swaps, similar to a bubble sort. It is conceptually simple, requiring no nested loops. The average running time is O(n^2) but tends towards O(n) if the list is initially almost sorted + +__Properties__ +* Worst case performance O(n^2) +* Best case performance O(n) +* Average case performance O(n^2) + + + ### [Merge](./merge_sort.rs) ![alt text][merge-image] @@ -54,6 +102,24 @@ __Properties__ ###### View the algorithm in [action][merge-toptal] +### [Odd-even](./odd_even_sort.rs) +![alt text][odd-even-image] + +From [Wikipedia][odd-even-wiki]: In computing, an odd–even sort or odd–even transposition sort (also known as brick sort or parity sort) is a relatively simple sorting algorithm, developed originally for use on parallel processors with local interconnections. It is a comparison sort related to bubble sort, with which it shares many characteristics. It functions by comparing all odd/even indexed pairs of adjacent elements in the list and, if a pair is in the wrong order (the first is larger than the second) the elements are switched. The next step repeats this for even/odd indexed pairs (of adjacent elements). Then it alternates between odd/even and even/odd steps until the list is sorted. + +NOTE: The implementation is an adaptation of the algorithm for a single-processor machine, while the original algorithm was devised to be executed on many processors simultaneously. +__Properties__ +* Worst case performance O(n^2) +* Best case performance O(n) +* Average case performance O(n^2) + + +### [Pancake](./pancake_sort.rs) +![alt text][pancake-image] + +From [Wikipedia][pancake-wiki]: All sorting methods require pairs of elements to be compared. For the traditional sorting problem, the usual problem studied is to minimize the number of comparisons required to sort a list. The number of actual operations, such as swapping two elements, is then irrelevant. For pancake sorting problems, in contrast, the aim is to minimize the number of operations, where the only allowed operations are reversals of the elements of some prefix of the sequence. Now, the number of comparisons is irrelevant. + + ### [Quick](./quick_sort.rs) ![alt text][quick-image] @@ -100,16 +166,67 @@ __Properties__ ###### View the algorithm in [action][shell-toptal] +### [Stooge](./stooge_sort.rs) +![alt text][stooge-image] + +From [Wikipedia][stooge-wiki]: Stooge sort is a recursive sorting algorithm. It is notable for its exceptionally bad time complexity of O(n^(log 3 / log 1.5)) = O(n^2.7095...). The running time of the algorithm is thus slower compared to reasonable sorting algorithms, and is slower than Bubble sort, a canonical example of a fairly inefficient sort. It is however more efficient than Slowsort. The name comes from The Three Stooges. + +__Properties__ +* Worst case performance O(n^(log(3) / log(1.5))) + +### [Tim](./tim_sort.rs) +![alt text][tim-image] + +From [Wikipedia][tim-wiki]: Timsort is a hybrid stable sorting algorithm, derived from merge sort and insertion sort, designed to perform well on many kinds of real-world data. It was implemented by Tim Peters in 2002 for use in the Python programming language. The algorithm finds subsequences of the data that are already ordered (runs) and uses them to sort the remainder more efficiently. This is done by merging runs until certain criteria are fulfilled. Timsort has been Python's standard sorting algorithm since version 2.3. It is also used to sort arrays of non-primitive type in Java SE 7, on the Android platform, in GNU Octave, on V8, Swift, and Rust. + +__Properties__ +* Worst-case performance O(max element size(ms)) +* Best-case performance O(max element size(ms)) + +### [Sleep](./sleep_sort.rs) +![alt text][sleep-image] + +From [Wikipedia][bucket-sort-wiki]: This is an idea that was originally posted on the message board 4chan, replacing the bucket in bucket sort with time instead of memory space. +It is actually possible to sort by "maximum of all elements x unit time to sleep". The only case where this would be useful would be in examples. + +### [Patience](./patience_sort.rs) +[patience-video] + + +From [Wikipedia][patience-sort-wiki]: The algorithm's name derives from a simplified variant of the patience card game. The game begins with a shuffled deck of cards. The cards are dealt one by one into a sequence of piles on the table, according to the following rules. + +1. Initially, there are no piles. The first card dealt forms a new pile consisting of the single card. +2. Each subsequent card is placed on the leftmost existing pile whose top card has a value greater than or equal to the new card's value, or to the right of all of the existing piles, thus forming a new pile. +3. When there are no more cards remaining to deal, the game ends. + +This card game is turned into a two-phase sorting algorithm, as follows. Given an array of n elements from some totally ordered domain, consider this array as a collection of cards and simulate the patience sorting game. When the game is over, recover the sorted sequence by repeatedly picking off the minimum visible card; in other words, perform a k-way merge of the p piles, each of which is internally sorted. + +__Properties__ +* Worst case performance O(n log n) +* Best case performance O(n) + +[bogo-wiki]: https://en.wikipedia.org/wiki/Bogosort +[bogo-image]: https://upload.wikimedia.org/wikipedia/commons/7/7b/Bogo_sort_animation.gif + [bubble-toptal]: https://www.toptal.com/developers/sorting-algorithms/bubble-sort [bubble-wiki]: https://en.wikipedia.org/wiki/Bubble_sort [bubble-image]: https://upload.wikimedia.org/wikipedia/commons/thumb/8/83/Bubblesort-edited-color.svg/220px-Bubblesort-edited-color.svg.png "Bubble Sort" +[shaker-wiki]: https://en.wikipedia.org/wiki/Cocktail_shaker_sort +[shaker-image]: https://upload.wikimedia.org/wikipedia/commons/e/ef/Sorting_shaker_sort_anim.gif + [counting-wiki]: https://en.wikipedia.org/wiki/Counting_sort [insertion-toptal]: https://www.toptal.com/developers/sorting-algorithms/insertion-sort [insertion-wiki]: https://en.wikipedia.org/wiki/Insertion_sort [insertion-image]: https://upload.wikimedia.org/wikipedia/commons/7/7e/Insertionsort-edited.png "Insertion Sort" +[gnome-wiki]: https://en.wikipedia.org/wiki/Gnome_sort +[gnome-image]: https://upload.wikimedia.org/wikipedia/commons/3/37/Sorting_gnomesort_anim.gif "Insertion Sort" + +[pancake-wiki]: https://en.wikipedia.org/wiki/Pancake_sorting +[pancake-image]: https://upload.wikimedia.org/wikipedia/commons/0/0f/Pancake_sort_operation.png + [quick-toptal]: https://www.toptal.com/developers/sorting-algorithms/quick-sort [quick-wiki]: https://en.wikipedia.org/wiki/Quicksort [quick-image]: https://upload.wikimedia.org/wikipedia/commons/6/6a/Sorting_quicksort_anim.gif "Quick Sort" @@ -118,6 +235,9 @@ __Properties__ [merge-wiki]: https://en.wikipedia.org/wiki/Merge_sort [merge-image]: https://upload.wikimedia.org/wikipedia/commons/c/cc/Merge-sort-example-300px.gif "Merge Sort" +[odd-even-image]: https://upload.wikimedia.org/wikipedia/commons/1/1b/Odd_even_sort_animation.gif +[odd-even-wiki]: https://en.wikipedia.org/wiki/Odd%E2%80%93even_sort + [radix-wiki]: https://en.wikipedia.org/wiki/Radix_sort [radix-image]: https://ds055uzetaobb.cloudfront.net/brioche/uploads/IEZs8xJML3-radixsort_ed.png?width=400 "Radix Sort" @@ -128,3 +248,18 @@ __Properties__ [shell-toptal]: https://www.toptal.com/developers/sorting-algorithms/shell-sort [shell-wiki]: https://en.wikipedia.org/wiki/Shellsort [shell-image]: https://upload.wikimedia.org/wikipedia/commons/d/d8/Sorting_shellsort_anim.gif "Shell Sort" + +[stooge-image]: https://upload.wikimedia.org/wikipedia/commons/f/f8/Sorting_stoogesort_anim.gif +[stooge-wiki]: https://en.wikipedia.org/wiki/Stooge_sort + +[tim-image]: https://thumbs.gfycat.com/BruisedFrigidBlackrhino-size_restricted.gif +[tim-wiki]: https://en.wikipedia.org/wiki/Timsort + +[comb-sort]: https://upload.wikimedia.org/wikipedia/commons/4/46/Comb_sort_demo.gif +[comb-sort-wiki]: https://en.wikipedia.org/wiki/Comb_sort + +[sleep-sort]: +[sleep-sort-wiki]: https://ja.m.wikipedia.org/wiki/バケットソート#.E3.82.B9.E3.83.AA.E3.83.BC.E3.83.97.E3.82.BD.E3.83.BC.E3.83.88 + +[patience-sort-wiki]: https://en.wikipedia.org/wiki/Patience_sorting +[patience-video]: https://user-images.githubusercontent.com/67539676/212542208-d3f7a824-60d8-467c-8097-841945514ae9.mp4 diff --git a/src/sorting/bead_sort.rs b/src/sorting/bead_sort.rs new file mode 100644 index 00000000000..c8c4017942e --- /dev/null +++ b/src/sorting/bead_sort.rs @@ -0,0 +1,59 @@ +//Bead sort only works for sequences of non-negative integers. +//https://en.wikipedia.org/wiki/Bead_sort +pub fn bead_sort(a: &mut [usize]) { + // Find the maximum element + let mut max = a[0]; + (1..a.len()).for_each(|i| { + if a[i] > max { + max = a[i]; + } + }); + + // allocating memory + let mut beads = vec![vec![0; max]; a.len()]; + + // mark the beads + for i in 0..a.len() { + for j in (0..a[i]).rev() { + beads[i][j] = 1; + } + } + + // move down the beads + for j in 0..max { + let mut sum = 0; + (0..a.len()).for_each(|i| { + sum += beads[i][j]; + beads[i][j] = 0; + }); + + for k in ((a.len() - sum)..a.len()).rev() { + a[k] = j + 1; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + + #[test] + fn descending() { + //descending + let mut ve1: [usize; 5] = [5, 4, 3, 2, 1]; + let cloned = ve1; + bead_sort(&mut ve1); + assert!(is_sorted(&ve1) && have_same_elements(&ve1, &cloned)); + } + + #[test] + fn mix_values() { + //pre-sorted + let mut ve2: [usize; 5] = [7, 9, 6, 2, 3]; + let cloned = ve2; + bead_sort(&mut ve2); + assert!(is_sorted(&ve2) && have_same_elements(&ve2, &cloned)); + } +} diff --git a/src/sorting/binary_insertion_sort.rs b/src/sorting/binary_insertion_sort.rs new file mode 100644 index 00000000000..3ecb47456e8 --- /dev/null +++ b/src/sorting/binary_insertion_sort.rs @@ -0,0 +1,51 @@ +fn _binary_search(arr: &[T], target: &T) -> usize { + let mut low = 0; + let mut high = arr.len(); + + while low < high { + let mid = low + (high - low) / 2; + + if arr[mid] < *target { + low = mid + 1; + } else { + high = mid; + } + } + + low +} + +pub fn binary_insertion_sort(arr: &mut [T]) { + let len = arr.len(); + + for i in 1..len { + let key = arr[i].clone(); + let index = _binary_search(&arr[..i], &key); + + arr[index..=i].rotate_right(1); + arr[index] = key; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_binary_insertion_sort() { + let mut arr1 = vec![64, 25, 12, 22, 11]; + let mut arr2 = vec![5, 4, 3, 2, 1]; + let mut arr3 = vec![1, 2, 3, 4, 5]; + let mut arr4: Vec = vec![]; // Explicitly specify the type for arr4 + + binary_insertion_sort(&mut arr1); + binary_insertion_sort(&mut arr2); + binary_insertion_sort(&mut arr3); + binary_insertion_sort(&mut arr4); + + assert_eq!(arr1, vec![11, 12, 22, 25, 64]); + assert_eq!(arr2, vec![1, 2, 3, 4, 5]); + assert_eq!(arr3, vec![1, 2, 3, 4, 5]); + assert_eq!(arr4, Vec::::new()); + } +} diff --git a/src/sorting/bingo_sort.rs b/src/sorting/bingo_sort.rs new file mode 100644 index 00000000000..0c113fe86d5 --- /dev/null +++ b/src/sorting/bingo_sort.rs @@ -0,0 +1,105 @@ +use std::cmp::{max, min}; + +// Function for finding the maximum and minimum element of the Array +fn max_min(vec: &[i32], bingo: &mut i32, next_bingo: &mut i32) { + for &element in vec.iter().skip(1) { + *bingo = min(*bingo, element); + *next_bingo = max(*next_bingo, element); + } +} + +pub fn bingo_sort(vec: &mut [i32]) { + if vec.is_empty() { + return; + } + + let mut bingo = vec[0]; + let mut next_bingo = vec[0]; + + max_min(vec, &mut bingo, &mut next_bingo); + + let largest_element = next_bingo; + let mut next_element_pos = 0; + + for (bingo, _next_bingo) in (bingo..=largest_element).zip(bingo..=largest_element) { + let start_pos = next_element_pos; + + for i in start_pos..vec.len() { + if vec[i] == bingo { + vec.swap(i, next_element_pos); + next_element_pos += 1; + } + } + } +} + +#[allow(dead_code)] +fn print_array(arr: &[i32]) { + print!("Sorted Array: "); + for &element in arr { + print!("{element} "); + } + println!(); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bingo_sort() { + let mut arr = vec![5, 4, 8, 5, 4, 8, 5, 4, 4, 4]; + bingo_sort(&mut arr); + assert_eq!(arr, vec![4, 4, 4, 4, 4, 5, 5, 5, 8, 8]); + + let mut arr2 = vec![10, 9, 8, 7, 6, 5, 4, 3, 2, 1]; + bingo_sort(&mut arr2); + assert_eq!(arr2, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + let mut arr3 = vec![0, 1, 0, 1, 0, 1]; + bingo_sort(&mut arr3); + assert_eq!(arr3, vec![0, 0, 0, 1, 1, 1]); + } + + #[test] + fn test_empty_array() { + let mut arr = Vec::new(); + bingo_sort(&mut arr); + assert_eq!(arr, Vec::new()); + } + + #[test] + fn test_single_element_array() { + let mut arr = vec![42]; + bingo_sort(&mut arr); + assert_eq!(arr, vec![42]); + } + + #[test] + fn test_negative_numbers() { + let mut arr = vec![-5, -4, -3, -2, -1]; + bingo_sort(&mut arr); + assert_eq!(arr, vec![-5, -4, -3, -2, -1]); + } + + #[test] + fn test_already_sorted() { + let mut arr = vec![1, 2, 3, 4, 5]; + bingo_sort(&mut arr); + assert_eq!(arr, vec![1, 2, 3, 4, 5]); + } + + #[test] + fn test_reverse_sorted() { + let mut arr = vec![5, 4, 3, 2, 1]; + bingo_sort(&mut arr); + assert_eq!(arr, vec![1, 2, 3, 4, 5]); + } + + #[test] + fn test_duplicates() { + let mut arr = vec![1, 2, 3, 4, 5, 1, 2, 3, 4, 5]; + bingo_sort(&mut arr); + assert_eq!(arr, vec![1, 1, 2, 2, 3, 3, 4, 4, 5, 5]); + } +} diff --git a/src/sorting/bitonic_sort.rs b/src/sorting/bitonic_sort.rs new file mode 100644 index 00000000000..8b1dca0c6c0 --- /dev/null +++ b/src/sorting/bitonic_sort.rs @@ -0,0 +1,52 @@ +fn _comp_and_swap(array: &mut [T], left: usize, right: usize, ascending: bool) { + if (ascending && array[left] > array[right]) || (!ascending && array[left] < array[right]) { + array.swap(left, right); + } +} + +fn _bitonic_merge(array: &mut [T], low: usize, length: usize, ascending: bool) { + if length > 1 { + let middle = length / 2; + for i in low..(low + middle) { + _comp_and_swap(array, i, i + middle, ascending); + } + _bitonic_merge(array, low, middle, ascending); + _bitonic_merge(array, low + middle, middle, ascending); + } +} + +pub fn bitonic_sort(array: &mut [T], low: usize, length: usize, ascending: bool) { + if length > 1 { + let middle = length / 2; + bitonic_sort(array, low, middle, true); + bitonic_sort(array, low + middle, middle, false); + _bitonic_merge(array, low, length, ascending); + } +} + +//Note that this program works only when size of input is a power of 2. +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_descending_sorted; + use crate::sorting::is_sorted; + + #[test] + fn descending() { + //descending + let mut ve1 = vec![6, 5, 4, 3]; + let cloned = ve1.clone(); + bitonic_sort(&mut ve1, 0, 4, true); + assert!(is_sorted(&ve1) && have_same_elements(&ve1, &cloned)); + } + + #[test] + fn ascending() { + //pre-sorted + let mut ve2 = vec![1, 2, 3, 4]; + let cloned = ve2.clone(); + bitonic_sort(&mut ve2, 0, 4, false); + assert!(is_descending_sorted(&ve2) && have_same_elements(&ve2, &cloned)); + } +} diff --git a/src/sorting/bogo_sort.rs b/src/sorting/bogo_sort.rs new file mode 100644 index 00000000000..572a9dd6584 --- /dev/null +++ b/src/sorting/bogo_sort.rs @@ -0,0 +1,73 @@ +use crate::math::PCG32; +use std::time::{SystemTime, UNIX_EPOCH}; + +const DEFAULT: u64 = 4294967296; + +fn is_sorted(arr: &[T], len: usize) -> bool { + for i in 0..len - 1 { + if arr[i] > arr[i + 1] { + return false; + } + } + + true +} + +#[cfg(target_pointer_width = "64")] +fn generate_index(range: usize, generator: &mut PCG32) -> usize { + generator.get_u64() as usize % range +} + +#[cfg(not(target_pointer_width = "64"))] +fn generate_index(range: usize, generator: &mut PCG32) -> usize { + generator.get_u32() as usize % range +} + +/** + * Fisher–Yates shuffle for generating random permutation. + */ +fn permute_randomly(arr: &mut [T], len: usize, generator: &mut PCG32) { + for i in (1..len).rev() { + let j = generate_index(i + 1, generator); + arr.swap(i, j); + } +} + +pub fn bogo_sort(arr: &mut [T]) { + let seed = match SystemTime::now().duration_since(UNIX_EPOCH) { + Ok(duration) => duration.as_millis() as u64, + Err(_) => DEFAULT, + }; + + let mut random_generator = PCG32::new_default(seed); + + let arr_length = arr.len(); + while !is_sorted(arr, arr_length) { + permute_randomly(arr, arr_length, &mut random_generator); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn random_array() { + let mut arr = [1, 8, 3, 2, 7, 4, 6, 5]; + bogo_sort(&mut arr); + + for i in 0..arr.len() - 1 { + assert!(arr[i] <= arr[i + 1]); + } + } + + #[test] + fn sorted_array() { + let mut arr = [1, 2, 3, 4, 5, 6, 7, 8]; + bogo_sort(&mut arr); + + for i in 0..arr.len() - 1 { + assert!(arr[i] <= arr[i + 1]); + } + } +} diff --git a/src/sorting/bubble_sort.rs b/src/sorting/bubble_sort.rs index 76b2f29e6fa..0df7cec07a1 100644 --- a/src/sorting/bubble_sort.rs +++ b/src/sorting/bubble_sort.rs @@ -1,34 +1,49 @@ pub fn bubble_sort(arr: &mut [T]) { - for i in 0..arr.len() { - for j in 0..arr.len() - 1 - i { - if arr[j] > arr[j + 1] { - arr.swap(j, j + 1); + if arr.is_empty() { + return; + } + let mut sorted = false; + let mut n = arr.len(); + while !sorted { + sorted = true; + for i in 0..n - 1 { + if arr[i] > arr[i + 1] { + arr.swap(i, i + 1); + sorted = false; } } + n -= 1; } } #[cfg(test)] mod tests { use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn descending() { //descending let mut ve1 = vec![6, 5, 4, 3, 2, 1]; + let cloned = ve1.clone(); bubble_sort(&mut ve1); - for i in 0..ve1.len() - 1 { - assert!(ve1[i] <= ve1[i + 1]); - } + assert!(is_sorted(&ve1) && have_same_elements(&ve1, &cloned)); } #[test] fn ascending() { //pre-sorted let mut ve2 = vec![1, 2, 3, 4, 5, 6]; + let cloned = ve2.clone(); bubble_sort(&mut ve2); - for i in 0..ve2.len() - 1 { - assert!(ve2[i] <= ve2[i + 1]); - } + assert!(is_sorted(&ve2) && have_same_elements(&ve2, &cloned)); + } + #[test] + fn empty() { + let mut ve3: Vec = vec![]; + let cloned = ve3.clone(); + bubble_sort(&mut ve3); + assert!(is_sorted(&ve3) && have_same_elements(&ve3, &cloned)); } } diff --git a/src/sorting/bucket_sort.rs b/src/sorting/bucket_sort.rs new file mode 100644 index 00000000000..05fb30272ad --- /dev/null +++ b/src/sorting/bucket_sort.rs @@ -0,0 +1,87 @@ +/// Sort a slice using bucket sort algorithm. +/// +/// Time complexity is `O(n + k)` on average, where `n` is the number of elements, +/// `k` is the number of buckets used in process. +/// +/// Space complexity is `O(n + k)`, as it sorts not in-place. +pub fn bucket_sort(arr: &[usize]) -> Vec { + if arr.is_empty() { + return vec![]; + } + + let max = *arr.iter().max().unwrap(); + let len = arr.len(); + let mut buckets = vec![vec![]; len + 1]; + + for x in arr { + buckets[len * *x / max].push(*x); + } + + for bucket in buckets.iter_mut() { + super::insertion_sort(bucket); + } + + let mut result = vec![]; + for bucket in buckets { + for x in bucket { + result.push(x); + } + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + + #[test] + fn empty() { + let arr: [usize; 0] = []; + let cloned = arr; + let res = bucket_sort(&arr); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn one_element() { + let arr: [usize; 1] = [4]; + let cloned = arr; + let res = bucket_sort(&arr); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn already_sorted() { + let arr: [usize; 3] = [10, 19, 105]; + let cloned = arr; + let res = bucket_sort(&arr); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn basic() { + let arr: [usize; 4] = [35, 53, 1, 0]; + let cloned = arr; + let res = bucket_sort(&arr); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn odd_number_of_elements() { + let arr: [usize; 5] = [1, 21, 5, 11, 58]; + let cloned = arr; + let res = bucket_sort(&arr); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn repeated_elements() { + let arr: [usize; 4] = [542, 542, 542, 542]; + let cloned = arr; + let res = bucket_sort(&arr); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } +} diff --git a/src/sorting/cocktail_shaker_sort.rs b/src/sorting/cocktail_shaker_sort.rs new file mode 100644 index 00000000000..dd65fc0fa99 --- /dev/null +++ b/src/sorting/cocktail_shaker_sort.rs @@ -0,0 +1,74 @@ +pub fn cocktail_shaker_sort(arr: &mut [T]) { + let len = arr.len(); + + if len == 0 { + return; + } + + loop { + let mut swapped = false; + + for i in 0..(len - 1).clamp(0, len) { + if arr[i] > arr[i + 1] { + arr.swap(i, i + 1); + swapped = true; + } + } + + if !swapped { + break; + } + + swapped = false; + + for i in (0..(len - 1).clamp(0, len)).rev() { + if arr[i] > arr[i + 1] { + arr.swap(i, i + 1); + swapped = true; + } + } + + if !swapped { + break; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + + #[test] + fn basic() { + let mut arr = vec![5, 2, 1, 3, 4, 6]; + let cloned = arr.clone(); + cocktail_shaker_sort(&mut arr); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); + } + + #[test] + fn empty() { + let mut arr = Vec::::new(); + let cloned = arr.clone(); + cocktail_shaker_sort(&mut arr); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); + } + + #[test] + fn one_element() { + let mut arr = vec![1]; + let cloned = arr.clone(); + cocktail_shaker_sort(&mut arr); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); + } + + #[test] + fn pre_sorted() { + let mut arr = vec![1, 2, 3, 4, 5, 6]; + let cloned = arr.clone(); + cocktail_shaker_sort(&mut arr); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); + } +} diff --git a/src/sorting/comb_sort.rs b/src/sorting/comb_sort.rs new file mode 100644 index 00000000000..d84522ce2ee --- /dev/null +++ b/src/sorting/comb_sort.rs @@ -0,0 +1,54 @@ +pub fn comb_sort(arr: &mut [T]) { + let mut gap = arr.len(); + let shrink = 1.3; + let mut sorted = false; + + while !sorted { + gap = (gap as f32 / shrink).floor() as usize; + if gap <= 1 { + gap = 1; + sorted = true; + } + for i in 0..arr.len() - gap { + let j = i + gap; + if arr[i] > arr[j] { + arr.swap(i, j); + sorted = false; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + + #[test] + fn descending() { + //descending + let mut ve1 = vec![6, 5, 4, 3, 2, 1]; + let cloned = ve1.clone(); + comb_sort(&mut ve1); + assert!(is_sorted(&ve1) && have_same_elements(&ve1, &cloned)); + } + + #[test] + fn ascending() { + //pre-sorted + let mut ve2 = vec![1, 2, 3, 4, 5, 6]; + let cloned = ve2.clone(); + comb_sort(&mut ve2); + assert!(is_sorted(&ve2) && have_same_elements(&ve2, &cloned)); + } + + #[test] + fn duplicates() { + //pre-sorted + let mut ve3 = vec![2, 2, 2, 2, 2, 1]; + let cloned = ve3.clone(); + comb_sort(&mut ve3); + assert!(is_sorted(&ve3) && have_same_elements(&ve3, &cloned)); + } +} diff --git a/src/sorting/counting_sort.rs b/src/sorting/counting_sort.rs index 6488ed5a5ae..e1c22373f7c 100644 --- a/src/sorting/counting_sort.rs +++ b/src/sorting/counting_sort.rs @@ -51,38 +51,43 @@ pub fn generic_counting_sort + From + AddAssign + Copy>( #[cfg(test)] mod test { - use super::super::is_sorted; use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn counting_sort_descending() { let mut ve1 = vec![6, 5, 4, 3, 2, 1]; + let cloned = ve1.clone(); counting_sort(&mut ve1, 6); - assert!(is_sorted(&ve1)); + assert!(is_sorted(&ve1) && have_same_elements(&ve1, &cloned)); } #[test] fn counting_sort_pre_sorted() { let mut ve2 = vec![1, 2, 3, 4, 5, 6]; + let cloned = ve2.clone(); counting_sort(&mut ve2, 6); - assert!(is_sorted(&ve2)); + assert!(is_sorted(&ve2) && have_same_elements(&ve2, &cloned)); } #[test] fn generic_counting_sort() { let mut ve1: Vec = vec![100, 30, 60, 10, 20, 120, 1]; + let cloned = ve1.clone(); super::generic_counting_sort(&mut ve1, 120); - assert!(is_sorted(&ve1)); + assert!(is_sorted(&ve1) && have_same_elements(&ve1, &cloned)); } #[test] fn presorted_u64_counting_sort() { let mut ve2: Vec = vec![1, 2, 3, 4, 5, 6]; + let cloned = ve2.clone(); super::generic_counting_sort(&mut ve2, 6); - assert!(is_sorted(&ve2)); + assert!(is_sorted(&ve2) && have_same_elements(&ve2, &cloned)); } } diff --git a/src/sorting/cycle_sort.rs b/src/sorting/cycle_sort.rs new file mode 100644 index 00000000000..44c0947613c --- /dev/null +++ b/src/sorting/cycle_sort.rs @@ -0,0 +1,60 @@ +// sorts with the minimum number of rewrites. Runs through all values in the array, placing them in their correct spots. O(n^2). + +pub fn cycle_sort(arr: &mut [i32]) { + for cycle_start in 0..arr.len() { + let mut item = arr[cycle_start]; + let mut pos = cycle_start; + for i in arr.iter().skip(cycle_start + 1) { + if *i < item { + pos += 1; + } + } + if pos == cycle_start { + continue; + } + while item == arr[pos] { + pos += 1; + } + std::mem::swap(&mut arr[pos], &mut item); + while pos != cycle_start { + pos = cycle_start; + for i in arr.iter().skip(cycle_start + 1) { + if *i < item { + pos += 1; + } + } + while item == arr[pos] { + pos += 1; + } + std::mem::swap(&mut arr[pos], &mut item); + } + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + + #[test] + fn it_works() { + let mut arr1 = [6, 5, 4, 3, 2, 1]; + let cloned = arr1; + cycle_sort(&mut arr1); + assert!(is_sorted(&arr1) && have_same_elements(&arr1, &cloned)); + arr1 = [12, 343, 21, 90, 3, 21]; + let cloned = arr1; + cycle_sort(&mut arr1); + assert!(is_sorted(&arr1) && have_same_elements(&arr1, &cloned)); + let mut arr2 = [1]; + let cloned = arr2; + cycle_sort(&mut arr2); + assert!(is_sorted(&arr2) && have_same_elements(&arr2, &cloned)); + let mut arr3 = [213, 542, 90, -23412, -32, 324, -34, 3324, 54]; + let cloned = arr3; + cycle_sort(&mut arr3); + assert!(is_sorted(&arr3) && have_same_elements(&arr3, &cloned)); + } +} diff --git a/src/sorting/dutch_national_flag_sort.rs b/src/sorting/dutch_national_flag_sort.rs new file mode 100644 index 00000000000..7d24d6d0321 --- /dev/null +++ b/src/sorting/dutch_national_flag_sort.rs @@ -0,0 +1,67 @@ +/* +A Rust implementation of the Dutch National Flag sorting algorithm. + +Reference implementation: https://github.com/TheAlgorithms/Python/blob/master/sorts/dutch_national_flag_sort.py +More info: https://en.wikipedia.org/wiki/Dutch_national_flag_problem +*/ + +#[derive(PartialOrd, PartialEq, Eq)] +pub enum Colors { + Red, // \ + White, // | Define the three colors of the Dutch Flag: 🇳🇱 + Blue, // / +} +use Colors::{Blue, Red, White}; + +// Algorithm implementation +pub fn dutch_national_flag_sort(mut sequence: Vec) -> Vec { + // We take ownership of `sequence` because the original `sequence` will be modified and then returned + let length = sequence.len(); + if length <= 1 { + return sequence; // Arrays of length 0 or 1 are automatically sorted + } + let mut low = 0; + let mut mid = 0; + let mut high = length - 1; + while mid <= high { + match sequence[mid] { + Red => { + sequence.swap(low, mid); + low += 1; + mid += 1; + } + White => { + mid += 1; + } + Blue => { + sequence.swap(mid, high); + high -= 1; + } + } + } + sequence +} + +#[cfg(test)] +mod tests { + use super::super::is_sorted; + use super::*; + + #[test] + fn random_array() { + let arr = vec![ + Red, Blue, White, White, Blue, Blue, Red, Red, White, Blue, White, Red, White, Blue, + ]; + let arr = dutch_national_flag_sort(arr); + assert!(is_sorted(&arr)) + } + + #[test] + fn sorted_array() { + let arr = vec![ + Red, Red, Red, Red, Red, White, White, White, White, White, Blue, Blue, Blue, Blue, + ]; + let arr = dutch_national_flag_sort(arr); + assert!(is_sorted(&arr)) + } +} diff --git a/src/sorting/exchange_sort.rs b/src/sorting/exchange_sort.rs new file mode 100644 index 00000000000..bbd0348b9a5 --- /dev/null +++ b/src/sorting/exchange_sort.rs @@ -0,0 +1,38 @@ +// sorts through swapping the first value until it is at the right position, and repeating for all the following. + +pub fn exchange_sort(arr: &mut [i32]) { + let length = arr.len(); + for number1 in 0..length { + for number2 in (number1 + 1)..length { + if arr[number2] < arr[number1] { + arr.swap(number1, number2) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + #[test] + fn it_works() { + let mut arr1 = [6, 5, 4, 3, 2, 1]; + let cloned = arr1; + exchange_sort(&mut arr1); + assert!(is_sorted(&arr1) && have_same_elements(&arr1, &cloned)); + arr1 = [12, 343, 21, 90, 3, 21]; + let cloned = arr1; + exchange_sort(&mut arr1); + assert!(is_sorted(&arr1) && have_same_elements(&arr1, &cloned)); + let mut arr2 = [1]; + let cloned = arr2; + exchange_sort(&mut arr2); + assert!(is_sorted(&arr2) && have_same_elements(&arr2, &cloned)); + let mut arr3 = [213, 542, 90, -23412, -32, 324, -34, 3324, 54]; + let cloned = arr3; + exchange_sort(&mut arr3); + assert!(is_sorted(&arr3) && have_same_elements(&arr3, &cloned)); + } +} diff --git a/src/sorting/gnome_sort.rs b/src/sorting/gnome_sort.rs new file mode 100644 index 00000000000..cc43ad7963b --- /dev/null +++ b/src/sorting/gnome_sort.rs @@ -0,0 +1,67 @@ +use std::cmp; + +pub fn gnome_sort(arr: &[T]) -> Vec +where + T: cmp::PartialEq + cmp::PartialOrd + Clone, +{ + let mut arr = arr.to_vec(); + let mut i: usize = 1; + let mut j: usize = 2; + + while i < arr.len() { + if arr[i - 1] < arr[i] { + i = j; + j = i + 1; + } else { + arr.swap(i - 1, i); + i -= 1; + if i == 0 { + i = j; + j += 1; + } + } + } + arr +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + + #[test] + fn basic() { + let original = [6, 5, -8, 3, 2, 3]; + let res = gnome_sort(&original); + assert!(is_sorted(&res) && have_same_elements(&res, &original)); + } + + #[test] + fn already_sorted() { + let original = gnome_sort(&["a", "b", "c"]); + let res = gnome_sort(&original); + assert!(is_sorted(&res) && have_same_elements(&res, &original)); + } + + #[test] + fn odd_number_of_elements() { + let original = gnome_sort(&["d", "a", "c", "e", "b"]); + let res = gnome_sort(&original); + assert!(is_sorted(&res) && have_same_elements(&res, &original)); + } + + #[test] + fn one_element() { + let original = gnome_sort(&[3]); + let res = gnome_sort(&original); + assert!(is_sorted(&res) && have_same_elements(&res, &original)); + } + + #[test] + fn empty() { + let original = gnome_sort(&Vec::::new()); + let res = gnome_sort(&original); + assert!(is_sorted(&res) && have_same_elements(&res, &original)); + } +} diff --git a/src/sorting/heap_sort.rs b/src/sorting/heap_sort.rs index d860dad0465..7b37a7c5149 100644 --- a/src/sorting/heap_sort.rs +++ b/src/sorting/heap_sort.rs @@ -1,139 +1,114 @@ -/// Sort a mutable slice using heap sort. -/// -/// Heap sort is an in-place O(n log n) sorting algorithm. It is based on a -/// max heap, a binary tree data structure whose main feature is that -/// parent nodes are always greater or equal to their child nodes. -/// -/// # Max Heap Implementation +//! This module provides functions for heap sort algorithm. + +use std::cmp::Ordering; + +/// Builds a heap from the provided array. /// -/// A max heap can be efficiently implemented with an array. -/// For example, the binary tree: -/// ```text -/// 1 -/// 2 3 -/// 4 5 6 7 -/// ``` +/// This function builds either a max heap or a min heap based on the `is_max_heap` parameter. /// -/// ... is represented by the following array: -/// ```text -/// 1 23 4567 -/// ``` +/// # Arguments /// -/// Given the index `i` of a node, parent and child indices can be calculated -/// as follows: -/// ```text -/// parent(i) = (i-1) / 2 -/// left_child(i) = 2*i + 1 -/// right_child(i) = 2*i + 2 -/// ``` +/// * `arr` - A mutable reference to the array to be sorted. +/// * `is_max_heap` - A boolean indicating whether to build a max heap (`true`) or a min heap (`false`). +fn build_heap(arr: &mut [T], is_max_heap: bool) { + let mut i = (arr.len() - 1) / 2; + while i > 0 { + heapify(arr, i, is_max_heap); + i -= 1; + } + heapify(arr, 0, is_max_heap); +} -/// # Algorithm +/// Fixes a heap violation starting at the given index. /// -/// Heap sort has two steps: -/// 1. Convert the input array to a max heap. -/// 2. Partition the array into heap part and sorted part. Initially the -/// heap consists of the whole array and the sorted part is empty: -/// ```text -/// arr: [ heap |] -/// ``` +/// This function adjusts the heap rooted at index `i` to fix the heap property violation. +/// It assumes that the subtrees rooted at left and right children of `i` are already heaps. /// -/// Repeatedly swap the root (i.e. the largest) element of the heap with -/// the last element of the heap and increase the sorted part by one: -/// ```text -/// arr: [ root ... last | sorted ] -/// --> [ last ... | root sorted ] -/// ``` +/// # Arguments /// -/// After each swap, fix the heap to make it a valid max heap again. -/// Once the heap is empty, `arr` is completely sorted. -pub fn heap_sort(arr: &mut [T]) { - if arr.len() <= 1 { - return; // already sorted - } +/// * `arr` - A mutable reference to the array representing the heap. +/// * `i` - The index to start fixing the heap violation. +/// * `is_max_heap` - A boolean indicating whether to maintain a max heap or a min heap. +fn heapify(arr: &mut [T], i: usize, is_max_heap: bool) { + let comparator: fn(&T, &T) -> Ordering = if !is_max_heap { + |a, b| b.cmp(a) + } else { + |a, b| a.cmp(b) + }; - heapify(arr); + let mut idx = i; + let l = 2 * i + 1; + let r = 2 * i + 2; - for end in (1..arr.len()).rev() { - arr.swap(0, end); - move_down(&mut arr[..end], 0); + if l < arr.len() && comparator(&arr[l], &arr[idx]) == Ordering::Greater { + idx = l; } -} -/// Convert `arr` into a max heap. -fn heapify(arr: &mut [T]) { - let last_parent = (arr.len() - 2) / 2; - for i in (0..=last_parent).rev() { - move_down(arr, i); + if r < arr.len() && comparator(&arr[r], &arr[idx]) == Ordering::Greater { + idx = r; + } + + if idx != i { + arr.swap(i, idx); + heapify(arr, idx, is_max_heap); } } -/// Move the element at `root` down until `arr` is a max heap again. +/// Sorts the given array using heap sort algorithm. /// -/// This assumes that the subtrees under `root` are valid max heaps already. -fn move_down(arr: &mut [T], mut root: usize) { - let last = arr.len() - 1; - loop { - let left = 2 * root + 1; - if left > last { - break; - } - let right = left + 1; - let max = if right <= last && arr[right] > arr[left] { - right - } else { - left - }; +/// This function sorts the array either in ascending or descending order based on the `ascending` parameter. +/// +/// # Arguments +/// +/// * `arr` - A mutable reference to the array to be sorted. +/// * `ascending` - A boolean indicating whether to sort in ascending order (`true`) or descending order (`false`). +pub fn heap_sort(arr: &mut [T], ascending: bool) { + if arr.len() <= 1 { + return; + } - if arr[max] > arr[root] { - arr.swap(root, max); - } - root = max; + // Build heap based on the order + build_heap(arr, ascending); + + let mut end = arr.len() - 1; + while end > 0 { + arr.swap(0, end); + heapify(&mut arr[..end], 0, ascending); + end -= 1; } } #[cfg(test)] mod tests { - use super::*; - - #[test] - fn empty() { - let mut arr: Vec = Vec::new(); - heap_sort(&mut arr); - assert_eq!(&arr, &[]); - } - - #[test] - fn single_element() { - let mut arr = vec![1]; - heap_sort(&mut arr); - assert_eq!(&arr, &[1]); - } + use crate::sorting::{have_same_elements, heap_sort, is_descending_sorted, is_sorted}; - #[test] - fn sorted_array() { - let mut arr = vec![1, 2, 3, 4]; - heap_sort(&mut arr); - assert_eq!(&arr, &[1, 2, 3, 4]); - } + macro_rules! test_heap_sort { + ($($name:ident: $input:expr,)*) => { + $( + #[test] + fn $name() { + let input_array = $input; + let mut arr_asc = input_array.clone(); + heap_sort(&mut arr_asc, true); + assert!(is_sorted(&arr_asc) && have_same_elements(&arr_asc, &input_array)); - #[test] - fn unsorted_array() { - let mut arr = vec![3, 4, 2, 1]; - heap_sort(&mut arr); - assert_eq!(&arr, &[1, 2, 3, 4]); - } - - #[test] - fn odd_number_of_elements() { - let mut arr = vec![3, 4, 2, 1, 7]; - heap_sort(&mut arr); - assert_eq!(&arr, &[1, 2, 3, 4, 7]); + let mut arr_dsc = input_array.clone(); + heap_sort(&mut arr_dsc, false); + assert!(is_descending_sorted(&arr_dsc) && have_same_elements(&arr_dsc, &input_array)); + } + )* + } } - #[test] - fn repeated_elements() { - let mut arr = vec![542, 542, 542, 542]; - heap_sort(&mut arr); - assert_eq!(&arr, &vec![542, 542, 542, 542]); + test_heap_sort! { + empty_array: Vec::::new(), + single_element_array: vec![5], + sorted: vec![1, 2, 3, 4, 5], + sorted_desc: vec![5, 4, 3, 2, 1, 0], + basic_0: vec![9, 8, 7, 6, 5], + basic_1: vec![8, 3, 1, 5, 7], + basic_2: vec![4, 5, 7, 1, 2, 3, 2, 8, 5, 4, 9, 9, 100, 1, 2, 3, 6, 4, 3], + duplicated_elements: vec![5, 5, 5, 5, 5], + strings: vec!["aa", "a", "ba", "ab"], } } diff --git a/src/sorting/insertion.rs b/src/sorting/insertion.rs deleted file mode 100644 index baa99024d3b..00000000000 --- a/src/sorting/insertion.rs +++ /dev/null @@ -1,73 +0,0 @@ -use std::cmp; - -#[allow(dead_code)] -pub fn insertion_sort(arr: &[T]) -> Vec -where - T: cmp::PartialEq + cmp::PartialOrd + Clone, -{ - // The resulting vector should contain the same amount of elements as - // the slice that is being sorted, so enough room is preallocated - let mut result: Vec = Vec::with_capacity(arr.len()); - - // Iterate over the elements to sort and - // put a clone of the element to insert in elem. - for elem in arr.iter().cloned() { - // How many elements have already been inserted? - let n_inserted = result.len(); - - // Loop over the inserted elements and one more index. - for i in 0..=n_inserted { - // If at the end or result[i] is larger than the current element, - // we have found the right spot: - if i == n_inserted || result[i] > elem { - // Insert the element at i, - // move the rest to higher indexes: - result.insert(i, elem); - break; - } - } - } - - result -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn empty() { - let res = insertion_sort(&Vec::::new()); - assert_eq!(res, vec![]); - } - - #[test] - fn one_element() { - let res = insertion_sort(&vec!["a"]); - assert_eq!(res, vec!["a"]); - } - - #[test] - fn already_sorted() { - let res = insertion_sort(&vec!["a", "b", "c"]); - assert_eq!(res, vec!["a", "b", "c"]); - } - - #[test] - fn basic() { - let res = insertion_sort(&vec!["d", "a", "c", "b"]); - assert_eq!(res, vec!["a", "b", "c", "d"]); - } - - #[test] - fn odd_number_of_elements() { - let res = insertion_sort(&vec!["d", "a", "c", "e", "b"]); - assert_eq!(res, vec!["a", "b", "c", "d", "e"]); - } - - #[test] - fn repeated_elements() { - let res = insertion_sort(&vec![542, 542, 542, 542]); - assert_eq!(res, vec![542, 542, 542, 542]); - } -} diff --git a/src/sorting/insertion_sort.rs b/src/sorting/insertion_sort.rs new file mode 100644 index 00000000000..ab33241b045 --- /dev/null +++ b/src/sorting/insertion_sort.rs @@ -0,0 +1,72 @@ +/// Sorts a mutable slice using in-place insertion sort algorithm. +/// +/// Time complexity is `O(n^2)`, where `n` is the number of elements. +/// Space complexity is `O(1)` as it sorts elements in-place. +pub fn insertion_sort(arr: &mut [T]) { + for i in 1..arr.len() { + let mut j = i; + let cur = arr[i]; + + while j > 0 && cur < arr[j - 1] { + arr[j] = arr[j - 1]; + j -= 1; + } + + arr[j] = cur; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + + #[test] + fn empty() { + let mut arr: [u8; 0] = []; + let cloned = arr; + insertion_sort(&mut arr); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); + } + + #[test] + fn one_element() { + let mut arr: [char; 1] = ['a']; + let cloned = arr; + insertion_sort(&mut arr); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); + } + + #[test] + fn already_sorted() { + let mut arr: [&str; 3] = ["a", "b", "c"]; + let cloned = arr; + insertion_sort(&mut arr); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); + } + + #[test] + fn basic() { + let mut arr: [&str; 4] = ["d", "a", "c", "b"]; + let cloned = arr; + insertion_sort(&mut arr); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); + } + + #[test] + fn odd_number_of_elements() { + let mut arr: Vec<&str> = vec!["d", "a", "c", "e", "b"]; + let cloned = arr.clone(); + insertion_sort(&mut arr); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); + } + + #[test] + fn repeated_elements() { + let mut arr: Vec = vec![542, 542, 542, 542]; + let cloned = arr.clone(); + insertion_sort(&mut arr); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); + } +} diff --git a/src/sorting/intro_sort.rs b/src/sorting/intro_sort.rs new file mode 100755 index 00000000000..d6b31c17a02 --- /dev/null +++ b/src/sorting/intro_sort.rs @@ -0,0 +1,107 @@ +// Intro Sort (Also known as Introspective Sort) +// Introspective Sort is hybrid sort (Quick Sort + Heap Sort + Insertion Sort) +// https://en.wikipedia.org/wiki/Introsort +fn insertion_sort(arr: &mut [T]) { + for i in 1..arr.len() { + let mut j = i; + while j > 0 && arr[j] < arr[j - 1] { + arr.swap(j, j - 1); + j -= 1; + } + } +} + +fn heapify(arr: &mut [T], n: usize, i: usize) { + let mut largest = i; + let left = 2 * i + 1; + let right = 2 * i + 2; + + if left < n && arr[left] > arr[largest] { + largest = left; + } + + if right < n && arr[right] > arr[largest] { + largest = right; + } + + if largest != i { + arr.swap(i, largest); + heapify(arr, n, largest); + } +} + +fn heap_sort(arr: &mut [T]) { + let n = arr.len(); + + // Build a max-heap + for i in (0..n / 2).rev() { + heapify(arr, n, i); + } + + // Extract elements from the heap one by one + for i in (0..n).rev() { + arr.swap(0, i); + heapify(arr, i, 0); + } +} + +pub fn intro_sort(arr: &mut [T]) { + let len = arr.len(); + let max_depth = (2.0 * len as f64).log2() as usize + 1; + + fn intro_sort_recursive(arr: &mut [T], max_depth: usize) { + let len = arr.len(); + + if len <= 16 { + insertion_sort(arr); + } else if max_depth == 0 { + heap_sort(arr); + } else { + let pivot = partition(arr); + intro_sort_recursive(&mut arr[..pivot], max_depth - 1); + intro_sort_recursive(&mut arr[pivot + 1..], max_depth - 1); + } + } + + fn partition(arr: &mut [T]) -> usize { + let len = arr.len(); + let pivot_index = len / 2; + arr.swap(pivot_index, len - 1); + + let mut i = 0; + for j in 0..len - 1 { + if arr[j] <= arr[len - 1] { + arr.swap(i, j); + i += 1; + } + } + + arr.swap(i, len - 1); + i + } + + intro_sort_recursive(arr, max_depth); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_intro_sort() { + // Test with integers + let mut arr1 = vec![67, 34, 29, 15, 21, 9, 99]; + intro_sort(&mut arr1); + assert_eq!(arr1, vec![9, 15, 21, 29, 34, 67, 99]); + + // Test with strings + let mut arr2 = vec!["sydney", "london", "tokyo", "beijing", "mumbai"]; + intro_sort(&mut arr2); + assert_eq!(arr2, vec!["beijing", "london", "mumbai", "sydney", "tokyo"]); + + // Test with an empty array + let mut arr3: Vec = vec![]; + intro_sort(&mut arr3); + assert_eq!(arr3, vec![]); + } +} diff --git a/src/sorting/merge_sort.rs b/src/sorting/merge_sort.rs index 6d95a49b8b1..4c184c86110 100644 --- a/src/sorting/merge_sort.rs +++ b/src/sorting/merge_sort.rs @@ -1,107 +1,169 @@ -fn _merge(arr: &mut [T], lo: usize, mid: usize, hi: usize) { - // create temporary arrays to support merge - let mut left_half = Vec::new(); - let mut right_half = Vec::new(); - for v in arr.iter().take(mid + 1).skip(lo) { - left_half.push(*v); - } - for v in arr.iter().take(hi + 1).skip(mid + 1) { - right_half.push(*v); - } +fn merge(arr: &mut [T], mid: usize) { + // Create temporary vectors to support the merge. + let left_half = arr[..mid].to_vec(); + let right_half = arr[mid..].to_vec(); - let lsize = left_half.len(); - let rsize = right_half.len(); - - // pointers to track the positions while merging + // Indexes to track the positions while merging. let mut l = 0; let mut r = 0; - let mut a = lo; - // pick smaller element one by one from either left or right half - while l < lsize && r < rsize { - if left_half[l] < right_half[r] { - arr[a] = left_half[l]; + for v in arr { + // Choose either the smaller element, or from whichever vec is not exhausted. + if r == right_half.len() || (l < left_half.len() && left_half[l] < right_half[r]) { + *v = left_half[l]; l += 1; } else { - arr[a] = right_half[r]; + *v = right_half[r]; r += 1; } - a += 1; - } - - // put all the remaining ones - while l < lsize { - arr[a] = left_half[l]; - l += 1; - a += 1; - } - - while r < rsize { - arr[a] = right_half[r]; - r += 1; - a += 1; } } -fn _merge_sort(arr: &mut [T], lo: usize, hi: usize) { - if lo < hi { - let mid = lo + (hi - lo) / 2; - _merge_sort(arr, lo, mid); - _merge_sort(arr, mid + 1, hi); - _merge(arr, lo, mid, hi); +pub fn top_down_merge_sort(arr: &mut [T]) { + if arr.len() > 1 { + let mid = arr.len() / 2; + // Sort the left half recursively. + top_down_merge_sort(&mut arr[..mid]); + // Sort the right half recursively. + top_down_merge_sort(&mut arr[mid..]); + // Combine the two halves. + merge(arr, mid); } } -pub fn merge_sort(arr: &mut [T]) { - let len = arr.len(); - if len > 1 { - _merge_sort(arr, 0, len - 1); +pub fn bottom_up_merge_sort(a: &mut [T]) { + if a.len() > 1 { + let len: usize = a.len(); + let mut sub_array_size: usize = 1; + while sub_array_size < len { + let mut start_index: usize = 0; + // still have more than one sub-arrays to merge + while len - start_index > sub_array_size { + let end_idx: usize = if start_index + 2 * sub_array_size > len { + len + } else { + start_index + 2 * sub_array_size + }; + // merge a[start_index..start_index+sub_array_size] and a[start_index+sub_array_size..end_idx] + // NOTE: mid is a relative index number starting from `start_index` + merge(&mut a[start_index..end_idx], sub_array_size); + // update `start_index` to merge the next sub-arrays + start_index = end_idx; + } + sub_array_size *= 2; + } } } #[cfg(test)] mod tests { - use super::*; + #[cfg(test)] + mod top_down_merge_sort { + use super::super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; - #[test] - fn basic() { - let mut res = vec![10, 8, 4, 3, 1, 9, 2, 7, 5, 6]; - merge_sort(&mut res); - assert_eq!(res, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); - } + #[test] + fn basic() { + let mut res = vec![10, 8, 4, 3, 1, 9, 2, 7, 5, 6]; + let cloned = res.clone(); + top_down_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } - #[test] - fn basic_string() { - let mut res = vec!["a", "bb", "d", "cc"]; - merge_sort(&mut res); - assert_eq!(res, vec!["a", "bb", "cc", "d"]); - } + #[test] + fn basic_string() { + let mut res = vec!["a", "bb", "d", "cc"]; + let cloned = res.clone(); + top_down_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } - #[test] - fn empty() { - let mut res = Vec::::new(); - merge_sort(&mut res); - assert_eq!(res, vec![]); - } + #[test] + fn empty() { + let mut res = Vec::::new(); + let cloned = res.clone(); + top_down_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } - #[test] - fn one_element() { - let mut res = vec![1]; - merge_sort(&mut res); - assert_eq!(res, vec![1]); - } + #[test] + fn one_element() { + let mut res = vec![1]; + let cloned = res.clone(); + top_down_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } - #[test] - fn pre_sorted() { - let mut res = vec![1, 2, 3, 4]; - merge_sort(&mut res); - assert_eq!(res, vec![1, 2, 3, 4]); + #[test] + fn pre_sorted() { + let mut res = vec![1, 2, 3, 4]; + let cloned = res.clone(); + top_down_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn reverse_sorted() { + let mut res = vec![4, 3, 2, 1]; + let cloned = res.clone(); + top_down_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } } - #[test] - fn reverse_sorted() { - let mut res = vec![4, 3, 2, 1]; - merge_sort(&mut res); - assert_eq!(res, vec![1, 2, 3, 4]); + #[cfg(test)] + mod bottom_up_merge_sort { + use super::super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + + #[test] + fn basic() { + let mut res = vec![10, 8, 4, 3, 1, 9, 2, 7, 5, 6]; + let cloned = res.clone(); + bottom_up_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn basic_string() { + let mut res = vec!["a", "bb", "d", "cc"]; + let cloned = res.clone(); + bottom_up_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn empty() { + let mut res = Vec::::new(); + let cloned = res.clone(); + bottom_up_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn one_element() { + let mut res = vec![1]; + let cloned = res.clone(); + bottom_up_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn pre_sorted() { + let mut res = vec![1, 2, 3, 4]; + let cloned = res.clone(); + bottom_up_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn reverse_sorted() { + let mut res = vec![4, 3, 2, 1]; + let cloned = res.clone(); + bottom_up_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } } } diff --git a/src/sorting/mod.rs b/src/sorting/mod.rs index ea608d74f87..79be2b0b9e6 100644 --- a/src/sorting/mod.rs +++ b/src/sorting/mod.rs @@ -1,43 +1,114 @@ +mod bead_sort; +mod binary_insertion_sort; +mod bingo_sort; +mod bitonic_sort; +mod bogo_sort; mod bubble_sort; +mod bucket_sort; +mod cocktail_shaker_sort; +mod comb_sort; mod counting_sort; +mod cycle_sort; +mod dutch_national_flag_sort; +mod exchange_sort; +mod gnome_sort; mod heap_sort; -mod insertion; +mod insertion_sort; +mod intro_sort; mod merge_sort; +mod odd_even_sort; +mod pancake_sort; +mod patience_sort; +mod pigeonhole_sort; mod quick_sort; +mod quick_sort_3_ways; mod radix_sort; mod selection_sort; +mod shell_sort; +mod sleep_sort; +#[cfg(test)] +mod sort_utils; +mod stooge_sort; +mod tim_sort; +mod tree_sort; +mod wave_sort; +mod wiggle_sort; -use std::cmp; - +pub use self::bead_sort::bead_sort; +pub use self::binary_insertion_sort::binary_insertion_sort; +pub use self::bingo_sort::bingo_sort; +pub use self::bitonic_sort::bitonic_sort; +pub use self::bogo_sort::bogo_sort; pub use self::bubble_sort::bubble_sort; +pub use self::bucket_sort::bucket_sort; +pub use self::cocktail_shaker_sort::cocktail_shaker_sort; +pub use self::comb_sort::comb_sort; pub use self::counting_sort::counting_sort; pub use self::counting_sort::generic_counting_sort; +pub use self::cycle_sort::cycle_sort; +pub use self::dutch_national_flag_sort::dutch_national_flag_sort; +pub use self::exchange_sort::exchange_sort; +pub use self::gnome_sort::gnome_sort; pub use self::heap_sort::heap_sort; -pub use self::insertion::insertion_sort; -pub use self::merge_sort::merge_sort; -pub use self::quick_sort::quick_sort; +pub use self::insertion_sort::insertion_sort; +pub use self::intro_sort::intro_sort; +pub use self::merge_sort::bottom_up_merge_sort; +pub use self::merge_sort::top_down_merge_sort; +pub use self::odd_even_sort::odd_even_sort; +pub use self::pancake_sort::pancake_sort; +pub use self::patience_sort::patience_sort; +pub use self::pigeonhole_sort::pigeonhole_sort; +pub use self::quick_sort::{partition, quick_sort}; +pub use self::quick_sort_3_ways::quick_sort_3_ways; pub use self::radix_sort::radix_sort; pub use self::selection_sort::selection_sort; +pub use self::shell_sort::shell_sort; +pub use self::sleep_sort::sleep_sort; +pub use self::stooge_sort::stooge_sort; +pub use self::tim_sort::tim_sort; +pub use self::tree_sort::tree_sort; +pub use self::wave_sort::wave_sort; +pub use self::wiggle_sort::wiggle_sort; -pub fn is_sorted(arr: &[T]) -> bool +#[cfg(test)] +use std::cmp; + +#[cfg(test)] +pub fn have_same_elements(a: &[T], b: &[T]) -> bool where - T: cmp::PartialOrd, + // T: cmp::PartialOrd, + // If HashSet is used + T: cmp::PartialOrd + cmp::Eq + std::hash::Hash, { - if arr.is_empty() { - return true; - } - - let mut prev = &arr[0]; + use std::collections::HashSet; - for item in arr.iter().skip(1) { - if prev > item { - return false; - } + if a.len() == b.len() { + // This is O(n^2) but performs better on smaller data sizes + //b.iter().all(|item| a.contains(item)) - prev = item; + // This is O(n), performs well on larger data sizes + let set_a: HashSet<&T> = a.iter().collect(); + let set_b: HashSet<&T> = b.iter().collect(); + set_a == set_b + } else { + false } +} - true +#[cfg(test)] +pub fn is_sorted(arr: &[T]) -> bool +where + T: cmp::PartialOrd, +{ + arr.windows(2).all(|w| w[0] <= w[1]) +} + +#[cfg(test)] +pub fn is_descending_sorted(arr: &[T]) -> bool +where + T: cmp::PartialOrd, +{ + arr.windows(2).all(|w| w[0] >= w[1]) } #[cfg(test)] @@ -51,7 +122,7 @@ mod tests { assert!(is_sorted(&[1, 2, 3])); assert!(is_sorted(&[0, 1, 1])); - assert_eq!(is_sorted(&[1, 0]), false); - assert_eq!(is_sorted(&[2, 3, 1, -1, 5]), false); + assert!(!is_sorted(&[1, 0])); + assert!(!is_sorted(&[2, 3, 1, -1, 5])); } } diff --git a/src/sorting/odd_even_sort.rs b/src/sorting/odd_even_sort.rs new file mode 100644 index 00000000000..c22a1c4daa1 --- /dev/null +++ b/src/sorting/odd_even_sort.rs @@ -0,0 +1,64 @@ +pub fn odd_even_sort(arr: &mut [T]) { + let len = arr.len(); + if len == 0 { + return; + } + + let mut sorted = false; + while !sorted { + sorted = true; + + for i in (1..len - 1).step_by(2) { + if arr[i] > arr[i + 1] { + arr.swap(i, i + 1); + sorted = false; + } + } + + for i in (0..len - 1).step_by(2) { + if arr[i] > arr[i + 1] { + arr.swap(i, i + 1); + sorted = false; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + + #[test] + fn basic() { + let mut arr = vec![3, 5, 1, 2, 4, 6]; + let cloned = arr.clone(); + odd_even_sort(&mut arr); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); + } + + #[test] + fn empty() { + let mut arr = Vec::::new(); + let cloned = arr.clone(); + odd_even_sort(&mut arr); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); + } + + #[test] + fn one_element() { + let mut arr = vec![3]; + let cloned = arr.clone(); + odd_even_sort(&mut arr); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); + } + + #[test] + fn pre_sorted() { + let mut arr = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + let cloned = arr.clone(); + odd_even_sort(&mut arr); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); + } +} diff --git a/src/sorting/pancake_sort.rs b/src/sorting/pancake_sort.rs new file mode 100644 index 00000000000..c37b646ca1a --- /dev/null +++ b/src/sorting/pancake_sort.rs @@ -0,0 +1,60 @@ +use std::cmp; + +pub fn pancake_sort(arr: &mut [T]) -> Vec +where + T: cmp::PartialEq + cmp::Ord + cmp::PartialOrd + Clone, +{ + let len = arr.len(); + if len < 2 { + arr.to_vec(); + } + for i in (0..len).rev() { + let max_index = arr + .iter() + .take(i + 1) + .enumerate() + .max_by_key(|&(_, elem)| elem) + .map(|(idx, _)| idx) + .unwrap(); + if max_index != i { + arr[0..=max_index].reverse(); + arr[0..=i].reverse(); + } + } + arr.to_vec() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basic() { + let res = pancake_sort(&mut [6, 5, -8, 3, 2, 3]); + assert_eq!(res, vec![-8, 2, 3, 3, 5, 6]); + } + + #[test] + fn already_sorted() { + let res = pancake_sort(&mut ["a", "b", "c"]); + assert_eq!(res, vec!["a", "b", "c"]); + } + + #[test] + fn odd_number_of_elements() { + let res = pancake_sort(&mut ["d", "a", "c", "e", "b"]); + assert_eq!(res, vec!["a", "b", "c", "d", "e"]); + } + + #[test] + fn one_element() { + let res = pancake_sort(&mut [3]); + assert_eq!(res, vec![3]); + } + + #[test] + fn empty() { + let res = pancake_sort(&mut [] as &mut [u8]); + assert_eq!(res, vec![]); + } +} diff --git a/src/sorting/patience_sort.rs b/src/sorting/patience_sort.rs new file mode 100644 index 00000000000..662d8ceefcb --- /dev/null +++ b/src/sorting/patience_sort.rs @@ -0,0 +1,86 @@ +use std::vec; + +pub fn patience_sort(arr: &mut [T]) { + if arr.is_empty() { + return; + } + + // collect piles from arr + let mut piles: Vec> = Vec::new(); + for &card in arr.iter() { + let mut left = 0usize; + let mut right = piles.len(); + + while left < right { + let mid = left + (right - left) / 2; + if piles[mid][piles[mid].len() - 1] >= card { + right = mid; + } else { + left = mid + 1; + } + } + + if left == piles.len() { + piles.push(vec![card]); + } else { + piles[left].push(card); + } + } + + // merge the piles + let mut idx = 0usize; + while let Some((min_id, pile)) = piles + .iter() + .enumerate() + .min_by_key(|(_, pile)| *pile.last().unwrap()) + { + arr[idx] = *pile.last().unwrap(); + idx += 1; + piles[min_id].pop(); + + if piles[min_id].is_empty() { + _ = piles.remove(min_id); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + + #[test] + fn basic() { + let mut array = vec![ + -2, 7, 15, -14, 0, 15, 0, 10_033, 7, -7, -4, -13, 5, 8, -14, 12, + ]; + let cloned = array.clone(); + patience_sort(&mut array); + assert!(is_sorted(&array) && have_same_elements(&array, &cloned)); + } + + #[test] + fn empty() { + let mut array = Vec::::new(); + let cloned = array.clone(); + patience_sort(&mut array); + assert!(is_sorted(&array) && have_same_elements(&array, &cloned)); + } + + #[test] + fn one_element() { + let mut array = vec![3]; + let cloned = array.clone(); + patience_sort(&mut array); + assert!(is_sorted(&array) && have_same_elements(&array, &cloned)); + } + + #[test] + fn pre_sorted() { + let mut array = vec![-123_456, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + let cloned = array.clone(); + patience_sort(&mut array); + assert!(is_sorted(&array) && have_same_elements(&array, &cloned)); + } +} diff --git a/src/sorting/pigeonhole_sort.rs b/src/sorting/pigeonhole_sort.rs new file mode 100644 index 00000000000..0a7a7de3066 --- /dev/null +++ b/src/sorting/pigeonhole_sort.rs @@ -0,0 +1,38 @@ +// From Wikipedia: Pigeonhole sorting is a sorting algorithm that is suitable for sorting lists of elements where the number of elements (n) and the length of the range of possible key values (N) are approximately the same. It requires O(n + N) time. + +pub fn pigeonhole_sort(array: &mut [i32]) { + if let (Some(min), Some(max)) = (array.iter().min(), array.iter().max()) { + let holes_range: usize = (max - min + 1) as usize; + let mut holes = vec![0; holes_range]; + let mut holes_repeat = vec![0; holes_range]; + for i in array.iter() { + let index = *i - min; + holes[index as usize] = *i; + holes_repeat[index as usize] += 1; + } + let mut index = 0; + for i in 0..holes_range { + while holes_repeat[i] > 0 { + array[index] = holes[i]; + index += 1; + holes_repeat[i] -= 1; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::super::is_sorted; + use super::*; + + #[test] + fn test1() { + let mut arr1 = [3, 3, 3, 1, 2, 6, 5, 5, 5, 4, 1, 6, 3]; + pigeonhole_sort(&mut arr1); + assert!(is_sorted(&arr1)); + let mut arr2 = [6, 5, 4, 3, 2, 1]; + pigeonhole_sort(&mut arr2); + assert!(is_sorted(&arr2)); + } +} diff --git a/src/sorting/quick_sort.rs b/src/sorting/quick_sort.rs index 0c2cab6bba7..102edc6337d 100644 --- a/src/sorting/quick_sort.rs +++ b/src/sorting/quick_sort.rs @@ -1,34 +1,137 @@ -fn _partition(arr: &mut [T], lo: isize, hi: isize) -> isize { - let pivot = hi as usize; - let mut i = lo - 1; - let mut j = hi; +pub fn partition(arr: &mut [T], lo: usize, hi: usize) -> usize { + let pivot = hi; + let mut i = lo; + let mut j = hi - 1; loop { - i += 1; - while arr[i as usize] < arr[pivot] { + while arr[i] < arr[pivot] { i += 1; } - j -= 1; - while j >= 0 && arr[j as usize] > arr[pivot] { + while j > 0 && arr[j] > arr[pivot] { j -= 1; } - if i >= j { + if j == 0 || i >= j { break; + } else if arr[i] == arr[j] { + i += 1; + j -= 1; } else { - arr.swap(i as usize, j as usize); + arr.swap(i, j); } } - arr.swap(i as usize, pivot as usize); + arr.swap(i, pivot); i } -fn _quick_sort(arr: &mut [T], lo: isize, hi: isize) { - if lo < hi { - let p = _partition(arr, lo, hi); - _quick_sort(arr, lo, p - 1); - _quick_sort(arr, p + 1, hi); + +fn _quick_sort(arr: &mut [T], mut lo: usize, mut hi: usize) { + while lo < hi { + let pivot = partition(arr, lo, hi); + + if pivot - lo < hi - pivot { + if pivot > 0 { + _quick_sort(arr, lo, pivot - 1); + } + lo = pivot + 1; + } else { + _quick_sort(arr, pivot + 1, hi); + hi = pivot - 1; + } } } + pub fn quick_sort(arr: &mut [T]) { let len = arr.len(); - _quick_sort(arr, 0, (len - 1) as isize); + if len > 1 { + _quick_sort(arr, 0, len - 1); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + use crate::sorting::sort_utils; + + #[test] + fn basic() { + let mut res = vec![10, 8, 4, 3, 1, 9, 2, 7, 5, 6]; + let cloned = res.clone(); + quick_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn basic_string() { + let mut res = vec!["a", "bb", "d", "cc"]; + let cloned = res.clone(); + quick_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn empty() { + let mut res = Vec::::new(); + let cloned = res.clone(); + quick_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn one_element() { + let mut res = vec![1]; + let cloned = res.clone(); + quick_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn pre_sorted() { + let mut res = vec![1, 2, 3, 4]; + let cloned = res.clone(); + quick_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn reverse_sorted() { + let mut res = vec![4, 3, 2, 1]; + let cloned = res.clone(); + quick_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn large_elements() { + let mut res = sort_utils::generate_random_vec(300000, 0, 1000000); + let cloned = res.clone(); + sort_utils::log_timed("large elements test", || { + quick_sort(&mut res); + }); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn nearly_ordered_elements() { + let mut res = sort_utils::generate_nearly_ordered_vec(3000, 10); + let cloned = res.clone(); + + sort_utils::log_timed("nearly ordered elements test", || { + quick_sort(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn repeated_elements() { + let mut res = sort_utils::generate_repeated_elements_vec(1000000, 3); + let cloned = res.clone(); + + sort_utils::log_timed("repeated elements test", || { + quick_sort(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } } diff --git a/src/sorting/quick_sort_3_ways.rs b/src/sorting/quick_sort_3_ways.rs new file mode 100644 index 00000000000..cf333114170 --- /dev/null +++ b/src/sorting/quick_sort_3_ways.rs @@ -0,0 +1,156 @@ +use std::cmp::{Ord, Ordering}; + +use rand::Rng; + +fn _quick_sort_3_ways(arr: &mut [T], lo: usize, hi: usize) { + if lo >= hi { + return; + } + + let mut rng = rand::rng(); + arr.swap(lo, rng.random_range(lo..=hi)); + + let mut lt = lo; // arr[lo+1, lt] < v + let mut gt = hi + 1; // arr[gt, r] > v + let mut i = lo + 1; // arr[lt + 1, i) == v + + while i < gt { + match arr[i].cmp(&arr[lo]) { + Ordering::Less => { + arr.swap(i, lt + 1); + i += 1; + lt += 1; + } + Ordering::Greater => { + arr.swap(i, gt - 1); + gt -= 1; + } + Ordering::Equal => { + i += 1; + } + } + } + + arr.swap(lo, lt); + + if lt > 1 { + _quick_sort_3_ways(arr, lo, lt - 1); + } + + _quick_sort_3_ways(arr, gt, hi); +} + +pub fn quick_sort_3_ways(arr: &mut [T]) { + let len = arr.len(); + if len > 1 { + _quick_sort_3_ways(arr, 0, len - 1); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + use crate::sorting::sort_utils; + + #[test] + fn basic() { + let mut res = vec![10, 8, 4, 3, 1, 9, 2, 7, 5, 6]; + let cloned = res.clone(); + sort_utils::log_timed("basic", || { + quick_sort_3_ways(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn basic_string() { + let mut res = vec!["a", "bb", "d", "cc"]; + let cloned = res.clone(); + sort_utils::log_timed("basic string", || { + quick_sort_3_ways(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn empty() { + let mut res = Vec::::new(); + let cloned = res.clone(); + sort_utils::log_timed("empty", || { + quick_sort_3_ways(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn one_element() { + let mut res = sort_utils::generate_random_vec(1, 0, 1); + let cloned = res.clone(); + sort_utils::log_timed("one element", || { + quick_sort_3_ways(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn pre_sorted() { + let mut res = sort_utils::generate_nearly_ordered_vec(300000, 0); + let cloned = res.clone(); + sort_utils::log_timed("pre sorted", || { + quick_sort_3_ways(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn reverse_sorted() { + let mut res = sort_utils::generate_reverse_ordered_vec(300000); + let cloned = res.clone(); + sort_utils::log_timed("reverse sorted", || { + quick_sort_3_ways(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn large_elements() { + let mut res = sort_utils::generate_random_vec(300000, 0, 1000000); + let cloned = res.clone(); + sort_utils::log_timed("large elements test", || { + quick_sort_3_ways(&mut res); + }); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn nearly_ordered_elements() { + let mut res = sort_utils::generate_nearly_ordered_vec(300000, 10); + let cloned = res.clone(); + + sort_utils::log_timed("nearly ordered elements test", || { + quick_sort_3_ways(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn repeated_elements() { + let mut res = sort_utils::generate_repeated_elements_vec(1000000, 3); + let cloned = res.clone(); + + sort_utils::log_timed("repeated elements test", || { + quick_sort_3_ways(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } +} diff --git a/src/sorting/radix_sort.rs b/src/sorting/radix_sort.rs index 19c081a591b..8f14efb2058 100644 --- a/src/sorting/radix_sort.rs +++ b/src/sorting/radix_sort.rs @@ -36,27 +36,31 @@ pub fn radix_sort(arr: &mut [u64]) { #[cfg(test)] mod tests { - use super::super::is_sorted; use super::radix_sort; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn empty() { let mut a: [u64; 0] = []; + let cloned = a; radix_sort(&mut a); - assert!(is_sorted(&a)); + assert!(is_sorted(&a) && have_same_elements(&a, &cloned)); } #[test] fn descending() { let mut v = vec![201, 127, 64, 37, 24, 4, 1]; + let cloned = v.clone(); radix_sort(&mut v); - assert!(is_sorted(&v)); + assert!(is_sorted(&v) && have_same_elements(&v, &cloned)); } #[test] fn ascending() { let mut v = vec![1, 4, 24, 37, 64, 127, 201]; + let cloned = v.clone(); radix_sort(&mut v); - assert!(is_sorted(&v)); + assert!(is_sorted(&v) && have_same_elements(&v, &cloned)); } } diff --git a/src/sorting/selection_sort.rs b/src/sorting/selection_sort.rs index 59fa6e0bf48..eebaebed3aa 100644 --- a/src/sorting/selection_sort.rs +++ b/src/sorting/selection_sort.rs @@ -14,32 +14,38 @@ pub fn selection_sort(arr: &mut [T]) { #[cfg(test)] mod tests { use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn basic() { let mut res = vec!["d", "a", "c", "b"]; + let cloned = res.clone(); selection_sort(&mut res); - assert_eq!(res, vec!["a", "b", "c", "d"]); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); } #[test] fn empty() { let mut res = Vec::::new(); + let cloned = res.clone(); selection_sort(&mut res); - assert_eq!(res, vec![]); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); } #[test] fn one_element() { let mut res = vec!["a"]; + let cloned = res.clone(); selection_sort(&mut res); - assert_eq!(res, vec!["a"]); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); } #[test] fn pre_sorted() { let mut res = vec!["a", "b", "c"]; + let cloned = res.clone(); selection_sort(&mut res); - assert_eq!(res, vec!["a", "b", "c"]); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); } } diff --git a/src/sorting/shell_sort.rs b/src/sorting/shell_sort.rs index 37c99731593..40a61b223b5 100644 --- a/src/sorting/shell_sort.rs +++ b/src/sorting/shell_sort.rs @@ -1,13 +1,13 @@ -pub fn shell_sort(values: &mut Vec) { +pub fn shell_sort(values: &mut [T]) { // shell sort works by swiping the value at a given gap and decreasing the gap to 1 - fn insertion(values: &mut Vec, start: usize, gap: usize) { + fn insertion(values: &mut [T], start: usize, gap: usize) { for i in ((start + gap)..values.len()).step_by(gap) { let val_current = values[i]; let mut pos = i; // make swaps while pos >= gap && values[pos - gap] > val_current { values[pos] = values[pos - gap]; - pos = pos - gap; + pos -= gap; } values[pos] = val_current; } @@ -24,39 +24,39 @@ pub fn shell_sort(values: &mut Vec) { #[cfg(test)] mod test { - use super::*; + use super::shell_sort; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn basic() { let mut vec = vec![3, 5, 6, 3, 1, 4]; - shell_sort(&mut ve3); - for i in 0..ve3.len() - 1 { - assert!(ve3[i] <= ve3[i + 1]); - } + let cloned = vec.clone(); + shell_sort(&mut vec); + assert!(is_sorted(&vec) && have_same_elements(&vec, &cloned)); } #[test] fn empty() { let mut vec: Vec = vec![]; + let cloned = vec.clone(); shell_sort(&mut vec); - assert_eq!(vec, vec![]); + assert!(is_sorted(&vec) && have_same_elements(&vec, &cloned)); } #[test] fn reverse() { let mut vec = vec![6, 5, 4, 3, 2, 1]; - shell_sort(&mut ve1); - for i in 0..ve1.len() - 1 { - assert!(ve1[i] <= ve1[i + 1]); - } + let cloned = vec.clone(); + shell_sort(&mut vec); + assert!(is_sorted(&vec) && have_same_elements(&vec, &cloned)); } #[test] fn already_sorted() { let mut vec = vec![1, 2, 3, 4, 5, 6]; - shell_sort(&mut ve2); - for i in 0..ve2.len() - 1 { - assert!(ve2[i] <= ve2[i + 1]); - } + let cloned = vec.clone(); + shell_sort(&mut vec); + assert!(is_sorted(&vec) && have_same_elements(&vec, &cloned)); } } diff --git a/src/sorting/sleep_sort.rs b/src/sorting/sleep_sort.rs new file mode 100644 index 00000000000..184eedb2539 --- /dev/null +++ b/src/sorting/sleep_sort.rs @@ -0,0 +1,70 @@ +use std::sync::mpsc; +use std::thread; +use std::time::Duration; + +pub fn sleep_sort(vec: &[usize]) -> Vec { + let len = vec.len(); + let (tx, rx) = mpsc::channel(); + + for &x in vec.iter() { + let tx: mpsc::Sender = tx.clone(); + thread::spawn(move || { + thread::sleep(Duration::from_millis((20 * x) as u64)); + tx.send(x).expect("panic"); + }); + } + let mut sorted_list: Vec = Vec::new(); + + for _ in 0..len { + sorted_list.push(rx.recv().unwrap()) + } + + sorted_list +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn empty() { + let res = sleep_sort(&[]); + assert_eq!(res, &[]); + } + + #[test] + fn single_element() { + let res = sleep_sort(&[1]); + assert_eq!(res, &[1]); + } + + #[test] + fn sorted_array() { + let res = sleep_sort(&[1, 2, 3, 4]); + assert_eq!(res, &[1, 2, 3, 4]); + } + + #[test] + fn unsorted_array() { + let res = sleep_sort(&[3, 4, 2, 1]); + assert_eq!(res, &[1, 2, 3, 4]); + } + + #[test] + fn odd_number_of_elements() { + let res = sleep_sort(&[3, 1, 7]); + assert_eq!(res, &[1, 3, 7]); + } + + #[test] + fn repeated_elements() { + let res = sleep_sort(&[1, 1, 1, 1]); + assert_eq!(res, &[1, 1, 1, 1]); + } + + #[test] + fn random_elements() { + let res = sleep_sort(&[5, 3, 7, 10, 1, 0, 8]); + assert_eq!(res, &[0, 1, 3, 5, 7, 8, 10]); + } +} diff --git a/src/sorting/sort_utils.rs b/src/sorting/sort_utils.rs new file mode 100644 index 00000000000..140d10a7f33 --- /dev/null +++ b/src/sorting/sort_utils.rs @@ -0,0 +1,63 @@ +use rand::Rng; +use std::time::Instant; + +#[cfg(test)] +pub fn generate_random_vec(n: u32, range_l: i32, range_r: i32) -> Vec { + let mut arr = Vec::::with_capacity(n as usize); + let mut rng = rand::rng(); + let mut count = n; + + while count > 0 { + arr.push(rng.random_range(range_l..=range_r)); + count -= 1; + } + + arr +} + +#[cfg(test)] +pub fn generate_nearly_ordered_vec(n: u32, swap_times: u32) -> Vec { + let mut arr: Vec = (0..n as i32).collect(); + let mut rng = rand::rng(); + + let mut count = swap_times; + + while count > 0 { + arr.swap( + rng.random_range(0..n as usize), + rng.random_range(0..n as usize), + ); + count -= 1; + } + + arr +} + +#[cfg(test)] +pub fn generate_ordered_vec(n: u32) -> Vec { + generate_nearly_ordered_vec(n, 0) +} + +#[cfg(test)] +pub fn generate_reverse_ordered_vec(n: u32) -> Vec { + let mut arr = generate_ordered_vec(n); + arr.reverse(); + arr +} + +#[cfg(test)] +pub fn generate_repeated_elements_vec(n: u32, unique_elements: u8) -> Vec { + let mut rng = rand::rng(); + let v = rng.random_range(0..n as i32); + generate_random_vec(n, v, v + unique_elements as i32) +} + +#[cfg(test)] +pub fn log_timed(test_name: &str, f: F) +where + F: FnOnce(), +{ + let before = Instant::now(); + f(); + println!("Elapsed time of {:?} is {:?}", test_name, before.elapsed()); +} diff --git a/src/sorting/stooge_sort.rs b/src/sorting/stooge_sort.rs new file mode 100644 index 00000000000..bae8b5bfdac --- /dev/null +++ b/src/sorting/stooge_sort.rs @@ -0,0 +1,63 @@ +fn _stooge_sort(arr: &mut [T], start: usize, end: usize) { + if arr[start] > arr[end] { + arr.swap(start, end); + } + + if start + 1 >= end { + return; + } + + let k = (end - start + 1) / 3; + + _stooge_sort(arr, start, end - k); + _stooge_sort(arr, start + k, end); + _stooge_sort(arr, start, end - k); +} + +pub fn stooge_sort(arr: &mut [T]) { + let len = arr.len(); + if len == 0 { + return; + } + + _stooge_sort(arr, 0, len - 1); +} + +#[cfg(test)] +mod test { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + + #[test] + fn basic() { + let mut vec = vec![3, 5, 6, 3, 1, 4]; + let cloned = vec.clone(); + stooge_sort(&mut vec); + assert!(is_sorted(&vec) && have_same_elements(&vec, &cloned)); + } + + #[test] + fn empty() { + let mut vec: Vec = vec![]; + let cloned = vec.clone(); + stooge_sort(&mut vec); + assert!(is_sorted(&vec) && have_same_elements(&vec, &cloned)); + } + + #[test] + fn reverse() { + let mut vec = vec![6, 5, 4, 3, 2, 1]; + let cloned = vec.clone(); + stooge_sort(&mut vec); + assert!(is_sorted(&vec) && have_same_elements(&vec, &cloned)); + } + + #[test] + fn already_sorted() { + let mut vec = vec![1, 2, 3, 4, 5, 6]; + let cloned = vec.clone(); + stooge_sort(&mut vec); + assert!(is_sorted(&vec) && have_same_elements(&vec, &cloned)); + } +} diff --git a/src/sorting/tim_sort.rs b/src/sorting/tim_sort.rs new file mode 100644 index 00000000000..04398a6aba9 --- /dev/null +++ b/src/sorting/tim_sort.rs @@ -0,0 +1,167 @@ +//! Implements Tim sort algorithm. +//! +//! Tim sort is a hybrid sorting algorithm derived from merge sort and insertion sort. +//! It is designed to perform well on many kinds of real-world data. + +use crate::sorting::insertion_sort; +use std::cmp; + +static MIN_MERGE: usize = 32; + +/// Calculates the minimum run length for Tim sort based on the length of the array. +/// +/// The minimum run length is determined using a heuristic that ensures good performance. +/// +/// # Arguments +/// +/// * `array_length` - The length of the array. +/// +/// # Returns +/// +/// The minimum run length. +fn compute_min_run_length(array_length: usize) -> usize { + let mut remaining_length = array_length; + let mut result = 0; + + while remaining_length >= MIN_MERGE { + result |= remaining_length & 1; + remaining_length >>= 1; + } + + remaining_length + result +} + +/// Merges two sorted subarrays into a single sorted subarray. +/// +/// This function merges two sorted subarrays of the provided slice into a single sorted subarray. +/// +/// # Arguments +/// +/// * `arr` - The slice containing the subarrays to be merged. +/// * `left` - The starting index of the first subarray. +/// * `mid` - The ending index of the first subarray. +/// * `right` - The ending index of the second subarray. +fn merge(arr: &mut [T], left: usize, mid: usize, right: usize) { + let left_slice = arr[left..=mid].to_vec(); + let right_slice = arr[mid + 1..=right].to_vec(); + let mut i = 0; + let mut j = 0; + let mut k = left; + + while i < left_slice.len() && j < right_slice.len() { + if left_slice[i] <= right_slice[j] { + arr[k] = left_slice[i]; + i += 1; + } else { + arr[k] = right_slice[j]; + j += 1; + } + k += 1; + } + + // Copy any remaining elements from the left subarray + while i < left_slice.len() { + arr[k] = left_slice[i]; + k += 1; + i += 1; + } + + // Copy any remaining elements from the right subarray + while j < right_slice.len() { + arr[k] = right_slice[j]; + k += 1; + j += 1; + } +} + +/// Sorts a slice using Tim sort algorithm. +/// +/// This function sorts the provided slice in-place using the Tim sort algorithm. +/// +/// # Arguments +/// +/// * `arr` - The slice to be sorted. +pub fn tim_sort(arr: &mut [T]) { + let n = arr.len(); + let min_run = compute_min_run_length(MIN_MERGE); + + // Perform insertion sort on small subarrays + let mut i = 0; + while i < n { + insertion_sort(&mut arr[i..cmp::min(i + MIN_MERGE, n)]); + i += min_run; + } + + // Merge sorted subarrays + let mut size = min_run; + while size < n { + let mut left = 0; + while left < n { + let mid = left + size - 1; + let right = cmp::min(left + 2 * size - 1, n - 1); + if mid < right { + merge(arr, left, mid, right); + } + + left += 2 * size; + } + size *= 2; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::{have_same_elements, is_sorted}; + + #[test] + fn min_run_length_returns_correct_value() { + assert_eq!(compute_min_run_length(0), 0); + assert_eq!(compute_min_run_length(10), 10); + assert_eq!(compute_min_run_length(33), 17); + assert_eq!(compute_min_run_length(64), 16); + } + + macro_rules! test_merge { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (input_arr, l, m, r, expected) = $inputs; + let mut arr = input_arr.clone(); + merge(&mut arr, l, m, r); + assert_eq!(arr, expected); + } + )* + } + } + + test_merge! { + left_and_right_subarrays_into_array: (vec![0, 2, 4, 1, 3, 5], 0, 2, 5, vec![0, 1, 2, 3, 4, 5]), + with_empty_left_subarray: (vec![1, 2, 3], 0, 0, 2, vec![1, 2, 3]), + with_empty_right_subarray: (vec![1, 2, 3], 0, 2, 2, vec![1, 2, 3]), + with_empty_left_and_right_subarrays: (vec![1, 2, 3], 1, 0, 0, vec![1, 2, 3]), + } + + macro_rules! test_tim_sort { + ($($name:ident: $input:expr,)*) => { + $( + #[test] + fn $name() { + let mut array = $input; + let cloned = array.clone(); + tim_sort(&mut array); + assert!(is_sorted(&array) && have_same_elements(&array, &cloned)); + } + )* + } + } + + test_tim_sort! { + sorts_basic_array_correctly: vec![-2, 7, 15, -14, 0, 15, 0, 7, -7, -4, -13, 5, 8, -14, 12], + sorts_long_array_correctly: vec![-2, 7, 15, -14, 0, 15, 0, 7, -7, -4, -13, 5, 8, -14, 12, 5, 3, 9, 22, 1, 1, 2, 3, 9, 6, 5, 4, 5, 6, 7, 8, 9, 1], + handles_empty_array: Vec::::new(), + handles_single_element_array: vec![3], + handles_pre_sorted_array: vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + } +} diff --git a/src/sorting/tree_sort.rs b/src/sorting/tree_sort.rs new file mode 100644 index 00000000000..a04067a5835 --- /dev/null +++ b/src/sorting/tree_sort.rs @@ -0,0 +1,114 @@ +// Author : cyrixninja +// Tree Sort Algorithm +// https://en.wikipedia.org/wiki/Tree_sort +// Wikipedia :A tree sort is a sort algorithm that builds a binary search tree from the elements to be sorted, and then traverses the tree (in-order) so that the elements come out in sorted order. +// Its typical use is sorting elements online: after each insertion, the set of elements seen so far is available in sorted order. + +struct TreeNode { + value: T, + left: Option>>, + right: Option>>, +} + +impl TreeNode { + fn new(value: T) -> Self { + TreeNode { + value, + left: None, + right: None, + } + } +} + +struct BinarySearchTree { + root: Option>>, +} + +impl BinarySearchTree { + fn new() -> Self { + BinarySearchTree { root: None } + } + + fn insert(&mut self, value: T) { + self.root = Some(Self::insert_recursive(self.root.take(), value)); + } + + fn insert_recursive(root: Option>>, value: T) -> Box> { + match root { + None => Box::new(TreeNode::new(value)), + Some(mut node) => { + if value <= node.value { + node.left = Some(Self::insert_recursive(node.left.take(), value)); + } else { + node.right = Some(Self::insert_recursive(node.right.take(), value)); + } + node + } + } + } + + fn in_order_traversal(&self, result: &mut Vec) { + Self::in_order_recursive(&self.root, result); + } + + fn in_order_recursive(root: &Option>>, result: &mut Vec) { + if let Some(node) = root { + Self::in_order_recursive(&node.left, result); + result.push(node.value.clone()); + Self::in_order_recursive(&node.right, result); + } + } +} + +pub fn tree_sort(arr: &mut Vec) { + let mut tree = BinarySearchTree::new(); + + for elem in arr.iter().cloned() { + tree.insert(elem); + } + + let mut result = Vec::new(); + tree.in_order_traversal(&mut result); + + *arr = result; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_array() { + let mut arr: Vec = vec![]; + tree_sort(&mut arr); + assert_eq!(arr, vec![]); + } + + #[test] + fn test_single_element() { + let mut arr = vec![8]; + tree_sort(&mut arr); + assert_eq!(arr, vec![8]); + } + + #[test] + fn test_already_sorted() { + let mut arr = vec![1, 2, 3, 4, 5]; + tree_sort(&mut arr); + assert_eq!(arr, vec![1, 2, 3, 4, 5]); + } + + #[test] + fn test_reverse_sorted() { + let mut arr = vec![5, 4, 3, 2, 1]; + tree_sort(&mut arr); + assert_eq!(arr, vec![1, 2, 3, 4, 5]); + } + + #[test] + fn test_random() { + let mut arr = vec![9, 6, 10, 11, 2, 19]; + tree_sort(&mut arr); + assert_eq!(arr, vec![2, 6, 9, 10, 11, 19]); + } +} diff --git a/src/sorting/wave_sort.rs b/src/sorting/wave_sort.rs new file mode 100644 index 00000000000..06e2e3dec97 --- /dev/null +++ b/src/sorting/wave_sort.rs @@ -0,0 +1,70 @@ +/// Wave Sort Algorithm +/// +/// Wave Sort is a sorting algorithm that works in O(n log n) time assuming +/// the sort function used works in O(n log n) time. +/// It arranges elements in an array into a sequence where every alternate +/// element is either greater or smaller than its adjacent elements. +/// +/// Reference: +/// [Wave Sort Algorithm - GeeksforGeeks](https://www.geeksforgeeks.org/sort-array-wave-form-2/) +/// +/// # Examples +/// +/// use the_algorithms_rust::sorting::wave_sort; +/// let array = vec![10, 90, 49, 2, 1, 5, 23]; +/// let result = wave_sort(array); +/// // Result: [2, 1, 10, 5, 49, 23, 90] +/// +pub fn wave_sort(arr: &mut [T]) { + let n = arr.len(); + arr.sort(); + + for i in (0..n - 1).step_by(2) { + arr.swap(i, i + 1); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_case_1() { + let mut array = vec![10, 90, 49, 2, 1, 5, 23]; + wave_sort(&mut array); + let expected = vec![2, 1, 10, 5, 49, 23, 90]; + assert_eq!(&array, &expected); + } + + #[test] + fn test_case_2() { + let mut array = vec![1, 3, 4, 2, 7, 8]; + wave_sort(&mut array); + let expected = vec![2, 1, 4, 3, 8, 7]; + assert_eq!(&array, &expected); + } + + #[test] + fn test_case_3() { + let mut array = vec![3, 3, 3, 3]; + wave_sort(&mut array); + let expected = vec![3, 3, 3, 3]; + assert_eq!(&array, &expected); + } + + #[test] + fn test_case_4() { + let mut array = vec![9, 4, 6, 8, 14, 3]; + wave_sort(&mut array); + let expected = vec![4, 3, 8, 6, 14, 9]; + assert_eq!(&array, &expected); + } + + #[test] + fn test_case_5() { + let mut array = vec![5, 10, 15, 20, 25]; + wave_sort(&mut array); + let expected = vec![10, 5, 20, 15, 25]; + assert_eq!(&array, &expected); + } +} diff --git a/src/sorting/wiggle_sort.rs b/src/sorting/wiggle_sort.rs new file mode 100644 index 00000000000..7f1bf1bf921 --- /dev/null +++ b/src/sorting/wiggle_sort.rs @@ -0,0 +1,80 @@ +//Wiggle Sort. +//Given an unsorted array nums, reorder it such +//that nums[0] < nums[1] > nums[2] < nums[3].... +//For example: +//if input numbers = [3, 5, 2, 1, 6, 4] +//one possible Wiggle Sorted answer is [3, 5, 1, 6, 2, 4]. + +pub fn wiggle_sort(nums: &mut Vec) -> &mut Vec { + //Rust implementation of wiggle. + // Example: + // >>> wiggle_sort([0, 5, 3, 2, 2]) + // [0, 5, 2, 3, 2] + // >>> wiggle_sort([]) + // [] + // >>> wiggle_sort([-2, -5, -45]) + // [-45, -2, -5] + + let len = nums.len(); + for i in 1..len { + let num_x = nums[i - 1]; + let num_y = nums[i]; + if (i % 2 == 1) == (num_x > num_y) { + nums[i - 1] = num_y; + nums[i] = num_x; + } + } + nums +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + + fn is_wiggle_sorted(nums: &[i32]) -> bool { + if nums.is_empty() { + return true; + } + let mut previous = nums[0]; + let mut result = true; + nums.iter().enumerate().skip(1).for_each(|(i, &item)| { + if i != 0 { + result = + result && ((i % 2 == 1 && previous < item) || (i % 2 == 0 && previous > item)); + } + + previous = item; + }); + result + } + + #[test] + fn wingle_elements() { + let arr = vec![3, 5, 2, 1, 6, 4]; + let mut cloned = arr.clone(); + let res = wiggle_sort(&mut cloned); + assert!(is_wiggle_sorted(res)); + assert!(have_same_elements(res, &arr)); + } + + #[test] + fn odd_number_of_elements() { + let arr = vec![4, 1, 3, 5, 2]; + let mut cloned = arr.clone(); + let res = wiggle_sort(&mut cloned); + assert!(is_wiggle_sorted(res)); + assert!(have_same_elements(res, &arr)); + } + + #[test] + fn repeated_elements() { + let arr = vec![5, 5, 5, 5]; + let mut cloned = arr.clone(); + let res = wiggle_sort(&mut cloned); + + // Negative test, can't be wiggle sorted + assert!(!is_wiggle_sorted(res)); + assert!(have_same_elements(res, &arr)); + } +} diff --git a/src/string/README.md b/src/string/README.md index 4f55bf3f003..85addca949f 100644 --- a/src/string/README.md +++ b/src/string/README.md @@ -1,5 +1,20 @@ ## String Algorithms +### [Aho-Corasick Algorithm](./aho_corasick.rs) +From [Wikipedia][aho-corasick-wiki]: a string-searching algorithm invented by Alfred V. Aho and Margaret J. Corasick in 1975.[1] It is a kind of dictionary-matching algorithm that locates elements of a finite set of strings (the "dictionary") within an input text. It matches all strings simultaneously. + +[aho-corasick-wiki]: https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm + + +### [Burrows-Wheeler transform](./burrows_wheeler_transform.rs) +From [Wikipedia][burrows-wheeler-wiki]: The Burrows–Wheeler transform (BWT, also called block-sorting compression) rearranges a character string into runs of similar characters. This is useful for compression, since it tends to be easy to compress a string that has runs of repeated characters by techniques such as move-to-front transform and run-length encoding. More importantly, the transformation is reversible, without needing to store any additional data except the position of the first original character. The BWT is thus a "free" method of improving the efficiency of text compression algorithms, costing only some extra computation. + +__Properties__ +* Worst-case performance O(n) + +[burrows-wheeler-wiki]: https://en.wikipedia.org/wiki/Burrows%E2%80%93Wheeler_transform + + ### [Knuth Morris Pratt](./knuth_morris_pratt.rs) From [Wikipedia][kmp-wiki]: searches for occurrences of a "word" W within a main "text string" S by employing the observation that when a mismatch occurs, the word itself embodies sufficient information to determine where the next match could begin, thus bypassing re-examination of previously matched characters. Knuth Morris Pratt search runs in linear time in the length of W and S. @@ -9,3 +24,32 @@ __Properties__ * Case space complexity O(w) [kmp-wiki]: https://en.wikipedia.org/wiki/Knuth–Morris–Pratt_algorithm + + + +### [Manacher](./manacher.rs) +From [Wikipedia][manacher-wiki]: find a longest palindrome in a string in linear time. + +__Properties__ +* Worst-case time complexity is O(n) +* Worst-case space complexity is O(n) + +[manacher-wiki]: https://en.wikipedia.org/wiki/Longest_palindromic_substring#Manacher's_algorithm + + +### [Rabin Karp](./rabin_karp.rs) +From [Wikipedia][rabin-karp-wiki]: a string-searching algorithm created by Richard M. Karp and Michael O. Rabin that uses hashing +to find an exact match of a pattern string in a text. + +[rabin-karp-wiki]: https://en.wikipedia.org/wiki/Rabin%E2%80%93Karp_algorithm + + +### [Hamming Distance](./hamming_distance.rs) +From [Wikipedia][hamming-distance-wiki]: In information theory, the Hamming distance between two strings of equal length is the number of positions at which the corresponding symbols are different. In other words, it measures the minimum number of substitutions required to change one string into the other, or the minimum number of errors that could have transformed one string into the other. In a more general context, the Hamming distance is one of several string metrics for measuring the edit distance between two sequences. It is named after the American mathematician Richard Hamming. + +[run-length-encoding-wiki]: https://en.wikipedia.org/wiki/Run-length_encoding + +### [Run Length Encoding](./run_length_encoding.rs) +From [Wikipedia][run-length-encoding-wiki]: a form of lossless data compression in which runs of data (sequences in which the same data value occurs in many consecutive data elements) are stored as a single data value and count, rather than as the original run. + +[hamming-distance-wiki]: https://en.wikipedia.org/wiki/Hamming_distance diff --git a/src/string/aho_corasick.rs b/src/string/aho_corasick.rs new file mode 100644 index 00000000000..02c6f7cdccc --- /dev/null +++ b/src/string/aho_corasick.rs @@ -0,0 +1,132 @@ +use std::cell::RefCell; +use std::collections::BTreeMap; +use std::collections::VecDeque; +use std::rc::{Rc, Weak}; + +#[derive(Default)] +struct ACNode { + trans: BTreeMap>>, + suffix: Weak>, // the suffix(fail) link + lengths: Vec, // lengths of matched patterns ended at this node +} + +#[derive(Default)] +pub struct AhoCorasick { + root: Rc>, +} + +impl AhoCorasick { + pub fn new(words: &[&str]) -> Self { + let root = Rc::new(RefCell::new(ACNode::default())); + for word in words { + let mut cur = Rc::clone(&root); + for c in word.chars() { + cur = Rc::clone(Rc::clone(&cur).borrow_mut().trans.entry(c).or_default()); + } + cur.borrow_mut().lengths.push(word.len()); + } + Self::build_suffix(Rc::clone(&root)); + Self { root } + } + + fn build_suffix(root: Rc>) { + let mut q = VecDeque::new(); + q.push_back(Rc::clone(&root)); + while let Some(parent) = q.pop_front() { + let parent = parent.borrow(); + for (c, child) in &parent.trans { + q.push_back(Rc::clone(child)); + let mut child = child.borrow_mut(); + let mut suffix = parent.suffix.upgrade(); + loop { + match &suffix { + None => { + child.lengths.extend(root.borrow().lengths.clone()); + child.suffix = Rc::downgrade(&root); + break; + } + Some(node) => { + if node.borrow().trans.contains_key(c) { + let node = &node.borrow().trans[c]; + child.lengths.extend(node.borrow().lengths.clone()); + child.suffix = Rc::downgrade(node); + break; + } + suffix = suffix.unwrap().borrow().suffix.upgrade(); + } + } + } + } + } + } + + pub fn search<'a>(&self, s: &'a str) -> Vec<&'a str> { + let mut ans = vec![]; + let mut cur = Rc::clone(&self.root); + let mut position: usize = 0; + for c in s.chars() { + loop { + if let Some(child) = Rc::clone(&cur).borrow().trans.get(&c) { + cur = Rc::clone(child); + break; + } + let suffix = cur.borrow().suffix.clone(); + match suffix.upgrade() { + Some(node) => cur = node, + None => break, + } + } + position += c.len_utf8(); + for &len in &cur.borrow().lengths { + ans.push(&s[position - len..position]); + } + } + ans + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_aho_corasick() { + let dict = ["abc", "abcd", "xyz", "acxy", "efg", "123", "678", "6543"]; + let ac = AhoCorasick::new(&dict); + let res = ac.search("ababcxyzacxy12678acxy6543"); + assert_eq!(res, ["abc", "xyz", "acxy", "678", "acxy", "6543",]); + } + + #[test] + fn test_aho_corasick_with_utf8() { + let dict = [ + "abc", + "中文", + "abc中", + "abcd", + "xyz", + "acxy", + "efg", + "123", + "678", + "6543", + "ハンバーガー", + ]; + let ac = AhoCorasick::new(&dict); + let res = ac.search("ababc中xyzacxy12678acxyハンバーガー6543中文"); + assert_eq!( + res, + [ + "abc", + "abc中", + "xyz", + "acxy", + "678", + "acxy", + "ハンバーガー", + "6543", + "中文" + ] + ); + } +} diff --git a/src/string/anagram.rs b/src/string/anagram.rs new file mode 100644 index 00000000000..9ea37dc4f6f --- /dev/null +++ b/src/string/anagram.rs @@ -0,0 +1,111 @@ +use std::collections::HashMap; + +/// Custom error type representing an invalid character found in the input. +#[derive(Debug, PartialEq)] +pub enum AnagramError { + NonAlphabeticCharacter, +} + +/// Checks if two strings are anagrams, ignoring spaces and case sensitivity. +/// +/// # Arguments +/// +/// * `s` - First input string. +/// * `t` - Second input string. +/// +/// # Returns +/// +/// * `Ok(true)` if the strings are anagrams. +/// * `Ok(false)` if the strings are not anagrams. +/// * `Err(AnagramError)` if either string contains non-alphabetic characters. +pub fn check_anagram(s: &str, t: &str) -> Result { + let s_cleaned = clean_string(s)?; + let t_cleaned = clean_string(t)?; + + Ok(char_count(&s_cleaned) == char_count(&t_cleaned)) +} + +/// Cleans the input string by removing spaces and converting to lowercase. +/// Returns an error if any non-alphabetic character is found. +/// +/// # Arguments +/// +/// * `s` - Input string to clean. +/// +/// # Returns +/// +/// * `Ok(String)` containing the cleaned string (no spaces, lowercase). +/// * `Err(AnagramError)` if the string contains non-alphabetic characters. +fn clean_string(s: &str) -> Result { + s.chars() + .filter(|c| !c.is_whitespace()) + .map(|c| { + if c.is_alphabetic() { + Ok(c.to_ascii_lowercase()) + } else { + Err(AnagramError::NonAlphabeticCharacter) + } + }) + .collect() +} + +/// Computes the histogram of characters in a string. +/// +/// # Arguments +/// +/// * `s` - Input string. +/// +/// # Returns +/// +/// * A `HashMap` where the keys are characters and values are their count. +fn char_count(s: &str) -> HashMap { + let mut res = HashMap::new(); + for c in s.chars() { + *res.entry(c).or_insert(0) += 1; + } + res +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (s, t, expected) = $test_case; + assert_eq!(check_anagram(s, t), expected); + assert_eq!(check_anagram(t, s), expected); + } + )* + } + } + + test_cases! { + empty_strings: ("", "", Ok(true)), + empty_and_non_empty: ("", "Ted Morgan", Ok(false)), + single_char_same: ("z", "Z", Ok(true)), + single_char_diff: ("g", "h", Ok(false)), + valid_anagram_lowercase: ("cheater", "teacher", Ok(true)), + valid_anagram_with_spaces: ("madam curie", "radium came", Ok(true)), + valid_anagram_mixed_cases: ("Satan", "Santa", Ok(true)), + valid_anagram_with_spaces_and_mixed_cases: ("Anna Madrigal", "A man and a girl", Ok(true)), + new_york_times: ("New York Times", "monkeys write", Ok(true)), + church_of_scientology: ("Church of Scientology", "rich chosen goofy cult", Ok(true)), + mcdonalds_restaurants: ("McDonald's restaurants", "Uncle Sam's standard rot", Err(AnagramError::NonAlphabeticCharacter)), + coronavirus: ("coronavirus", "carnivorous", Ok(true)), + synonym_evil: ("evil", "vile", Ok(true)), + synonym_gentleman: ("a gentleman", "elegant man", Ok(true)), + antigram: ("restful", "fluster", Ok(true)), + sentences: ("William Shakespeare", "I am a weakish speller", Ok(true)), + part_of_speech_adj_to_verb: ("silent", "listen", Ok(true)), + anagrammatized: ("Anagrams", "Ars magna", Ok(true)), + non_anagram: ("rat", "car", Ok(false)), + invalid_anagram_with_special_char: ("hello!", "world", Err(AnagramError::NonAlphabeticCharacter)), + invalid_anagram_with_numeric_chars: ("test123", "321test", Err(AnagramError::NonAlphabeticCharacter)), + invalid_anagram_with_symbols: ("check@anagram", "check@nagaram", Err(AnagramError::NonAlphabeticCharacter)), + non_anagram_length_mismatch: ("abc", "abcd", Ok(false)), + } +} diff --git a/src/string/autocomplete_using_trie.rs b/src/string/autocomplete_using_trie.rs new file mode 100644 index 00000000000..630b6e1dd79 --- /dev/null +++ b/src/string/autocomplete_using_trie.rs @@ -0,0 +1,125 @@ +/* + It autocomplete by prefix using added words. + + word List => ["apple", "orange", "oregano"] + prefix => "or" + matches => ["orange", "oregano"] +*/ + +use std::collections::HashMap; + +const END: char = '#'; + +#[derive(Debug)] +struct Trie(HashMap>); + +impl Trie { + fn new() -> Self { + Trie(HashMap::new()) + } + + fn insert(&mut self, text: &str) { + let mut trie = self; + + for c in text.chars() { + trie = trie.0.entry(c).or_insert_with(|| Box::new(Trie::new())); + } + + trie.0.insert(END, Box::new(Trie::new())); + } + + fn find(&self, prefix: &str) -> Vec { + let mut trie = self; + + for c in prefix.chars() { + let char_trie = trie.0.get(&c); + if let Some(char_trie) = char_trie { + trie = char_trie; + } else { + return vec![]; + } + } + + Self::_elements(trie) + .iter() + .map(|s| prefix.to_owned() + s) + .collect() + } + + fn _elements(map: &Trie) -> Vec { + let mut results = vec![]; + + for (c, v) in map.0.iter() { + let mut sub_result = vec![]; + if c == &END { + sub_result.push("".to_owned()) + } else { + Self::_elements(v) + .iter() + .map(|s| sub_result.push(c.to_string() + s)) + .collect() + } + + results.extend(sub_result) + } + + results + } +} + +pub struct Autocomplete { + trie: Trie, +} + +impl Autocomplete { + fn new() -> Self { + Self { trie: Trie::new() } + } + + pub fn insert_words>(&mut self, words: &[T]) { + for word in words { + self.trie.insert(word.as_ref()); + } + } + + pub fn find_words(&self, prefix: &str) -> Vec { + self.trie.find(prefix) + } +} + +impl Default for Autocomplete { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::Autocomplete; + + #[test] + fn test_autocomplete() { + let words = vec!["apple", "orange", "oregano"]; + + let mut auto_complete = Autocomplete::new(); + auto_complete.insert_words(&words); + + let prefix = "app"; + let mut auto_completed_words = auto_complete.find_words(prefix); + + let mut apple = vec!["apple"]; + apple.sort(); + + auto_completed_words.sort(); + assert_eq!(auto_completed_words, apple); + + let prefix = "or"; + let mut auto_completed_words = auto_complete.find_words(prefix); + + let mut prefix_or = vec!["orange", "oregano"]; + prefix_or.sort(); + + auto_completed_words.sort(); + assert_eq!(auto_completed_words, prefix_or); + } +} diff --git a/src/string/boyer_moore_search.rs b/src/string/boyer_moore_search.rs new file mode 100644 index 00000000000..e9c46a8c980 --- /dev/null +++ b/src/string/boyer_moore_search.rs @@ -0,0 +1,161 @@ +//! This module implements the Boyer-Moore string search algorithm, an efficient method +//! for finding all occurrences of a pattern within a given text. The algorithm skips +//! sections of the text by leveraging two key rules: the bad character rule and the +//! good suffix rule (only the bad character rule is implemented here for simplicity). + +use std::collections::HashMap; + +/// Builds the bad character table for the Boyer-Moore algorithm. +/// This table stores the last occurrence of each character in the pattern. +/// +/// # Arguments +/// * `pat` - The pattern as a slice of characters. +/// +/// # Returns +/// A `HashMap` where the keys are characters from the pattern and the values are their +/// last known positions within the pattern. +fn build_bad_char_table(pat: &[char]) -> HashMap { + let mut bad_char_table = HashMap::new(); + for (i, &ch) in pat.iter().enumerate() { + bad_char_table.insert(ch, i as isize); + } + bad_char_table +} + +/// Calculates the shift when a full match occurs in the Boyer-Moore algorithm. +/// It uses the bad character table to determine how much to shift the pattern. +/// +/// # Arguments +/// * `shift` - The current shift of the pattern on the text. +/// * `pat_len` - The length of the pattern. +/// * `text_len` - The length of the text. +/// * `bad_char_table` - The bad character table built for the pattern. +/// * `text` - The text as a slice of characters. +/// +/// # Returns +/// The number of positions to shift the pattern after a match. +fn calc_match_shift( + shift: isize, + pat_len: isize, + text_len: isize, + bad_char_table: &HashMap, + text: &[char], +) -> isize { + if shift + pat_len >= text_len { + return 1; + } + let next_ch = text[(shift + pat_len) as usize]; + pat_len - bad_char_table.get(&next_ch).unwrap_or(&-1) +} + +/// Calculates the shift when a mismatch occurs in the Boyer-Moore algorithm. +/// The bad character rule is used to determine how far to shift the pattern. +/// +/// # Arguments +/// * `mis_idx` - The mismatch index in the pattern. +/// * `shift` - The current shift of the pattern on the text. +/// * `text` - The text as a slice of characters. +/// * `bad_char_table` - The bad character table built for the pattern. +/// +/// # Returns +/// The number of positions to shift the pattern after a mismatch. +fn calc_mismatch_shift( + mis_idx: isize, + shift: isize, + text: &[char], + bad_char_table: &HashMap, +) -> isize { + let mis_ch = text[(shift + mis_idx) as usize]; + let bad_char_shift = bad_char_table.get(&mis_ch).unwrap_or(&-1); + std::cmp::max(1, mis_idx - bad_char_shift) +} + +/// Performs the Boyer-Moore string search algorithm, which searches for all +/// occurrences of a pattern within a text. +/// +/// The Boyer-Moore algorithm is efficient for large texts and patterns, as it +/// skips sections of the text based on the bad character rule and other optimizations. +/// +/// # Arguments +/// * `text` - The text to search within as a string slice. +/// * `pat` - The pattern to search for as a string slice. +/// +/// # Returns +/// A vector of starting indices where the pattern occurs in the text. +pub fn boyer_moore_search(text: &str, pat: &str) -> Vec { + let mut positions = Vec::new(); + + let text_len = text.len() as isize; + let pat_len = pat.len() as isize; + + // Handle edge cases where the text or pattern is empty, or the pattern is longer than the text + if text_len == 0 || pat_len == 0 || pat_len > text_len { + return positions; + } + + // Convert text and pattern to character vectors for easier indexing + let pat: Vec = pat.chars().collect(); + let text: Vec = text.chars().collect(); + + // Build the bad character table for the pattern + let bad_char_table = build_bad_char_table(&pat); + + let mut shift = 0; + + // Main loop: shift the pattern over the text + while shift <= text_len - pat_len { + let mut j = pat_len - 1; + + // Compare pattern from right to left + while j >= 0 && pat[j as usize] == text[(shift + j) as usize] { + j -= 1; + } + + // If we found a match (j < 0), record the position + if j < 0 { + positions.push(shift as usize); + shift += calc_match_shift(shift, pat_len, text_len, &bad_char_table, &text); + } else { + // If mismatch, calculate how far to shift based on the bad character rule + shift += calc_mismatch_shift(j, shift, &text, &bad_char_table); + } + } + + positions +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! boyer_moore_tests { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (text, pattern, expected) = $tc; + assert_eq!(boyer_moore_search(text, pattern), expected); + } + )* + }; + } + + boyer_moore_tests! { + test_simple_match: ("AABCAB12AFAABCABFFEGABCAB", "ABCAB", vec![1, 11, 20]), + test_no_match: ("AABCAB12AFAABCABFFEGABCAB", "FFF", vec![]), + test_partial_match: ("AABCAB12AFAABCABFFEGABCAB", "CAB", vec![3, 13, 22]), + test_empty_text: ("", "A", vec![]), + test_empty_pattern: ("ABC", "", vec![]), + test_both_empty: ("", "", vec![]), + test_pattern_longer_than_text: ("ABC", "ABCDEFG", vec![]), + test_single_character_text: ("A", "A", vec![0]), + test_single_character_pattern: ("AAAA", "A", vec![0, 1, 2, 3]), + test_case_sensitivity: ("ABCabcABC", "abc", vec![3]), + test_overlapping_patterns: ("AAAAA", "AAA", vec![0, 1, 2]), + test_special_characters: ("@!#$$%^&*", "$$", vec![3]), + test_numerical_pattern: ("123456789123456", "456", vec![3, 12]), + test_partial_overlap_no_match: ("ABCD", "ABCDE", vec![]), + test_single_occurrence: ("XXXXXXXXXXXXXXXXXXPATTERNXXXXXXXXXXXXXXXXXX", "PATTERN", vec![18]), + test_single_occurrence_with_noise: ("PATPATPATPATTERNPAT", "PATTERN", vec![9]), + } +} diff --git a/src/string/burrows_wheeler_transform.rs b/src/string/burrows_wheeler_transform.rs new file mode 100644 index 00000000000..3ecef7b9ab3 --- /dev/null +++ b/src/string/burrows_wheeler_transform.rs @@ -0,0 +1,117 @@ +pub fn burrows_wheeler_transform(input: &str) -> (String, usize) { + let len = input.len(); + + let mut table = Vec::::with_capacity(len); + for i in 0..len { + table.push(input[i..].to_owned() + &input[..i]); + } + table.sort_by_key(|a| a.to_lowercase()); + + let mut encoded = String::new(); + let mut index: usize = 0; + for (i, item) in table.iter().enumerate().take(len) { + encoded.push(item.chars().last().unwrap()); + if item.eq(&input) { + index = i; + } + } + + (encoded, index) +} + +pub fn inv_burrows_wheeler_transform>(input: (T, usize)) -> String { + let len = input.0.as_ref().len(); + let mut table = Vec::<(usize, char)>::with_capacity(len); + for i in 0..len { + table.push((i, input.0.as_ref().chars().nth(i).unwrap())); + } + + table.sort_by(|a, b| a.1.cmp(&b.1)); + + let mut decoded = String::new(); + let mut idx = input.1; + for _ in 0..len { + decoded.push(table[idx].1); + idx = table[idx].0; + } + + decoded +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + //Ensure function stand-alone legitimacy + fn stand_alone_function() { + assert_eq!( + burrows_wheeler_transform("CARROT"), + ("CTRRAO".to_owned(), 1usize) + ); + assert_eq!(inv_burrows_wheeler_transform(("CTRRAO", 1usize)), "CARROT"); + assert_eq!( + burrows_wheeler_transform("THEALGORITHMS"), + ("EHLTTRAHGOMSI".to_owned(), 11usize) + ); + assert_eq!( + inv_burrows_wheeler_transform(("EHLTTRAHGOMSI".to_string(), 11usize)), + "THEALGORITHMS" + ); + assert_eq!( + burrows_wheeler_transform("!.!.!??.=::"), + (":..!!?:=.?!".to_owned(), 0usize) + ); + assert_eq!( + inv_burrows_wheeler_transform((":..!!?:=.?!", 0usize)), + "!.!.!??.=::" + ); + } + #[test] + fn basic_characters() { + assert_eq!( + inv_burrows_wheeler_transform(burrows_wheeler_transform("CARROT")), + "CARROT" + ); + assert_eq!( + inv_burrows_wheeler_transform(burrows_wheeler_transform("TOMATO")), + "TOMATO" + ); + assert_eq!( + inv_burrows_wheeler_transform(burrows_wheeler_transform("THISISATEST")), + "THISISATEST" + ); + assert_eq!( + inv_burrows_wheeler_transform(burrows_wheeler_transform("THEALGORITHMS")), + "THEALGORITHMS" + ); + assert_eq!( + inv_burrows_wheeler_transform(burrows_wheeler_transform("RUST")), + "RUST" + ); + } + + #[test] + fn special_characters() { + assert_eq!( + inv_burrows_wheeler_transform(burrows_wheeler_transform("!.!.!??.=::")), + "!.!.!??.=::" + ); + assert_eq!( + inv_burrows_wheeler_transform(burrows_wheeler_transform("!{}{}(((&&%%!??.=::")), + "!{}{}(((&&%%!??.=::" + ); + assert_eq!( + inv_burrows_wheeler_transform(burrows_wheeler_transform("//&$[]")), + "//&$[]" + ); + } + + #[test] + fn empty() { + assert_eq!( + inv_burrows_wheeler_transform(burrows_wheeler_transform("")), + "" + ); + } +} diff --git a/src/string/duval_algorithm.rs b/src/string/duval_algorithm.rs new file mode 100644 index 00000000000..69e9dbff2a9 --- /dev/null +++ b/src/string/duval_algorithm.rs @@ -0,0 +1,97 @@ +//! Implementation of Duval's Algorithm to compute the standard factorization of a string +//! into Lyndon words. A Lyndon word is defined as a string that is strictly smaller +//! (lexicographically) than any of its nontrivial suffixes. This implementation operates +//! in linear time and space. + +/// Performs Duval's algorithm to factorize a given string into its Lyndon words. +/// +/// # Arguments +/// +/// * `s` - A slice of characters representing the input string. +/// +/// # Returns +/// +/// A vector of strings, where each string is a Lyndon word, representing the factorization +/// of the input string. +/// +/// # Time Complexity +/// +/// The algorithm runs in O(n) time, where `n` is the length of the input string. +pub fn duval_algorithm(s: &str) -> Vec { + factorize_duval(&s.chars().collect::>()) +} + +/// Helper function that takes a string slice, converts it to a vector of characters, +/// and then applies the Duval factorization algorithm to find the Lyndon words. +/// +/// # Arguments +/// +/// * `s` - A string slice representing the input text. +/// +/// # Returns +/// +/// A vector of strings, each representing a Lyndon word in the factorization. +fn factorize_duval(s: &[char]) -> Vec { + let mut start = 0; + let mut factors: Vec = Vec::new(); + + while start < s.len() { + let mut end = start + 1; + let mut repeat = start; + + while end < s.len() && s[repeat] <= s[end] { + if s[repeat] < s[end] { + repeat = start; + } else { + repeat += 1; + } + end += 1; + } + + while start <= repeat { + factors.push(s[start..start + end - repeat].iter().collect::()); + start += end - repeat; + } + } + + factors +} + +#[cfg(test)] +mod test { + use super::*; + + macro_rules! test_duval_algorithm { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (text, expected) = $inputs; + assert_eq!(duval_algorithm(text), expected); + } + )* + } + } + + test_duval_algorithm! { + repeating_with_suffix: ("abcdabcdababc", ["abcd", "abcd", "ababc"]), + single_repeating_char: ("aaa", ["a", "a", "a"]), + single: ("ababb", ["ababb"]), + unicode: ("അഅഅ", ["അ", "അ", "അ"]), + empty_string: ("", Vec::::new()), + single_char: ("x", ["x"]), + palindrome: ("racecar", ["r", "acecar"]), + long_repeating: ("aaaaaa", ["a", "a", "a", "a", "a", "a"]), + mixed_repeating: ("ababcbabc", ["ababcbabc"]), + non_repeating_sorted: ("abcdefg", ["abcdefg"]), + alternating_increasing: ("abababab", ["ab", "ab", "ab", "ab"]), + long_repeating_lyndon: ("abcabcabcabc", ["abc", "abc", "abc", "abc"]), + decreasing_order: ("zyxwvutsrqponm", ["z", "y", "x", "w", "v", "u", "t", "s", "r", "q", "p", "o", "n", "m"]), + alphanumeric_mixed: ("a1b2c3a1", ["a", "1b2c3a", "1"]), + special_characters: ("a@b#c$d", ["a", "@b", "#c$d"]), + unicode_complex: ("αβγδ", ["αβγδ"]), + long_string_performance: (&"a".repeat(1_000_000), vec!["a"; 1_000_000]), + palindrome_repeating_prefix: ("abccba", ["abccb", "a"]), + interrupted_lyndon: ("abcxabc", ["abcx", "abc"]), + } +} diff --git a/src/string/hamming_distance.rs b/src/string/hamming_distance.rs new file mode 100644 index 00000000000..3137d5abc7c --- /dev/null +++ b/src/string/hamming_distance.rs @@ -0,0 +1,57 @@ +/// Error type for Hamming distance calculation. +#[derive(Debug, PartialEq)] +pub enum HammingDistanceError { + InputStringsHaveDifferentLength, +} + +/// Calculates the Hamming distance between two strings. +/// +/// The Hamming distance is defined as the number of positions at which the corresponding characters of the two strings are different. +pub fn hamming_distance(string_a: &str, string_b: &str) -> Result { + if string_a.len() != string_b.len() { + return Err(HammingDistanceError::InputStringsHaveDifferentLength); + } + + let distance = string_a + .chars() + .zip(string_b.chars()) + .filter(|(a, b)| a != b) + .count(); + + Ok(distance) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_hamming_distance { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (str_a, str_b, expected) = $tc; + assert_eq!(hamming_distance(str_a, str_b), expected); + assert_eq!(hamming_distance(str_b, str_a), expected); + } + )* + } + } + + test_hamming_distance! { + empty_inputs: ("", "", Ok(0)), + different_length: ("0", "", Err(HammingDistanceError::InputStringsHaveDifferentLength)), + length_1_inputs_identical: ("a", "a", Ok(0)), + length_1_inputs_different: ("a", "b", Ok(1)), + same_strings: ("rust", "rust", Ok(0)), + regular_input_0: ("karolin", "kathrin", Ok(3)), + regular_input_1: ("kathrin", "kerstin", Ok(4)), + regular_input_2: ("00000", "11111", Ok(5)), + different_case: ("x", "X", Ok(1)), + strings_with_no_common_chars: ("abcd", "wxyz", Ok(4)), + long_strings_one_diff: (&"a".repeat(1000), &("a".repeat(999) + "b"), Ok(1)), + long_strings_many_diffs: (&("a".repeat(500) + &"b".repeat(500)), &("b".repeat(500) + &"a".repeat(500)), Ok(1000)), + strings_with_special_chars_identical: ("!@#$%^", "!@#$%^", Ok(0)), + strings_with_special_chars_diff: ("!@#$%^", "&*()_+", Ok(6)), + } +} diff --git a/src/string/isogram.rs b/src/string/isogram.rs new file mode 100644 index 00000000000..30b8d66bdff --- /dev/null +++ b/src/string/isogram.rs @@ -0,0 +1,104 @@ +//! This module provides functionality to check if a given string is an isogram. +//! An isogram is a word or phrase in which no letter occurs more than once. + +use std::collections::HashMap; + +/// Enum representing possible errors that can occur while checking for isograms. +#[derive(Debug, PartialEq, Eq)] +pub enum IsogramError { + /// Indicates that the input contains a non-alphabetic character. + NonAlphabeticCharacter, +} + +/// Counts the occurrences of each alphabetic character in a given string. +/// +/// This function takes a string slice as input. It counts how many times each alphabetic character +/// appears in the input string and returns a hashmap where the keys are characters and the values +/// are their respective counts. +/// +/// # Arguments +/// +/// * `s` - A string slice that contains the input to count characters from. +/// +/// # Errors +/// +/// Returns an error if the input contains non-alphabetic characters (excluding spaces). +/// +/// # Note +/// +/// This function treats uppercase and lowercase letters as equivalent (case-insensitive). +/// Spaces are ignored and do not affect the character count. +fn count_letters(s: &str) -> Result, IsogramError> { + let mut letter_counts = HashMap::new(); + + for ch in s.to_ascii_lowercase().chars() { + if !ch.is_ascii_alphabetic() && !ch.is_whitespace() { + return Err(IsogramError::NonAlphabeticCharacter); + } + + if ch.is_ascii_alphabetic() { + *letter_counts.entry(ch).or_insert(0) += 1; + } + } + + Ok(letter_counts) +} + +/// Checks if the given input string is an isogram. +/// +/// This function takes a string slice as input. It counts the occurrences of each +/// alphabetic character (ignoring case and spaces). +/// +/// # Arguments +/// +/// * `input` - A string slice that contains the input to check for isogram properties. +/// +/// # Return +/// +/// - `Ok(true)` if all characters appear only once, or `Ok(false)` if any character appears more than once. +/// - `Err(IsogramError::NonAlphabeticCharacter) if the input contains any non-alphabetic characters. +pub fn is_isogram(s: &str) -> Result { + let letter_counts = count_letters(s)?; + Ok(letter_counts.values().all(|&count| count == 1)) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! isogram_tests { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $tc; + assert_eq!(is_isogram(input), expected); + } + )* + }; + } + + isogram_tests! { + isogram_simple: ("isogram", Ok(true)), + isogram_case_insensitive: ("Isogram", Ok(true)), + isogram_with_spaces: ("a b c d e", Ok(true)), + isogram_mixed: ("Dermatoglyphics", Ok(true)), + isogram_long: ("Subdermatoglyphic", Ok(true)), + isogram_german_city: ("Malitzschkendorf", Ok(true)), + perfect_pangram: ("Cwm fjord bank glyphs vext quiz", Ok(true)), + isogram_sentences: ("The big dwarf only jumps", Ok(true)), + isogram_french: ("Lampez un fort whisky", Ok(true)), + isogram_portuguese: ("Velho traduz sim", Ok(true)), + isogram_spanis: ("Centrifugadlos", Ok(true)), + invalid_isogram_with_repeated_char: ("hello", Ok(false)), + invalid_isogram_with_numbers: ("abc123", Err(IsogramError::NonAlphabeticCharacter)), + invalid_isogram_with_special_char: ("abc!", Err(IsogramError::NonAlphabeticCharacter)), + invalid_isogram_with_comma: ("Velho, traduz sim", Err(IsogramError::NonAlphabeticCharacter)), + invalid_isogram_with_spaces: ("a b c d a", Ok(false)), + invalid_isogram_with_repeated_phrase: ("abcabc", Ok(false)), + isogram_empty_string: ("", Ok(true)), + isogram_single_character: ("a", Ok(true)), + invalid_isogram_multiple_same_characters: ("aaaa", Ok(false)), + invalid_isogram_with_symbols: ("abc@#$%", Err(IsogramError::NonAlphabeticCharacter)), + } +} diff --git a/src/string/isomorphism.rs b/src/string/isomorphism.rs new file mode 100644 index 00000000000..8583ece1e1c --- /dev/null +++ b/src/string/isomorphism.rs @@ -0,0 +1,83 @@ +//! This module provides functionality to determine whether two strings are isomorphic. +//! +//! Two strings are considered isomorphic if the characters in one string can be replaced +//! by some mapping relation to obtain the other string. +use std::collections::HashMap; + +/// Determines whether two strings are isomorphic. +/// +/// # Arguments +/// +/// * `s` - The first string. +/// * `t` - The second string. +/// +/// # Returns +/// +/// `true` if the strings are isomorphic, `false` otherwise. +pub fn is_isomorphic(s: &str, t: &str) -> bool { + let s_chars: Vec = s.chars().collect(); + let t_chars: Vec = t.chars().collect(); + if s_chars.len() != t_chars.len() { + return false; + } + let mut s_to_t_map = HashMap::new(); + let mut t_to_s_map = HashMap::new(); + for (s_char, t_char) in s_chars.into_iter().zip(t_chars) { + if !check_mapping(&mut s_to_t_map, s_char, t_char) + || !check_mapping(&mut t_to_s_map, t_char, s_char) + { + return false; + } + } + true +} + +/// Checks the mapping between two characters and updates the map. +/// +/// # Arguments +/// +/// * `map` - The HashMap to store the mapping. +/// * `key` - The key character. +/// * `value` - The value character. +/// +/// # Returns +/// +/// `true` if the mapping is consistent, `false` otherwise. +fn check_mapping(map: &mut HashMap, key: char, value: char) -> bool { + match map.get(&key) { + Some(&mapped_char) => mapped_char == value, + None => { + map.insert(key, value); + true + } + } +} + +#[cfg(test)] +mod tests { + use super::is_isomorphic; + macro_rules! test_is_isomorphic { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (s, t, expected) = $inputs; + assert_eq!(is_isomorphic(s, t), expected); + assert_eq!(is_isomorphic(t, s), expected); + assert!(is_isomorphic(s, s)); + assert!(is_isomorphic(t, t)); + } + )* + } + } + test_is_isomorphic! { + isomorphic: ("egg", "add", true), + isomorphic_long: ("abcdaabdcdbbabababacdadad", "AbCdAAbdCdbbAbAbAbACdAdAd", true), + not_isomorphic: ("egg", "adc", false), + non_isomorphic_long: ("abcdaabdcdbbabababacdadad", "AACdAAbdCdbbAbAbAbACdAdAd", false), + isomorphic_unicode: ("天苍苍", "野茫茫", true), + isomorphic_unicode_different_byte_size: ("abb", "野茫茫", true), + empty: ("", "", true), + different_length: ("abc", "abcd", false), + } +} diff --git a/src/string/jaro_winkler_distance.rs b/src/string/jaro_winkler_distance.rs new file mode 100644 index 00000000000..e00e526e676 --- /dev/null +++ b/src/string/jaro_winkler_distance.rs @@ -0,0 +1,84 @@ +// In computer science and statistics, +// the Jaro–Winkler distance is a string metric measuring an edit distance +// between two sequences. +// It is a variant proposed in 1990 by William E. Winkler +// of the Jaro distance metric (1989, Matthew A. Jaro). + +pub fn jaro_winkler_distance(str1: &str, str2: &str) -> f64 { + if str1.is_empty() || str2.is_empty() { + return 0.0; + } + fn get_matched_characters(s1: &str, s2: &str) -> String { + let mut s2 = s2.to_string(); + let mut matched: Vec = Vec::new(); + let limit = std::cmp::min(s1.len(), s2.len()) / 2; + for (i, l) in s1.chars().enumerate() { + let left = std::cmp::max(0, i as i32 - limit as i32) as usize; + let right = std::cmp::min(i + limit + 1, s2.len()); + if s2[left..right].contains(l) { + matched.push(l); + let a = &s2[0..s2.find(l).expect("this exists")]; + let b = &s2[(s2.find(l).expect("this exists") + 1)..]; + s2 = format!("{a} {b}"); + } + } + matched.iter().collect::() + } + + let matching_1 = get_matched_characters(str1, str2); + let matching_2 = get_matched_characters(str2, str1); + let match_count = matching_1.len(); + + // transposition + let transpositions = { + let mut count = 0; + for (c1, c2) in matching_1.chars().zip(matching_2.chars()) { + if c1 != c2 { + count += 1; + } + } + count / 2 + }; + + let jaro: f64 = { + if match_count == 0 { + return 0.0; + } + (1_f64 / 3_f64) + * (match_count as f64 / str1.len() as f64 + + match_count as f64 / str2.len() as f64 + + (match_count - transpositions) as f64 / match_count as f64) + }; + + let mut prefix_len = 0.0; + let bound = std::cmp::min(std::cmp::min(str1.len(), str2.len()), 4); + for (c1, c2) in str1[..bound].chars().zip(str2[..bound].chars()) { + if c1 == c2 { + prefix_len += 1.0; + } else { + break; + } + } + jaro + (0.1 * prefix_len * (1.0 - jaro)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_jaro_winkler_distance() { + let a = jaro_winkler_distance("hello", "world"); + assert_eq!(a, 0.4666666666666666); + let a = jaro_winkler_distance("martha", "marhta"); + assert_eq!(a, 0.9611111111111111); + let a = jaro_winkler_distance("martha", "marhat"); + assert_eq!(a, 0.9611111111111111); + let a = jaro_winkler_distance("test", "test"); + assert_eq!(a, 1.0); + let a = jaro_winkler_distance("test", ""); + assert_eq!(a, 0.0); + let a = jaro_winkler_distance("hello world", "HeLLo W0rlD"); + assert_eq!(a, 0.6363636363636364); + } +} diff --git a/src/string/knuth_morris_pratt.rs b/src/string/knuth_morris_pratt.rs index 1bd183ab3f1..ec3fba0c9f1 100644 --- a/src/string/knuth_morris_pratt.rs +++ b/src/string/knuth_morris_pratt.rs @@ -1,97 +1,146 @@ -pub fn knuth_morris_pratt(st: String, pat: String) -> Vec { - if st.is_empty() || pat.is_empty() { +//! Knuth-Morris-Pratt string matching algorithm implementation in Rust. +//! +//! This module contains the implementation of the KMP algorithm, which is used for finding +//! occurrences of a pattern string within a text string efficiently. The algorithm preprocesses +//! the pattern to create a partial match table, which allows for efficient searching. + +/// Finds all occurrences of the pattern in the given string using the Knuth-Morris-Pratt algorithm. +/// +/// # Arguments +/// +/// * `string` - The string to search within. +/// * `pattern` - The pattern string to search for. +/// +/// # Returns +/// +/// A vector of starting indices where the pattern is found in the string. If the pattern or the +/// string is empty, an empty vector is returned. +pub fn knuth_morris_pratt(string: &str, pattern: &str) -> Vec { + if string.is_empty() || pattern.is_empty() { return vec![]; } - let string = st.into_bytes(); - let pattern = pat.into_bytes(); + let text_chars = string.chars().collect::>(); + let pattern_chars = pattern.chars().collect::>(); + let partial_match_table = build_partial_match_table(&pattern_chars); + find_pattern(&text_chars, &pattern_chars, &partial_match_table) +} - // build the partial match table - let mut partial = vec![0]; - for i in 1..pattern.len() { - let mut j = partial[i - 1]; - while j > 0 && pattern[j] != pattern[i] { - j = partial[j - 1]; - } - partial.push(if pattern[j] == pattern[i] { j + 1 } else { j }); - } +/// Builds the partial match table (also known as "prefix table") for the given pattern. +/// +/// The partial match table is used to skip characters while matching the pattern in the text. +/// Each entry at index `i` in the table indicates the length of the longest proper prefix of +/// the substring `pattern[0..i]` which is also a suffix of this substring. +/// +/// # Arguments +/// +/// * `pattern_chars` - The pattern string as a slice of characters. +/// +/// # Returns +/// +/// A vector representing the partial match table. +fn build_partial_match_table(pattern_chars: &[char]) -> Vec { + let mut partial_match_table = vec![0]; + pattern_chars + .iter() + .enumerate() + .skip(1) + .for_each(|(index, &char)| { + let mut length = partial_match_table[index - 1]; + while length > 0 && pattern_chars[length] != char { + length = partial_match_table[length - 1]; + } + partial_match_table.push(if pattern_chars[length] == char { + length + 1 + } else { + length + }); + }); + partial_match_table +} - // and read 'string' to find 'pattern' - let mut ret = vec![]; - let mut j = 0; +/// Finds all occurrences of the pattern in the given string using the precomputed partial match table. +/// +/// This function iterates through the string and uses the partial match table to efficiently find +/// all starting indices of the pattern in the string. +/// +/// # Arguments +/// +/// * `text_chars` - The string to search within as a slice of characters. +/// * `pattern_chars` - The pattern string to search for as a slice of characters. +/// * `partial_match_table` - The precomputed partial match table for the pattern. +/// +/// # Returns +/// +/// A vector of starting indices where the pattern is found in the string. +fn find_pattern( + text_chars: &[char], + pattern_chars: &[char], + partial_match_table: &[usize], +) -> Vec { + let mut result_indices = vec![]; + let mut match_length = 0; - for (i, &c) in string.iter().enumerate() { - while j > 0 && c != pattern[j] { - j = partial[j - 1]; - } - if c == pattern[j] { - j += 1; - } - if j == pattern.len() { - ret.push(i + 1 - j); - j = partial[j - 1]; - } - } + text_chars + .iter() + .enumerate() + .for_each(|(text_index, &text_char)| { + while match_length > 0 && text_char != pattern_chars[match_length] { + match_length = partial_match_table[match_length - 1]; + } + if text_char == pattern_chars[match_length] { + match_length += 1; + } + if match_length == pattern_chars.len() { + result_indices.push(text_index + 1 - match_length); + match_length = partial_match_table[match_length - 1]; + } + }); - ret + result_indices } #[cfg(test)] mod tests { use super::*; - #[test] - fn each_letter_matches() { - let index = knuth_morris_pratt("aaa".to_string(), "a".to_string()); - assert_eq!(index, vec![0, 1, 2]); - } - - #[test] - fn a_few_separate_matches() { - let index = knuth_morris_pratt("abababa".to_string(), "ab".to_string()); - assert_eq!(index, vec![0, 2, 4]); - } - - #[test] - fn one_match() { - let index = - knuth_morris_pratt("ABC ABCDAB ABCDABCDABDE".to_string(), "ABCDABD".to_string()); - assert_eq!(index, vec![15]); - } - - #[test] - fn lots_of_matches() { - let index = knuth_morris_pratt("aaabaabaaaaa".to_string(), "aa".to_string()); - assert_eq!(index, vec![0, 1, 4, 7, 8, 9, 10]); - } - - #[test] - fn lots_of_intricate_matches() { - let index = knuth_morris_pratt("ababababa".to_string(), "aba".to_string()); - assert_eq!(index, vec![0, 2, 4, 6]); - } - - #[test] - fn not_found0() { - let index = knuth_morris_pratt("abcde".to_string(), "f".to_string()); - assert_eq!(index, vec![]); - } - - #[test] - fn not_found1() { - let index = knuth_morris_pratt("abcde".to_string(), "ac".to_string()); - assert_eq!(index, vec![]); - } - - #[test] - fn not_found2() { - let index = knuth_morris_pratt("ababab".to_string(), "bababa".to_string()); - assert_eq!(index, vec![]); + macro_rules! test_knuth_morris_pratt { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (input, pattern, expected) = $inputs; + assert_eq!(knuth_morris_pratt(input, pattern), expected); + } + )* + } } - #[test] - fn empty_string() { - let index = knuth_morris_pratt("".to_string(), "abcdef".to_string()); - assert_eq!(index, vec![]); + test_knuth_morris_pratt! { + each_letter_matches: ("aaa", "a", vec![0, 1, 2]), + a_few_seperate_matches: ("abababa", "ab", vec![0, 2, 4]), + unicode: ("അഅഅ", "അ", vec![0, 1, 2]), + unicode_no_match_but_similar_bytes: ( + &String::from_utf8(vec![224, 180, 133]).unwrap(), + &String::from_utf8(vec![224, 180, 132]).unwrap(), + vec![] + ), + one_match: ("ABC ABCDAB ABCDABCDABDE", "ABCDABD", vec![15]), + lots_of_matches: ("aaabaabaaaaa", "aa", vec![0, 1, 4, 7, 8, 9, 10]), + lots_of_intricate_matches: ("ababababa", "aba", vec![0, 2, 4, 6]), + not_found0: ("abcde", "f", vec![]), + not_found1: ("abcde", "ac", vec![]), + not_found2: ("ababab", "bababa", vec![]), + empty_string: ("", "abcdef", vec![]), + empty_pattern: ("abcdef", "", vec![]), + single_character_string: ("a", "a", vec![0]), + single_character_pattern: ("abcdef", "d", vec![3]), + pattern_at_start: ("abcdef", "abc", vec![0]), + pattern_at_end: ("abcdef", "def", vec![3]), + pattern_in_middle: ("abcdef", "cd", vec![2]), + no_match_with_repeated_characters: ("aaaaaa", "b", vec![]), + pattern_longer_than_string: ("abc", "abcd", vec![]), + very_long_string: (&"a".repeat(10000), "a", (0..10000).collect::>()), + very_long_pattern: (&"a".repeat(10000), &"a".repeat(9999), (0..2).collect::>()), } } diff --git a/src/string/levenshtein_distance.rs b/src/string/levenshtein_distance.rs new file mode 100644 index 00000000000..1a1ccefaee4 --- /dev/null +++ b/src/string/levenshtein_distance.rs @@ -0,0 +1,164 @@ +//! Provides functions to calculate the Levenshtein distance between two strings. +//! +//! The Levenshtein distance is a measure of the similarity between two strings by calculating the minimum number of single-character +//! edits (insertions, deletions, or substitutions) required to change one string into the other. + +use std::cmp::min; + +/// Calculates the Levenshtein distance between two strings using a naive dynamic programming approach. +/// +/// The Levenshtein distance is a measure of the similarity between two strings by calculating the minimum number of single-character +/// edits (insertions, deletions, or substitutions) required to change one string into the other. +/// +/// # Arguments +/// +/// * `string1` - A reference to the first string. +/// * `string2` - A reference to the second string. +/// +/// # Returns +/// +/// The Levenshtein distance between the two input strings. +/// +/// This function computes the Levenshtein distance by constructing a dynamic programming matrix and iteratively filling it in. +/// It follows the standard top-to-bottom, left-to-right approach for filling in the matrix. +/// +/// # Complexity +/// +/// - Time complexity: O(nm), +/// - Space complexity: O(nm), +/// +/// where n and m are lengths of `string1` and `string2`. +/// +/// Note that this implementation uses a straightforward dynamic programming approach without any space optimization. +/// It may consume more memory for larger input strings compared to the optimized version. +pub fn naive_levenshtein_distance(string1: &str, string2: &str) -> usize { + let distance_matrix: Vec> = (0..=string1.len()) + .map(|i| { + (0..=string2.len()) + .map(|j| { + if i == 0 { + j + } else if j == 0 { + i + } else { + 0 + } + }) + .collect() + }) + .collect(); + + let updated_matrix = (1..=string1.len()).fold(distance_matrix, |matrix, i| { + (1..=string2.len()).fold(matrix, |mut inner_matrix, j| { + let cost = usize::from(string1.as_bytes()[i - 1] != string2.as_bytes()[j - 1]); + inner_matrix[i][j] = (inner_matrix[i - 1][j - 1] + cost) + .min(inner_matrix[i][j - 1] + 1) + .min(inner_matrix[i - 1][j] + 1); + inner_matrix + }) + }); + + updated_matrix[string1.len()][string2.len()] +} + +/// Calculates the Levenshtein distance between two strings using an optimized dynamic programming approach. +/// +/// This edit distance is defined as 1 point per insertion, substitution, or deletion required to make the strings equal. +/// +/// # Arguments +/// +/// * `string1` - The first string. +/// * `string2` - The second string. +/// +/// # Returns +/// +/// The Levenshtein distance between the two input strings. +/// For a detailed explanation, check the example on [Wikipedia](https://en.wikipedia.org/wiki/Levenshtein_distance). +/// This function iterates over the bytes in the string, so it may not behave entirely as expected for non-ASCII strings. +/// +/// Note that this implementation utilizes an optimized dynamic programming approach, significantly reducing the space complexity from O(nm) to O(n), where n and m are the lengths of `string1` and `string2`. +/// +/// Additionally, it minimizes space usage by leveraging the shortest string horizontally and the longest string vertically in the computation matrix. +/// +/// # Complexity +/// +/// - Time complexity: O(nm), +/// - Space complexity: O(n), +/// +/// where n and m are lengths of `string1` and `string2`. +pub fn optimized_levenshtein_distance(string1: &str, string2: &str) -> usize { + if string1.is_empty() { + return string2.len(); + } + let l1 = string1.len(); + let mut prev_dist: Vec = (0..=l1).collect(); + + for (row, c2) in string2.chars().enumerate() { + // we'll keep a reference to matrix[i-1][j-1] (top-left cell) + let mut prev_substitution_cost = prev_dist[0]; + // diff with empty string, since `row` starts at 0, it's `row + 1` + prev_dist[0] = row + 1; + + for (col, c1) in string1.chars().enumerate() { + // "on the left" in the matrix (i.e. the value we just computed) + let deletion_cost = prev_dist[col] + 1; + // "on the top" in the matrix (means previous) + let insertion_cost = prev_dist[col + 1] + 1; + let substitution_cost = if c1 == c2 { + // last char is the same on both ends, so the min_distance is left unchanged from matrix[i-1][i+1] + prev_substitution_cost + } else { + // substitute the last character + prev_substitution_cost + 1 + }; + // save the old value at (i-1, j-1) + prev_substitution_cost = prev_dist[col + 1]; + prev_dist[col + 1] = _min3(deletion_cost, insertion_cost, substitution_cost); + } + } + prev_dist[l1] +} + +#[inline] +fn _min3(a: T, b: T, c: T) -> T { + min(a, min(b, c)) +} + +#[cfg(test)] +mod tests { + const LEVENSHTEIN_DISTANCE_TEST_CASES: &[(&str, &str, usize)] = &[ + ("", "", 0), + ("Hello, World!", "Hello, World!", 0), + ("", "Rust", 4), + ("horse", "ros", 3), + ("tan", "elephant", 6), + ("execute", "intention", 8), + ]; + + macro_rules! levenshtein_distance_tests { + ($function:ident) => { + mod $function { + use super::*; + + fn run_test_case(string1: &str, string2: &str, expected_distance: usize) { + assert_eq!(super::super::$function(string1, string2), expected_distance); + assert_eq!(super::super::$function(string2, string1), expected_distance); + assert_eq!(super::super::$function(string1, string1), 0); + assert_eq!(super::super::$function(string2, string2), 0); + } + + #[test] + fn test_levenshtein_distance() { + for &(string1, string2, expected_distance) in + LEVENSHTEIN_DISTANCE_TEST_CASES.iter() + { + run_test_case(string1, string2, expected_distance); + } + } + } + }; + } + + levenshtein_distance_tests!(naive_levenshtein_distance); + levenshtein_distance_tests!(optimized_levenshtein_distance); +} diff --git a/src/string/lipogram.rs b/src/string/lipogram.rs new file mode 100644 index 00000000000..9a486c2a62d --- /dev/null +++ b/src/string/lipogram.rs @@ -0,0 +1,112 @@ +use std::collections::HashSet; + +/// Represents possible errors that can occur when checking for lipograms. +#[derive(Debug, PartialEq, Eq)] +pub enum LipogramError { + /// Indicates that a non-alphabetic character was found in the input. + NonAlphabeticCharacter, + /// Indicates that a missing character is not in lowercase. + NonLowercaseMissingChar, +} + +/// Computes the set of missing alphabetic characters from the input string. +/// +/// # Arguments +/// +/// * `in_str` - A string slice that contains the input text. +/// +/// # Returns +/// +/// Returns a `HashSet` containing the lowercase alphabetic characters that are not present in `in_str`. +fn compute_missing(in_str: &str) -> HashSet { + let alphabet: HashSet = ('a'..='z').collect(); + + let letters_used: HashSet = in_str + .to_lowercase() + .chars() + .filter(|c| c.is_ascii_alphabetic()) + .collect(); + + alphabet.difference(&letters_used).cloned().collect() +} + +/// Checks if the provided string is a lipogram, meaning it is missing specific characters. +/// +/// # Arguments +/// +/// * `lipogram_str` - A string slice that contains the text to be checked for being a lipogram. +/// * `missing_chars` - A reference to a `HashSet` containing the expected missing characters. +/// +/// # Returns +/// +/// Returns `Ok(true)` if the string is a lipogram that matches the provided missing characters, +/// `Ok(false)` if it does not match, or a `LipogramError` if the input contains invalid characters. +pub fn is_lipogram( + lipogram_str: &str, + missing_chars: &HashSet, +) -> Result { + for &c in missing_chars { + if !c.is_lowercase() { + return Err(LipogramError::NonLowercaseMissingChar); + } + } + + for c in lipogram_str.chars() { + if !c.is_ascii_alphabetic() && !c.is_whitespace() { + return Err(LipogramError::NonAlphabeticCharacter); + } + } + + let missing = compute_missing(lipogram_str); + Ok(missing == *missing_chars) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_lipogram { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (input, missing_chars, expected) = $tc; + assert_eq!(is_lipogram(input, &missing_chars), expected); + } + )* + } + } + + test_lipogram! { + perfect_pangram: ( + "The quick brown fox jumps over the lazy dog", + HashSet::from([]), + Ok(true) + ), + lipogram_single_missing: ( + "The quick brown fox jumped over the lazy dog", + HashSet::from(['s']), + Ok(true) + ), + lipogram_multiple_missing: ( + "The brown fox jumped over the lazy dog", + HashSet::from(['q', 'i', 'c', 'k', 's']), + Ok(true) + ), + long_lipogram_single_missing: ( + "A jovial swain should not complain of any buxom fair who mocks his pain and thinks it gain to quiz his awkward air", + HashSet::from(['e']), + Ok(true) + ), + invalid_non_lowercase_chars: ( + "The quick brown fox jumped over the lazy dog", + HashSet::from(['X']), + Err(LipogramError::NonLowercaseMissingChar) + ), + invalid_non_alphabetic_input: ( + "The quick brown fox jumps over the lazy dog 123@!", + HashSet::from([]), + Err(LipogramError::NonAlphabeticCharacter) + ), + } +} diff --git a/src/string/manacher.rs b/src/string/manacher.rs new file mode 100644 index 00000000000..e45a3f15612 --- /dev/null +++ b/src/string/manacher.rs @@ -0,0 +1,91 @@ +pub fn manacher(s: String) -> String { + let l = s.len(); + if l <= 1 { + return s; + } + + // MEMO: We need to detect odd palindrome as well, + // therefore, inserting dummy string so that + // we can find a pair with dummy center character. + let mut chars: Vec = Vec::with_capacity(s.len() * 2 + 1); + for c in s.chars() { + chars.push('#'); + chars.push(c); + } + chars.push('#'); + + // List: storing the length of palindrome at each index of string + let mut length_of_palindrome = vec![1usize; chars.len()]; + // Integer: Current checking palindrome's center index + let mut current_center: usize = 0; + // Integer: Right edge index existing the radius away from current center + let mut right_from_current_center: usize = 0; + + for i in 0..chars.len() { + // 1: Check if we are looking at right side of palindrome. + if right_from_current_center > i && i > current_center { + // 1-1: If so copy from the left side of palindrome. + // If the value + index exceeds the right edge index, we should cut and check palindrome later #3. + length_of_palindrome[i] = std::cmp::min( + right_from_current_center - i, + length_of_palindrome[2 * current_center - i], + ); + // 1-2: Move the checking palindrome to new index if it exceeds the right edge. + if length_of_palindrome[i] + i >= right_from_current_center { + current_center = i; + right_from_current_center = length_of_palindrome[i] + i; + // 1-3: If radius exceeds the end of list, it means checking is over. + // You will never get the larger value because the string will get only shorter. + if right_from_current_center >= chars.len() - 1 { + break; + } + } else { + // 1-4: If the checking index doesn't exceeds the right edge, + // it means the length is just as same as the left side. + // You don't need to check anymore. + continue; + } + } + + // Integer: Current radius from checking index + // If it's copied from left side and more than 1, + // it means it's ensured so you don't need to check inside radius. + let mut radius: usize = (length_of_palindrome[i] - 1) / 2; + radius += 1; + // 2: Checking palindrome. + // Need to care about overflow usize. + while i >= radius && i + radius <= chars.len() - 1 && chars[i - radius] == chars[i + radius] + { + length_of_palindrome[i] += 2; + radius += 1; + } + } + + // 3: Find the maximum length and generate answer. + let center_of_max = length_of_palindrome + .iter() + .enumerate() + .max_by_key(|(_, &value)| value) + .map(|(idx, _)| idx) + .unwrap(); + let radius_of_max = (length_of_palindrome[center_of_max] - 1) / 2; + let answer = &chars[(center_of_max - radius_of_max)..=(center_of_max + radius_of_max)] + .iter() + .collect::(); + answer.replace('#', "") +} + +#[cfg(test)] +mod tests { + use super::manacher; + + #[test] + fn get_longest_palindrome_by_manacher() { + assert_eq!(manacher("babad".to_string()), "aba".to_string()); + assert_eq!(manacher("cbbd".to_string()), "bb".to_string()); + assert_eq!(manacher("a".to_string()), "a".to_string()); + + let ac_ans = manacher("ac".to_string()); + assert!(ac_ans == *"a" || ac_ans == *"c"); + } +} diff --git a/src/string/mod.rs b/src/string/mod.rs index 5ea328dff2e..6ba37f39f29 100644 --- a/src/string/mod.rs +++ b/src/string/mod.rs @@ -1,3 +1,53 @@ +mod aho_corasick; +mod anagram; +mod autocomplete_using_trie; +mod boyer_moore_search; +mod burrows_wheeler_transform; +mod duval_algorithm; +mod hamming_distance; +mod isogram; +mod isomorphism; +mod jaro_winkler_distance; mod knuth_morris_pratt; +mod levenshtein_distance; +mod lipogram; +mod manacher; +mod palindrome; +mod pangram; +mod rabin_karp; +mod reverse; +mod run_length_encoding; +mod shortest_palindrome; +mod suffix_array; +mod suffix_array_manber_myers; +mod suffix_tree; +mod z_algorithm; +pub use self::aho_corasick::AhoCorasick; +pub use self::anagram::check_anagram; +pub use self::autocomplete_using_trie::Autocomplete; +pub use self::boyer_moore_search::boyer_moore_search; +pub use self::burrows_wheeler_transform::{ + burrows_wheeler_transform, inv_burrows_wheeler_transform, +}; +pub use self::duval_algorithm::duval_algorithm; +pub use self::hamming_distance::hamming_distance; +pub use self::isogram::is_isogram; +pub use self::isomorphism::is_isomorphic; +pub use self::jaro_winkler_distance::jaro_winkler_distance; pub use self::knuth_morris_pratt::knuth_morris_pratt; +pub use self::levenshtein_distance::{naive_levenshtein_distance, optimized_levenshtein_distance}; +pub use self::lipogram::is_lipogram; +pub use self::manacher::manacher; +pub use self::palindrome::is_palindrome; +pub use self::pangram::is_pangram; +pub use self::pangram::PangramStatus; +pub use self::rabin_karp::rabin_karp; +pub use self::reverse::reverse; +pub use self::run_length_encoding::{run_length_decoding, run_length_encoding}; +pub use self::shortest_palindrome::shortest_palindrome; +pub use self::suffix_array::generate_suffix_array; +pub use self::suffix_array_manber_myers::generate_suffix_array_manber_myers; +pub use self::suffix_tree::{Node, SuffixTree}; +pub use self::z_algorithm::match_pattern; +pub use self::z_algorithm::z_array; diff --git a/src/string/palindrome.rs b/src/string/palindrome.rs new file mode 100644 index 00000000000..6ee2d0be7ca --- /dev/null +++ b/src/string/palindrome.rs @@ -0,0 +1,74 @@ +//! A module for checking if a given string is a palindrome. + +/// Checks if the given string is a palindrome. +/// +/// A palindrome is a sequence that reads the same backward as forward. +/// This function ignores non-alphanumeric characters and is case-insensitive. +/// +/// # Arguments +/// +/// * `s` - A string slice that represents the input to be checked. +/// +/// # Returns +/// +/// * `true` if the string is a palindrome; otherwise, `false`. +pub fn is_palindrome(s: &str) -> bool { + let mut chars = s + .chars() + .filter(|c| c.is_alphanumeric()) + .map(|c| c.to_ascii_lowercase()); + + while let (Some(c1), Some(c2)) = (chars.next(), chars.next_back()) { + if c1 != c2 { + return false; + } + } + + true +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! palindrome_tests { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $inputs; + assert_eq!(is_palindrome(input), expected); + } + )* + } + } + + palindrome_tests! { + odd_palindrome: ("madam", true), + even_palindrome: ("deified", true), + single_character_palindrome: ("x", true), + single_word_palindrome: ("eye", true), + case_insensitive_palindrome: ("RaceCar", true), + mixed_case_and_punctuation_palindrome: ("A man, a plan, a canal, Panama!", true), + mixed_case_and_space_palindrome: ("No 'x' in Nixon", true), + empty_string: ("", true), + pompeii_palindrome: ("Roma-Olima-Milo-Amor", true), + napoleon_palindrome: ("Able was I ere I saw Elba", true), + john_taylor_palindrome: ("Lewd did I live, & evil I did dwel", true), + well_know_english_palindrome: ("Never odd or even", true), + palindromic_phrase: ("Rats live on no evil star", true), + names_palindrome: ("Hannah", true), + prime_minister_of_cambodia: ("Lon Nol", true), + japanese_novelist_and_manga_writer: ("Nisio Isin", true), + actor: ("Robert Trebor", true), + rock_vocalist: ("Ola Salo", true), + pokemon_species: ("Girafarig", true), + lychrel_num_56: ("121", true), + universal_palindrome_date: ("02/02/2020", true), + french_palindrome: ("une Slave valse nu", true), + finnish_palindrome: ("saippuakivikauppias", true), + non_palindrome_simple: ("hello", false), + non_palindrome_with_punctuation: ("hello!", false), + non_palindrome_mixed_case: ("Hello, World", false), + } +} diff --git a/src/string/pangram.rs b/src/string/pangram.rs new file mode 100644 index 00000000000..19ccad4a688 --- /dev/null +++ b/src/string/pangram.rs @@ -0,0 +1,92 @@ +//! This module provides functionality to check if a given string is a pangram. +//! +//! A pangram is a sentence that contains every letter of the alphabet at least once. +//! This module can distinguish between a non-pangram, a regular pangram, and a +//! perfect pangram, where each letter appears exactly once. + +use std::collections::HashSet; + +/// Represents the status of a string in relation to the pangram classification. +#[derive(PartialEq, Debug)] +pub enum PangramStatus { + NotPangram, + Pangram, + PerfectPangram, +} + +fn compute_letter_counts(pangram_str: &str) -> std::collections::HashMap { + let mut letter_counts = std::collections::HashMap::new(); + + for ch in pangram_str + .to_lowercase() + .chars() + .filter(|c| c.is_ascii_alphabetic()) + { + *letter_counts.entry(ch).or_insert(0) += 1; + } + + letter_counts +} + +/// Determines if the input string is a pangram, and classifies it as either a regular or perfect pangram. +/// +/// # Arguments +/// +/// * `pangram_str` - A reference to the string slice to be checked for pangram status. +/// +/// # Returns +/// +/// A `PangramStatus` enum indicating whether the string is a pangram, and if so, whether it is a perfect pangram. +pub fn is_pangram(pangram_str: &str) -> PangramStatus { + let letter_counts = compute_letter_counts(pangram_str); + + let alphabet: HashSet = ('a'..='z').collect(); + let used_letters: HashSet<_> = letter_counts.keys().cloned().collect(); + + if used_letters != alphabet { + return PangramStatus::NotPangram; + } + + if letter_counts.values().all(|&count| count == 1) { + PangramStatus::PerfectPangram + } else { + PangramStatus::Pangram + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! pangram_tests { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $tc; + assert_eq!(is_pangram(input), expected); + } + )* + }; + } + + pangram_tests! { + test_not_pangram_simple: ("This is not a pangram", PangramStatus::NotPangram), + test_not_pangram_day: ("today is a good day", PangramStatus::NotPangram), + test_not_pangram_almost: ("this is almost a pangram but it does not have bcfghjkqwxy and the last letter", PangramStatus::NotPangram), + test_pangram_standard: ("The quick brown fox jumps over the lazy dog", PangramStatus::Pangram), + test_pangram_boxer: ("A mad boxer shot a quick, gloved jab to the jaw of his dizzy opponent", PangramStatus::Pangram), + test_pangram_discotheques: ("Amazingly few discotheques provide jukeboxes", PangramStatus::Pangram), + test_pangram_zebras: ("How vexingly quick daft zebras jump", PangramStatus::Pangram), + test_perfect_pangram_jock: ("Mr. Jock, TV quiz PhD, bags few lynx", PangramStatus::PerfectPangram), + test_empty_string: ("", PangramStatus::NotPangram), + test_repeated_letter: ("aaaaa", PangramStatus::NotPangram), + test_non_alphabetic: ("12345!@#$%", PangramStatus::NotPangram), + test_mixed_case_pangram: ("ThE QuiCk BroWn FoX JumPs OveR tHe LaZy DoG", PangramStatus::Pangram), + test_perfect_pangram_with_symbols: ("Mr. Jock, TV quiz PhD, bags few lynx!", PangramStatus::PerfectPangram), + test_long_non_pangram: (&"a".repeat(1000), PangramStatus::NotPangram), + test_near_pangram_missing_one_letter: ("The quick brown fox jumps over the lazy do", PangramStatus::NotPangram), + test_near_pangram_missing_two_letters: ("The quick brwn f jumps ver the lazy dg", PangramStatus::NotPangram), + test_near_pangram_with_special_characters: ("Th3 qu!ck brown f0x jumps 0v3r th3 l@zy d0g.", PangramStatus::NotPangram), + } +} diff --git a/src/string/rabin_karp.rs b/src/string/rabin_karp.rs new file mode 100644 index 00000000000..9901849990a --- /dev/null +++ b/src/string/rabin_karp.rs @@ -0,0 +1,123 @@ +//! This module implements the Rabin-Karp string searching algorithm. +//! It uses a rolling hash technique to find all occurrences of a pattern +//! within a target string efficiently. + +const MOD: usize = 101; +const RADIX: usize = 256; + +/// Finds all starting indices where the `pattern` appears in the `text`. +/// +/// # Arguments +/// * `text` - The string where the search is performed. +/// * `pattern` - The substring pattern to search for. +/// +/// # Returns +/// A vector of starting indices where the pattern is found. +pub fn rabin_karp(text: &str, pattern: &str) -> Vec { + if text.is_empty() || pattern.is_empty() || pattern.len() > text.len() { + return vec![]; + } + + let pat_hash = compute_hash(pattern); + let mut radix_pow = 1; + + // Compute RADIX^(n-1) % MOD + for _ in 0..pattern.len() - 1 { + radix_pow = (radix_pow * RADIX) % MOD; + } + + let mut rolling_hash = 0; + let mut result = vec![]; + for i in 0..=text.len() - pattern.len() { + rolling_hash = if i == 0 { + compute_hash(&text[0..pattern.len()]) + } else { + update_hash(text, i - 1, i + pattern.len() - 1, rolling_hash, radix_pow) + }; + if rolling_hash == pat_hash && pattern[..] == text[i..i + pattern.len()] { + result.push(i); + } + } + result +} + +/// Calculates the hash of a string using the Rabin-Karp formula. +/// +/// # Arguments +/// * `s` - The string to calculate the hash for. +/// +/// # Returns +/// The hash value of the string modulo `MOD`. +fn compute_hash(s: &str) -> usize { + let mut hash_val = 0; + for &byte in s.as_bytes().iter() { + hash_val = (hash_val * RADIX + byte as usize) % MOD; + } + hash_val +} + +/// Updates the rolling hash when shifting the search window. +/// +/// # Arguments +/// * `s` - The full text where the search is performed. +/// * `old_idx` - The index of the character that is leaving the window. +/// * `new_idx` - The index of the new character entering the window. +/// * `old_hash` - The hash of the previous substring. +/// * `radix_pow` - The precomputed value of RADIX^(n-1) % MOD. +/// +/// # Returns +/// The updated hash for the new substring. +fn update_hash( + s: &str, + old_idx: usize, + new_idx: usize, + old_hash: usize, + radix_pow: usize, +) -> usize { + let mut new_hash = old_hash; + let old_char = s.as_bytes()[old_idx] as usize; + let new_char = s.as_bytes()[new_idx] as usize; + new_hash = (new_hash + MOD - (old_char * radix_pow % MOD)) % MOD; + new_hash = (new_hash * RADIX + new_char) % MOD; + new_hash +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_cases { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (text, pattern, expected) = $inputs; + assert_eq!(rabin_karp(text, pattern), expected); + } + )* + }; + } + + test_cases! { + single_match_at_start: ("hello world", "hello", vec![0]), + single_match_at_end: ("hello world", "world", vec![6]), + single_match_in_middle: ("abc def ghi", "def", vec![4]), + multiple_matches: ("ababcabc", "abc", vec![2, 5]), + overlapping_matches: ("aaaaa", "aaa", vec![0, 1, 2]), + no_match: ("abcdefg", "xyz", vec![]), + pattern_is_entire_string: ("abc", "abc", vec![0]), + target_is_multiple_patterns: ("abcabcabc", "abc", vec![0, 3, 6]), + empty_text: ("", "abc", vec![]), + empty_pattern: ("abc", "", vec![]), + empty_text_and_pattern: ("", "", vec![]), + pattern_larger_than_text: ("abc", "abcd", vec![]), + large_text_small_pattern: (&("a".repeat(1000) + "b"), "b", vec![1000]), + single_char_match: ("a", "a", vec![0]), + single_char_no_match: ("a", "b", vec![]), + large_pattern_no_match: ("abc", "defghi", vec![]), + repeating_chars: ("aaaaaa", "aa", vec![0, 1, 2, 3, 4]), + special_characters: ("abc$def@ghi", "$def@", vec![3]), + numeric_and_alphabetic_mix: ("abc123abc456", "123abc", vec![3]), + case_sensitivity: ("AbcAbc", "abc", vec![]), + } +} diff --git a/src/string/reverse.rs b/src/string/reverse.rs new file mode 100644 index 00000000000..bf17745a147 --- /dev/null +++ b/src/string/reverse.rs @@ -0,0 +1,40 @@ +/// Reverses the given string. +/// +/// # Arguments +/// +/// * `text` - A string slice that holds the string to be reversed. +/// +/// # Returns +/// +/// * A new `String` that is the reverse of the input string. +pub fn reverse(text: &str) -> String { + text.chars().rev().collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(reverse(input), expected); + } + )* + }; + } + + test_cases! { + test_simple_palindrome: ("racecar", "racecar"), + test_non_palindrome: ("abcdef", "fedcba"), + test_sentence_with_spaces: ("step on no pets", "step on no pets"), + test_empty_string: ("", ""), + test_single_character: ("a", "a"), + test_leading_trailing_spaces: (" hello ", " olleh "), + test_unicode_characters: ("你好", "好你"), + test_mixed_content: ("a1b2c3!", "!3c2b1a"), + } +} diff --git a/src/string/run_length_encoding.rs b/src/string/run_length_encoding.rs new file mode 100644 index 00000000000..1952df4c230 --- /dev/null +++ b/src/string/run_length_encoding.rs @@ -0,0 +1,78 @@ +pub fn run_length_encoding(target: &str) -> String { + if target.trim().is_empty() { + return "".to_string(); + } + let mut count: i32 = 0; + let mut base_character: String = "".to_string(); + let mut encoded_target = String::new(); + + for c in target.chars() { + if base_character == *"" { + base_character = c.to_string(); + } + if c.to_string() == base_character { + count += 1; + } else { + encoded_target.push_str(&count.to_string()); + count = 1; + encoded_target.push_str(&base_character); + base_character = c.to_string(); + } + } + encoded_target.push_str(&count.to_string()); + encoded_target.push_str(&base_character); + + encoded_target +} + +pub fn run_length_decoding(target: &str) -> String { + if target.trim().is_empty() { + return "".to_string(); + } + let mut character_count = String::new(); + let mut decoded_target = String::new(); + + for c in target.chars() { + character_count.push(c); + let is_numeric: bool = character_count.parse::().is_ok(); + + if !is_numeric { + let pop_char: char = character_count.pop().unwrap(); + decoded_target.push_str( + &pop_char + .to_string() + .repeat(character_count.parse().unwrap()), + ); + character_count = "".to_string(); + } + } + + decoded_target +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_run_length { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (raw_str, encoded) = $test_case; + assert_eq!(run_length_encoding(raw_str), encoded); + assert_eq!(run_length_decoding(encoded), raw_str); + } + )* + }; + } + + test_run_length! { + empty_input: ("", ""), + repeated_char: ("aaaaaaaaaa", "10a"), + no_repeated: ("abcdefghijk", "1a1b1c1d1e1f1g1h1i1j1k"), + regular_input: ("aaaaabbbcccccdddddddddd", "5a3b5c10d"), + two_blocks_with_same_char: ("aaabbaaaa", "3a2b4a"), + long_input: ("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbcccccdddddddddd", "200a3b5c10d"), + } +} diff --git a/src/string/shortest_palindrome.rs b/src/string/shortest_palindrome.rs new file mode 100644 index 00000000000..e143590f3fc --- /dev/null +++ b/src/string/shortest_palindrome.rs @@ -0,0 +1,119 @@ +//! This module provides functions for finding the shortest palindrome +//! that can be formed by adding characters to the left of a given string. +//! References +//! +//! - [KMP](https://www.scaler.com/topics/data-structures/kmp-algorithm/) +//! - [Prefix Functions and KPM](https://oi-wiki.org/string/kmp/) + +/// Finds the shortest palindrome that can be formed by adding characters +/// to the left of the given string `s`. +/// +/// # Arguments +/// +/// * `s` - A string slice that holds the input string. +/// +/// # Returns +/// +/// Returns a new string that is the shortest palindrome, formed by adding +/// the necessary characters to the beginning of `s`. +pub fn shortest_palindrome(s: &str) -> String { + if s.is_empty() { + return "".to_string(); + } + + let original_chars: Vec = s.chars().collect(); + let suffix_table = compute_suffix(&original_chars); + + let mut reversed_chars: Vec = s.chars().rev().collect(); + // The prefix of the original string matches the suffix of the reversed string. + let prefix_match = compute_prefix_match(&original_chars, &reversed_chars, &suffix_table); + + reversed_chars.append(&mut original_chars[prefix_match[original_chars.len() - 1]..].to_vec()); + reversed_chars.iter().collect() +} + +/// Computes the suffix table used for the KMP (Knuth-Morris-Pratt) string +/// matching algorithm. +/// +/// # Arguments +/// +/// * `chars` - A slice of characters for which the suffix table is computed. +/// +/// # Returns +/// +/// Returns a vector of `usize` representing the suffix table. Each element +/// at index `i` indicates the longest proper suffix which is also a proper +/// prefix of the substring `chars[0..=i]`. +pub fn compute_suffix(chars: &[char]) -> Vec { + let mut suffix = vec![0; chars.len()]; + for i in 1..chars.len() { + let mut j = suffix[i - 1]; + while j > 0 && chars[j] != chars[i] { + j = suffix[j - 1]; + } + suffix[i] = j + (chars[j] == chars[i]) as usize; + } + suffix +} + +/// Computes the prefix matches of the original string against its reversed +/// version using the suffix table. +/// +/// # Arguments +/// +/// * `original` - A slice of characters representing the original string. +/// * `reversed` - A slice of characters representing the reversed string. +/// * `suffix` - A slice containing the suffix table computed for the original string. +/// +/// # Returns +/// +/// Returns a vector of `usize` where each element at index `i` indicates the +/// length of the longest prefix of `original` that matches a suffix of +/// `reversed[0..=i]`. +pub fn compute_prefix_match(original: &[char], reversed: &[char], suffix: &[usize]) -> Vec { + let mut match_table = vec![0; original.len()]; + match_table[0] = usize::from(original[0] == reversed[0]); + for i in 1..original.len() { + let mut j = match_table[i - 1]; + while j > 0 && reversed[i] != original[j] { + j = suffix[j - 1]; + } + match_table[i] = j + usize::from(reversed[i] == original[j]); + } + match_table +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::string::is_palindrome; + + macro_rules! test_shortest_palindrome { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $inputs; + assert!(is_palindrome(expected)); + assert_eq!(shortest_palindrome(input), expected); + assert_eq!(shortest_palindrome(expected), expected); + } + )* + } + } + + test_shortest_palindrome! { + empty: ("", ""), + extend_left_1: ("aacecaaa", "aaacecaaa"), + extend_left_2: ("abcd", "dcbabcd"), + unicode_1: ("അ", "അ"), + unicode_2: ("a牛", "牛a牛"), + single_char: ("x", "x"), + already_palindrome: ("racecar", "racecar"), + extend_left_3: ("abcde", "edcbabcde"), + extend_left_4: ("abca", "acbabca"), + long_string: ("abcdefg", "gfedcbabcdefg"), + repetitive: ("aaaaa", "aaaaa"), + complex: ("abacdfgdcaba", "abacdgfdcabacdfgdcaba"), + } +} diff --git a/src/string/suffix_array.rs b/src/string/suffix_array.rs new file mode 100644 index 00000000000..a89575fc8e8 --- /dev/null +++ b/src/string/suffix_array.rs @@ -0,0 +1,95 @@ +// In computer science, a suffix array is a sorted array of all suffixes of a string. +// It is a data structure used in, among others, full-text indices, data-compression algorithms, +// and the field of bibliometrics. Source: https://en.wikipedia.org/wiki/Suffix_array + +use std::cmp::Ordering; + +#[derive(Clone)] +struct Suffix { + index: usize, + rank: (i32, i32), +} + +impl Suffix { + fn cmp(&self, b: &Self) -> Ordering { + let a = self; + let ((a1, a2), (b1, b2)) = (a.rank, b.rank); + match a1.cmp(&b1) { + Ordering::Equal => { + if a2 < b2 { + Ordering::Less + } else { + Ordering::Greater + } + } + o => o, + } + } +} + +pub fn generate_suffix_array(txt: &str) -> Vec { + let n = txt.len(); + let mut suffixes: Vec = vec![ + Suffix { + index: 0, + rank: (-1, -1) + }; + n + ]; + for (i, suf) in suffixes.iter_mut().enumerate() { + suf.index = i; + suf.rank.0 = (txt.chars().nth(i).expect("this should exist") as u32 - 'a' as u32) as i32; + suf.rank.1 = if (i + 1) < n { + (txt.chars().nth(i + 1).expect("this should exist") as u32 - 'a' as u32) as i32 + } else { + -1 + } + } + suffixes.sort_by(|a, b| a.cmp(b)); + let mut ind = vec![0; n]; + let mut k = 4; + while k < 2 * n { + let mut rank = 0; + let mut prev_rank = suffixes[0].rank.0; + suffixes[0].rank.0 = rank; + ind[suffixes[0].index] = 0; + + for i in 1..n { + if suffixes[i].rank.0 == prev_rank && suffixes[i].rank.1 == suffixes[i - 1].rank.1 { + prev_rank = suffixes[i].rank.0; + suffixes[i].rank.0 = rank; + } else { + prev_rank = suffixes[i].rank.0; + rank += 1; + suffixes[i].rank.0 = rank; + } + ind[suffixes[i].index] = i; + } + for i in 0..n { + let next_index = suffixes[i].index + (k / 2); + suffixes[i].rank.1 = if next_index < n { + suffixes[ind[next_index]].rank.0 + } else { + -1 + } + } + suffixes.sort_by(|a, b| a.cmp(b)); + k *= 2; + } + let mut suffix_arr = Vec::new(); + for suf in suffixes { + suffix_arr.push(suf.index); + } + suffix_arr +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_suffix_array() { + let a = generate_suffix_array("banana"); + assert_eq!(a, vec![5, 3, 1, 0, 4, 2]); + } +} diff --git a/src/string/suffix_array_manber_myers.rs b/src/string/suffix_array_manber_myers.rs new file mode 100644 index 00000000000..4c4d58c5901 --- /dev/null +++ b/src/string/suffix_array_manber_myers.rs @@ -0,0 +1,107 @@ +pub fn generate_suffix_array_manber_myers(input: &str) -> Vec { + if input.is_empty() { + return Vec::new(); + } + let n = input.len(); + let mut suffixes: Vec<(usize, &str)> = Vec::with_capacity(n); + + for (i, _suffix) in input.char_indices() { + suffixes.push((i, &input[i..])); + } + + suffixes.sort_by_key(|&(_, s)| s); + + let mut suffix_array: Vec = vec![0; n]; + let mut rank = vec![0; n]; + + let mut cur_rank = 0; + let mut prev_suffix = &suffixes[0].1; + + for (i, suffix) in suffixes.iter().enumerate() { + if &suffix.1 != prev_suffix { + cur_rank += 1; + prev_suffix = &suffix.1; + } + rank[suffix.0] = cur_rank; + suffix_array[i] = suffix.0; + } + + let mut k = 1; + let mut new_rank: Vec = vec![0; n]; + + while k < n { + suffix_array.sort_by_key(|&x| (rank[x], rank[(x + k) % n])); + + let mut cur_rank = 0; + let mut prev = suffix_array[0]; + new_rank[prev] = cur_rank; + + for &suffix in suffix_array.iter().skip(1) { + let next = suffix; + if (rank[prev], rank[(prev + k) % n]) != (rank[next], rank[(next + k) % n]) { + cur_rank += 1; + } + new_rank[next] = cur_rank; + prev = next; + } + + std::mem::swap(&mut rank, &mut new_rank); + + k <<= 1; + } + + suffix_array +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_suffix_array() { + let input = "banana"; + let expected_result = vec![5, 3, 1, 0, 4, 2]; + assert_eq!(generate_suffix_array_manber_myers(input), expected_result); + } + + #[test] + fn test_empty_string() { + let input = ""; + let expected_result: Vec = Vec::new(); + assert_eq!(generate_suffix_array_manber_myers(input), expected_result); + } + + #[test] + fn test_single_character() { + let input = "a"; + let expected_result = vec![0]; + assert_eq!(generate_suffix_array_manber_myers(input), expected_result); + } + #[test] + fn test_repeating_characters() { + let input = "zzzzzz"; + let expected_result = vec![5, 4, 3, 2, 1, 0]; + assert_eq!(generate_suffix_array_manber_myers(input), expected_result); + } + + #[test] + fn test_long_string() { + let input = "abcdefghijklmnopqrstuvwxyz"; + let expected_result: Vec = (0..26).collect(); + assert_eq!(generate_suffix_array_manber_myers(input), expected_result); + } + + #[test] + fn test_mix_of_characters() { + let input = "abracadabra!"; + let expected_result = vec![11, 10, 7, 0, 3, 5, 8, 1, 4, 6, 9, 2]; + assert_eq!(generate_suffix_array_manber_myers(input), expected_result); + } + + #[test] + fn test_whitespace_characters() { + let input = " hello world "; + let expected_result = vec![12, 0, 6, 11, 2, 1, 10, 3, 4, 5, 8, 9, 7]; + assert_eq!(generate_suffix_array_manber_myers(input), expected_result); + } +} diff --git a/src/string/suffix_tree.rs b/src/string/suffix_tree.rs new file mode 100644 index 00000000000..24cedf6b197 --- /dev/null +++ b/src/string/suffix_tree.rs @@ -0,0 +1,152 @@ +// In computer science, a suffix tree (also called PAT tree or, in an earlier form, position tree) +// is a compressed trie containing all the suffixes of the given text as their keys and positions +// in the text as their values. Suffix trees allow particularly fast implementations of many +// important string operations. Source: https://en.wikipedia.org/wiki/Suffix_tree + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Node { + pub sub: String, // substring of input string + pub ch: Vec, // vector of child nodes +} + +impl Node { + fn new(sub: String, children: Vec) -> Self { + Node { + sub, + ch: children.to_vec(), + } + } + pub fn empty() -> Self { + Node { + sub: "".to_string(), + ch: vec![], + } + } +} + +pub struct SuffixTree { + pub nodes: Vec, +} + +impl SuffixTree { + pub fn new(s: &str) -> Self { + let mut suf_tree = SuffixTree { + nodes: vec![Node::empty()], + }; + for i in 0..s.len() { + let (_, substr) = s.split_at(i); + suf_tree.add_suffix(substr); + } + suf_tree + } + fn add_suffix(&mut self, suf: &str) { + let mut n = 0; + let mut i = 0; + while i < suf.len() { + let b = suf.chars().nth(i); + let mut x2 = 0; + let mut n2: usize; + loop { + let children = &self.nodes[n].ch; + if children.len() == x2 { + n2 = self.nodes.len(); + self.nodes.push(Node::new( + { + let (_, sub) = suf.split_at(i); + sub.to_string() + }, + vec![], + )); + self.nodes[n].ch.push(n2); + return; + } + n2 = children[x2]; + if self.nodes[n2].sub.chars().next() == b { + break; + } + x2 += 1; + } + let sub2 = self.nodes[n2].sub.clone(); + let mut j = 0; + while j < sub2.len() { + if suf.chars().nth(i + j) != sub2.chars().nth(j) { + let n3 = n2; + n2 = self.nodes.len(); + self.nodes.push(Node::new( + { + let (sub, _) = sub2.split_at(j); + sub.to_string() + }, + vec![n3], + )); + let (_, temp_sub) = sub2.split_at(j); + self.nodes[n3].sub = temp_sub.to_string(); + self.nodes[n].ch[x2] = n2; + break; + } + j += 1; + } + i += j; + n = n2; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_suffix_tree() { + let suf_tree = SuffixTree::new("banana$"); + assert_eq!( + suf_tree.nodes, + vec![ + Node { + sub: "".to_string(), + ch: vec![1, 8, 6, 10] + }, + Node { + sub: "banana$".to_string(), + ch: vec![] + }, + Node { + sub: "na$".to_string(), + ch: vec![] + }, + Node { + sub: "na$".to_string(), + ch: vec![] + }, + Node { + sub: "na".to_string(), + ch: vec![2, 5] + }, + Node { + sub: "$".to_string(), + ch: vec![] + }, + Node { + sub: "na".to_string(), + ch: vec![3, 7] + }, + Node { + sub: "$".to_string(), + ch: vec![] + }, + Node { + sub: "a".to_string(), + ch: vec![4, 9] + }, + Node { + sub: "$".to_string(), + ch: vec![] + }, + Node { + sub: "$".to_string(), + ch: vec![] + } + ] + ); + } +} diff --git a/src/string/z_algorithm.rs b/src/string/z_algorithm.rs new file mode 100644 index 00000000000..a2825e02ddc --- /dev/null +++ b/src/string/z_algorithm.rs @@ -0,0 +1,209 @@ +//! This module provides functionalities to match patterns in strings +//! and compute the Z-array for a given input string. + +/// Calculates the Z-value for a given substring of the input string +/// based on a specified pattern. +/// +/// # Parameters +/// - `input_string`: A slice of elements that represents the input string. +/// - `pattern`: A slice of elements representing the pattern to match. +/// - `start_index`: The index in the input string to start checking for matches. +/// - `z_value`: The initial Z-value to be computed. +/// +/// # Returns +/// The computed Z-value indicating the length of the matching prefix. +fn calculate_z_value( + input_string: &[T], + pattern: &[T], + start_index: usize, + mut z_value: usize, +) -> usize { + let size = input_string.len(); + let pattern_size = pattern.len(); + + while (start_index + z_value) < size && z_value < pattern_size { + if input_string[start_index + z_value] != pattern[z_value] { + break; + } + z_value += 1; + } + z_value +} + +/// Initializes the Z-array value based on a previous match and updates +/// it to optimize further calculations. +/// +/// # Parameters +/// - `z_array`: A mutable slice of the Z-array to be updated. +/// - `i`: The current index in the input string. +/// - `match_end`: The index of the last character matched in the pattern. +/// - `last_match`: The index of the last match found. +/// +/// # Returns +/// The initialized Z-array value for the current index. +fn initialize_z_array_from_previous_match( + z_array: &[usize], + i: usize, + match_end: usize, + last_match: usize, +) -> usize { + std::cmp::min(z_array[i - last_match], match_end - i + 1) +} + +/// Finds the starting indices of all full matches of the pattern +/// in the Z-array. +/// +/// # Parameters +/// - `z_array`: A slice of the Z-array containing computed Z-values. +/// - `pattern_size`: The length of the pattern to find in the Z-array. +/// +/// # Returns +/// A vector containing the starting indices of full matches. +fn find_full_matches(z_array: &[usize], pattern_size: usize) -> Vec { + z_array + .iter() + .enumerate() + .filter_map(|(idx, &z_value)| (z_value == pattern_size).then_some(idx)) + .collect() +} + +/// Matches the occurrences of a pattern in an input string starting +/// from a specified index. +/// +/// # Parameters +/// - `input_string`: A slice of elements to search within. +/// - `pattern`: A slice of elements that represents the pattern to match. +/// - `start_index`: The index in the input string to start the search. +/// - `only_full_matches`: If true, only full matches of the pattern will be returned. +/// +/// # Returns +/// A vector containing the starting indices of the matches. +fn match_with_z_array( + input_string: &[T], + pattern: &[T], + start_index: usize, + only_full_matches: bool, +) -> Vec { + let size = input_string.len(); + let pattern_size = pattern.len(); + let mut last_match: usize = 0; + let mut match_end: usize = 0; + let mut z_array = vec![0usize; size]; + + for i in start_index..size { + if i <= match_end { + z_array[i] = initialize_z_array_from_previous_match(&z_array, i, match_end, last_match); + } + + z_array[i] = calculate_z_value(input_string, pattern, i, z_array[i]); + + if i + z_array[i] > match_end + 1 { + match_end = i + z_array[i] - 1; + last_match = i; + } + } + + if !only_full_matches { + z_array + } else { + find_full_matches(&z_array, pattern_size) + } +} + +/// Constructs the Z-array for the given input string. +/// +/// The Z-array is an array where the i-th element is the length of the longest +/// substring starting from s[i] that is also a prefix of s. +/// +/// # Parameters +/// - `input`: A slice of the input string for which the Z-array is to be constructed. +/// +/// # Returns +/// A vector representing the Z-array of the input string. +pub fn z_array(input: &[T]) -> Vec { + match_with_z_array(input, input, 1, false) +} + +/// Matches the occurrences of a given pattern in an input string. +/// +/// This function acts as a wrapper around `match_with_z_array` to provide a simpler +/// interface for pattern matching, returning only full matches. +/// +/// # Parameters +/// - `input`: A slice of the input string where the pattern will be searched. +/// - `pattern`: A slice of the pattern to search for in the input string. +/// +/// # Returns +/// A vector of indices where the pattern matches the input string. +pub fn match_pattern(input: &[T], pattern: &[T]) -> Vec { + match_with_z_array(input, pattern, 0, true) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_match_pattern { + ($($name:ident: ($input:expr, $pattern:expr, $expected:expr),)*) => { + $( + #[test] + fn $name() { + let (input, pattern, expected) = ($input, $pattern, $expected); + assert_eq!(match_pattern(input.as_bytes(), pattern.as_bytes()), expected); + } + )* + }; + } + + macro_rules! test_z_array_cases { + ($($name:ident: ($input:expr, $expected:expr),)*) => { + $( + #[test] + fn $name() { + let (input, expected) = ($input, $expected); + assert_eq!(z_array(input.as_bytes()), expected); + } + )* + }; + } + + test_match_pattern! { + simple_match: ("abcabcabc", "abc", vec![0, 3, 6]), + no_match: ("abcdef", "xyz", vec![]), + single_char_match: ("aaaaaa", "a", vec![0, 1, 2, 3, 4, 5]), + overlapping_match: ("abababa", "aba", vec![0, 2, 4]), + full_string_match: ("pattern", "pattern", vec![0]), + empty_pattern: ("nonempty", " ", vec![]), + pattern_larger_than_text: ("small", "largerpattern", vec![]), + repeated_pattern_in_text: ( + "aaaaaaaa", + "aaa", + vec![0, 1, 2, 3, 4, 5] + ), + pattern_not_in_lipsum: ( + concat!( + "lorem ipsum dolor sit amet, consectetur ", + "adipiscing elit, sed do eiusmod tempor ", + "incididunt ut labore et dolore magna aliqua" + ), + ";alksdjfoiwer", + vec![] + ), + pattern_in_lipsum: ( + concat!( + "lorem ipsum dolor sit amet, consectetur ", + "adipiscing elit, sed do eiusmod tempor ", + "incididunt ut labore et dolore magna aliqua" + ), + "m", + vec![4, 10, 23, 68, 74, 110] + ), + } + + test_z_array_cases! { + basic_z_array: ("aabaabab", vec![0, 1, 0, 4, 1, 0, 1, 0]), + empty_string: ("", vec![]), + single_char_z_array: ("a", vec![0]), + repeated_char_z_array: ("aaaaaa", vec![0, 5, 4, 3, 2, 1]), + } +}