diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 42b3ffdd415..26e15bdf0fe 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @siriak @imp2002 +* @imp2002 @vil02 diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000000..7abaea9b883 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,13 @@ +--- +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/.github/workflows/" + schedule: + interval: "weekly" + + - package-ecosystem: "cargo" + directory: "/" + schedule: + interval: "daily" +... diff --git a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md b/.github/pull_request_template.md similarity index 80% rename from .github/PULL_REQUEST_TEMPLATE/pull_request_template.md rename to .github/pull_request_template.md index 6bcffa81bc0..3623445a8c7 100644 --- a/.github/PULL_REQUEST_TEMPLATE/pull_request_template.md +++ b/.github/pull_request_template.md @@ -17,9 +17,11 @@ Please delete options that are not relevant. ## Checklist: - [ ] I ran bellow commands using the latest version of **rust nightly**. -- [ ] I ran `cargo clippy --all -- -D warning` just before my last commit and fixed any issue that was found. +- [ ] I ran `cargo clippy --all -- -D warnings` just before my last commit and fixed any issue that was found. - [ ] I ran `cargo fmt` just before my last commit. - [ ] I ran `cargo test` just before my last commit and all tests passed. +- [ ] I added my algorithm to the corresponding `mod.rs` file within its own folder, and in any parent folder(s). +- [ ] I added my algorithm to `DIRECTORY.md` with the correct link. - [ ] I checked `COUNTRIBUTING.md` and my code follows its guidelines. Please make sure that if there is a test that takes too long to run ( > 300ms), you `#[ignore]` that or diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ff792c9da23..1ab85c40554 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,28 +1,35 @@ name: build -on: pull_request +'on': + pull_request: + workflow_dispatch: + schedule: + - cron: '51 2 * * 4' + +permissions: + contents: read jobs: fmt: name: cargo fmt runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: cargo fmt run: cargo fmt --all -- --check - + clippy: name: cargo clippy runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: cargo clippy - run: cargo clippy --all -- -D warnings - + run: cargo clippy --all --all-targets -- -D warnings + test: name: cargo test runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: cargo test run: cargo test diff --git a/.github/workflows/code_ql.yml b/.github/workflows/code_ql.yml new file mode 100644 index 00000000000..707822d15a3 --- /dev/null +++ b/.github/workflows/code_ql.yml @@ -0,0 +1,35 @@ +--- +name: code_ql + +'on': + workflow_dispatch: + push: + branches: + - master + pull_request: + schedule: + - cron: '10 7 * * 1' + +jobs: + analyze_actions: + name: Analyze Actions + runs-on: 'ubuntu-latest' + permissions: + actions: read + contents: read + security-events: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + languages: 'actions' + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 + with: + category: "/language:actions" +... diff --git a/.github/workflows/directory_workflow.yml b/.github/workflows/directory_workflow.yml index 51b2942e8c7..9595c7ad8cb 100644 --- a/.github/workflows/directory_workflow.yml +++ b/.github/workflows/directory_workflow.yml @@ -1,24 +1,30 @@ name: build_directory_md -on: [push, pull_request] +on: + push: + branches: [master] + +permissions: + contents: read jobs: MainSequence: name: DIRECTORY.md runs-on: ubuntu-latest steps: - - uses: actions/checkout@v1 # v2 is broken for git diff - - uses: actions/setup-python@v2 + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-python@v5 - name: Setup Git Specs run: | - git config --global user.name github-actions - git config --global user.email '${GITHUB_ACTOR}@users.noreply.github.com' + git config --global user.name "$GITHUB_ACTOR" + git config --global user.email "$GITHUB_ACTOR@users.noreply.github.com" git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/$GITHUB_REPOSITORY - name: Update DIRECTORY.md run: | - python .github/workflows/scripts/build_directory_md.py + cargo run --manifest-path=.github/workflows/scripts/build_directory/Cargo.toml - name: Commit DIRECTORY.md run: | git add DIRECTORY.md - git commit -m "updating DIRECTORY.md" || true - git diff DIRECTORY.md - git push --force origin HEAD:$GITHUB_REF || true + git commit -m "Update DIRECTORY.md [skip actions]" || true + git push origin HEAD:$GITHUB_REF || true diff --git a/.github/workflows/scripts/build_directory/Cargo.toml b/.github/workflows/scripts/build_directory/Cargo.toml new file mode 100644 index 00000000000..fd8a7760e45 --- /dev/null +++ b/.github/workflows/scripts/build_directory/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "build_directory" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] diff --git a/.github/workflows/scripts/build_directory/src/lib.rs b/.github/workflows/scripts/build_directory/src/lib.rs new file mode 100644 index 00000000000..2a5aba9700e --- /dev/null +++ b/.github/workflows/scripts/build_directory/src/lib.rs @@ -0,0 +1,124 @@ +use std::{ + error::Error, + fs, + path::{Path, PathBuf}, +}; + +static URL_BASE: &str = "https://github.com/TheAlgorithms/Rust/blob/master"; + +fn good_filepaths(top_dir: &Path) -> Result, Box> { + let mut good_fs = Vec::new(); + if top_dir.is_dir() { + for entry in fs::read_dir(top_dir)? { + let entry = entry?; + let path = entry.path(); + if entry.file_name().to_str().unwrap().starts_with('.') + || entry.file_name().to_str().unwrap().starts_with('_') + { + continue; + } + if path.is_dir() { + let mut other = good_filepaths(&path)?; + good_fs.append(&mut other); + } else if entry.file_name().to_str().unwrap().ends_with(".rs") + && entry.file_name().to_str().unwrap() != "mod.rs" + { + good_fs.push( + path.into_os_string() + .into_string() + .unwrap() + .split_at(2) + .1 + .to_string(), + ); + } + } + } + good_fs.sort(); + Ok(good_fs) +} + +fn md_prefix(indent_count: usize) -> String { + if indent_count > 0 { + format!("{}*", " ".repeat(indent_count)) + } else { + "\n##".to_string() + } +} + +fn print_path(old_path: String, new_path: String) -> (String, String) { + let old_parts = old_path + .split(std::path::MAIN_SEPARATOR) + .collect::>(); + let mut result = String::new(); + for (count, new_part) in new_path.split(std::path::MAIN_SEPARATOR).enumerate() { + if count + 1 > old_parts.len() || old_parts[count] != new_part { + println!("{} {}", md_prefix(count), to_title(new_part)); + result.push_str(format!("{} {}\n", md_prefix(count), to_title(new_part)).as_str()); + } + } + (new_path, result) +} + +pub fn build_directory_md(top_dir: &Path) -> Result> { + let mut old_path = String::from(""); + let mut result = String::new(); + for filepath in good_filepaths(top_dir)? { + let mut filepath = PathBuf::from(filepath); + let filename = filepath.file_name().unwrap().to_owned(); + filepath.pop(); + let filepath = filepath.into_os_string().into_string().unwrap(); + if filepath != old_path { + let path_res = print_path(old_path, filepath); + old_path = path_res.0; + result.push_str(path_res.1.as_str()); + } + let url = format!("{}/{}", old_path, filename.to_string_lossy()); + let url = get_addr(&url); + let indent = old_path.matches(std::path::MAIN_SEPARATOR).count() + 1; + let filename = to_title(filename.to_str().unwrap().split('.').collect::>()[0]); + println!("{} [{}]({})", md_prefix(indent), filename, url); + result.push_str(format!("{} [{}]({})\n", md_prefix(indent), filename, url).as_str()); + } + Ok(result) +} + +fn to_title(name: &str) -> String { + let mut change = true; + name.chars() + .map(move |letter| { + if change && !letter.is_numeric() { + change = false; + letter.to_uppercase().next().unwrap() + } else if letter == '_' { + change = true; + ' ' + } else { + if letter.is_numeric() || !letter.is_alphanumeric() { + change = true; + } + letter + } + }) + .collect::() +} + +fn get_addr(addr: &str) -> String { + if cfg!(windows) { + format!("{}/{}", URL_BASE, switch_backslash(addr)) + } else { + format!("{}/{}", URL_BASE, addr) + } +} + +// Function that changes '\' to '/' (for Windows builds only) +fn switch_backslash(addr: &str) -> String { + addr.chars() + .map(|mut symbol| { + if symbol == '\\' { + symbol = '/'; + } + symbol + }) + .collect::() +} diff --git a/.github/workflows/scripts/build_directory/src/main.rs b/.github/workflows/scripts/build_directory/src/main.rs new file mode 100644 index 00000000000..9a54f0c0e3e --- /dev/null +++ b/.github/workflows/scripts/build_directory/src/main.rs @@ -0,0 +1,17 @@ +use std::{fs::File, io::Write, path::Path}; + +use build_directory::build_directory_md; +fn main() -> Result<(), std::io::Error> { + let mut file = File::create("DIRECTORY.md").unwrap(); // unwrap for panic + + match build_directory_md(Path::new(".")) { + Ok(buf) => { + file.write_all("# List of all files\n".as_bytes())?; + file.write_all(buf.as_bytes())?; + } + Err(err) => { + panic!("Error while creating string: {err}"); + } + } + Ok(()) +} diff --git a/.github/workflows/scripts/build_directory_md.py b/.github/workflows/scripts/build_directory_md.py deleted file mode 100644 index 52cfeab153a..00000000000 --- a/.github/workflows/scripts/build_directory_md.py +++ /dev/null @@ -1,51 +0,0 @@ -import os - -from typing import Iterator - -URL_BASE = "https://github.com/TheAlgorithms/Rust/blob/master" - -g_output = [] - - -def good_filepaths(top_dir: str = ".") -> Iterator[str]: - fs_exts = tuple(".rs".split()) - for dirpath, dirnames, filenames in os.walk(top_dir): - dirnames[:] = [d for d in dirnames if d[0] not in "._"] - for filename in filenames: - if filename != "mod.rs" and os.path.splitext(filename)[1].lower() in fs_exts: - yield os.path.join(dirpath, filename).lstrip("./") - - -def md_prefix(i): - return f"{i * ' '}*" if i else "\n##" - - -def print_path(old_path: str, new_path: str) -> str: - global g_output - old_parts = old_path.split(os.sep) - for i, new_part in enumerate(new_path.split(os.sep)): - if i + 1 > len(old_parts) or old_parts[i] != new_part: - if new_part: - print(f"{md_prefix(i)} {new_part.replace('_', ' ').title()}") - g_output.append(f"{md_prefix(i)} {new_part.replace('_', ' ').title()}") - return new_path - - -def build_directory_md(top_dir: str = ".") -> str: - global g_output - old_path = "" - for filepath in sorted(good_filepaths(), key=str.lower): - filepath, filename = os.path.split(filepath) - if filepath != old_path: - old_path = print_path(old_path, filepath) - indent = (filepath.count(os.sep) + 1) if filepath else 0 - url = "/".join((URL_BASE, filepath, filename)).replace(" ", "%20") - filename = os.path.splitext(filename.replace("_", " ").title())[0] - print((f"{md_prefix(indent)} [{filename}]({url})")) - g_output.append(f"{md_prefix(indent)} [{filename}]({url})") - - return "# List of all files\n" + "\n".join(g_output) - - -with open("DIRECTORY.md", "w") as out_file: - out_file.write(build_directory_md(".") + "\n") diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index bbd94300d5c..3e99d1d726d 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -2,11 +2,17 @@ name: 'Close stale issues and PRs' on: schedule: - cron: '0 0 * * *' +permissions: + contents: read + jobs: stale: + permissions: + issues: write + pull-requests: write runs-on: ubuntu-latest steps: - - uses: actions/stale@v4 + - uses: actions/stale@v9 with: stale-issue-message: 'This issue has been automatically marked as abandoned because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.' close-issue-message: 'Please ping one of the maintainers once you add more information and updates here. If this is not the case and you need some help, feel free to ask for help in our [Gitter](https://gitter.im/TheAlgorithms) channel. Thank you for your contributions!' diff --git a/.github/workflows/upload_coverage_report.yml b/.github/workflows/upload_coverage_report.yml new file mode 100644 index 00000000000..ebe347c99e4 --- /dev/null +++ b/.github/workflows/upload_coverage_report.yml @@ -0,0 +1,48 @@ +--- +name: upload_coverage_report + +# yamllint disable-line rule:truthy +on: + workflow_dispatch: + push: + branches: + - master + pull_request: + +permissions: + contents: read + +env: + REPORT_NAME: "lcov.info" + +jobs: + upload_coverage_report: + runs-on: ubuntu-latest + env: + CARGO_TERM_COLOR: always + steps: + - uses: actions/checkout@v4 + - uses: taiki-e/install-action@cargo-llvm-cov + - name: Generate code coverage + run: > + cargo llvm-cov + --all-features + --workspace + --lcov + --output-path "${{ env.REPORT_NAME }}" + - name: Upload coverage to codecov (tokenless) + if: >- + github.event_name == 'pull_request' && + github.event.pull_request.head.repo.full_name != github.repository + uses: codecov/codecov-action@v5 + with: + files: "${{ env.REPORT_NAME }}" + fail_ci_if_error: true + - name: Upload coverage to codecov (with token) + if: "! github.event.pull_request.head.repo.fork " + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: "${{ env.REPORT_NAME }}" + fail_ci_if_error: true +... diff --git a/.gitpod.Dockerfile b/.gitpod.Dockerfile new file mode 100644 index 00000000000..7faaafa9c0c --- /dev/null +++ b/.gitpod.Dockerfile @@ -0,0 +1,3 @@ +FROM gitpod/workspace-rust:2024-06-05-14-45-28 + +USER gitpod diff --git a/.gitpod.yml b/.gitpod.yml new file mode 100644 index 00000000000..26a402b692b --- /dev/null +++ b/.gitpod.yml @@ -0,0 +1,11 @@ +--- +image: + file: .gitpod.Dockerfile + +tasks: + - init: cargo build + +vscode: + extensions: + - rust-lang.rust-analyzer +... diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index db5320c1708..00000000000 --- a/.travis.yml +++ /dev/null @@ -1,9 +0,0 @@ -language: rust - -before_script: - - rustup component add rustfmt-preview - - rustup component add clippy -script: - - cargo fmt --all -- --check - - cargo clippy --all -- -D warnings - - cargo test diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ba7e8e190e5..e6c0aecd597 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -39,7 +39,7 @@ mod tests { } ``` -## Before submitting you PR +## Before submitting your PR Do **not** use acronyms: `DFS` should be `depth_first_search`. diff --git a/Cargo.toml b/Cargo.toml index eabc8f02107..71d16d2eafb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,171 @@ [package] name = "the_algorithms_rust" +edition = "2021" version = "0.1.0" authors = ["Anshul Malik "] [dependencies] +num-bigint = { version = "0.4", optional = true } +num-traits = { version = "0.2", optional = true } +rand = "0.9" +nalgebra = "0.33.0" + +[dev-dependencies] +quickcheck = "1.0" +quickcheck_macros = "1.0" + +[features] +default = ["big-math"] +big-math = ["dep:num-bigint", "dep:num-traits"] + +[lints.clippy] +pedantic = "warn" +restriction = "warn" +nursery = "warn" +cargo = "warn" +# pedantic-lints: +cast_lossless = { level = "allow", priority = 1 } +cast_possible_truncation = { level = "allow", priority = 1 } +cast_possible_wrap = { level = "allow", priority = 1 } +cast_precision_loss = { level = "allow", priority = 1 } +cast_sign_loss = { level = "allow", priority = 1 } +cloned_instead_of_copied = { level = "allow", priority = 1 } +doc_markdown = { level = "allow", priority = 1 } +explicit_deref_methods = { level = "allow", priority = 1 } +explicit_iter_loop = { level = "allow", priority = 1 } +float_cmp = { level = "allow", priority = 1 } +if_not_else = { level = "allow", priority = 1 } +implicit_clone = { level = "allow", priority = 1 } +implicit_hasher = { level = "allow", priority = 1 } +items_after_statements = { level = "allow", priority = 1 } +iter_without_into_iter = { level = "allow", priority = 1 } +linkedlist = { level = "allow", priority = 1 } +manual_assert = { level = "allow", priority = 1 } +manual_let_else = { level = "allow", priority = 1 } +manual_string_new = { level = "allow", priority = 1 } +many_single_char_names = { level = "allow", priority = 1 } +match_on_vec_items = { level = "allow", priority = 1 } +match_wildcard_for_single_variants = { level = "allow", priority = 1 } +missing_errors_doc = { level = "allow", priority = 1 } +missing_fields_in_debug = { level = "allow", priority = 1 } +missing_panics_doc = { level = "allow", priority = 1 } +module_name_repetitions = { level = "allow", priority = 1 } +must_use_candidate = { level = "allow", priority = 1 } +needless_pass_by_value = { level = "allow", priority = 1 } +redundant_closure_for_method_calls = { level = "allow", priority = 1 } +return_self_not_must_use = { level = "allow", priority = 1 } +semicolon_if_nothing_returned = { level = "allow", priority = 1 } +should_panic_without_expect = { level = "allow", priority = 1 } +similar_names = { level = "allow", priority = 1 } +single_match_else = { level = "allow", priority = 1 } +stable_sort_primitive = { level = "allow", priority = 1 } +too_many_lines = { level = "allow", priority = 1 } +trivially_copy_pass_by_ref = { level = "allow", priority = 1 } +unnecessary_box_returns = { level = "allow", priority = 1 } +unnested_or_patterns = { level = "allow", priority = 1 } +unreadable_literal = { level = "allow", priority = 1 } +unused_self = { level = "allow", priority = 1 } +used_underscore_binding = { level = "allow", priority = 1 } +ref_option = { level = "allow", priority = 1 } +unnecessary_semicolon = { level = "allow", priority = 1 } +ignore_without_reason = { level = "allow", priority = 1 } +# restriction-lints: +absolute_paths = { level = "allow", priority = 1 } +arithmetic_side_effects = { level = "allow", priority = 1 } +as_conversions = { level = "allow", priority = 1 } +assertions_on_result_states = { level = "allow", priority = 1 } +blanket_clippy_restriction_lints = { level = "allow", priority = 1 } +clone_on_ref_ptr = { level = "allow", priority = 1 } +dbg_macro = { level = "allow", priority = 1 } +decimal_literal_representation = { level = "allow", priority = 1 } +default_numeric_fallback = { level = "allow", priority = 1 } +deref_by_slicing = { level = "allow", priority = 1 } +else_if_without_else = { level = "allow", priority = 1 } +exhaustive_enums = { level = "allow", priority = 1 } +exhaustive_structs = { level = "allow", priority = 1 } +expect_used = { level = "allow", priority = 1 } +float_arithmetic = { level = "allow", priority = 1 } +float_cmp_const = { level = "allow", priority = 1 } +if_then_some_else_none = { level = "allow", priority = 1 } +impl_trait_in_params = { level = "allow", priority = 1 } +implicit_return = { level = "allow", priority = 1 } +indexing_slicing = { level = "allow", priority = 1 } +integer_division = { level = "allow", priority = 1 } +integer_division_remainder_used = { level = "allow", priority = 1 } +iter_over_hash_type = { level = "allow", priority = 1 } +little_endian_bytes = { level = "allow", priority = 1 } +map_err_ignore = { level = "allow", priority = 1 } +min_ident_chars = { level = "allow", priority = 1 } +missing_assert_message = { level = "allow", priority = 1 } +missing_asserts_for_indexing = { level = "allow", priority = 1 } +missing_docs_in_private_items = { level = "allow", priority = 1 } +missing_inline_in_public_items = { level = "allow", priority = 1 } +missing_trait_methods = { level = "allow", priority = 1 } +mod_module_files = { level = "allow", priority = 1 } +modulo_arithmetic = { level = "allow", priority = 1 } +multiple_unsafe_ops_per_block = { level = "allow", priority = 1 } +non_ascii_literal = { level = "allow", priority = 1 } +panic = { level = "allow", priority = 1 } +partial_pub_fields = { level = "allow", priority = 1 } +pattern_type_mismatch = { level = "allow", priority = 1 } +print_stderr = { level = "allow", priority = 1 } +print_stdout = { level = "allow", priority = 1 } +pub_use = { level = "allow", priority = 1 } +pub_with_shorthand = { level = "allow", priority = 1 } +question_mark_used = { level = "allow", priority = 1 } +same_name_method = { level = "allow", priority = 1 } +semicolon_outside_block = { level = "allow", priority = 1 } +separated_literal_suffix = { level = "allow", priority = 1 } +shadow_reuse = { level = "allow", priority = 1 } +shadow_same = { level = "allow", priority = 1 } +shadow_unrelated = { level = "allow", priority = 1 } +single_call_fn = { level = "allow", priority = 1 } +single_char_lifetime_names = { level = "allow", priority = 1 } +std_instead_of_alloc = { level = "allow", priority = 1 } +std_instead_of_core = { level = "allow", priority = 1 } +str_to_string = { level = "allow", priority = 1 } +string_add = { level = "allow", priority = 1 } +string_slice = { level = "allow", priority = 1 } +undocumented_unsafe_blocks = { level = "allow", priority = 1 } +unnecessary_safety_comment = { level = "allow", priority = 1 } +unreachable = { level = "allow", priority = 1 } +unseparated_literal_suffix = { level = "allow", priority = 1 } +unwrap_in_result = { level = "allow", priority = 1 } +unwrap_used = { level = "allow", priority = 1 } +use_debug = { level = "allow", priority = 1 } +wildcard_enum_match_arm = { level = "allow", priority = 1 } +renamed_function_params = { level = "allow", priority = 1 } +allow_attributes_without_reason = { level = "allow", priority = 1 } +allow_attributes = { level = "allow", priority = 1 } +cfg_not_test = { level = "allow", priority = 1 } +field_scoped_visibility_modifiers = { level = "allow", priority = 1 } +unused_trait_names = { level = "allow", priority = 1 } +used_underscore_items = { level = "allow", priority = 1 } +arbitrary_source_item_ordering = { level = "allow", priority = 1 } +map_with_unused_argument_over_ranges = { level = "allow", priority = 1 } +precedence_bits = { level = "allow", priority = 1 } +redundant_test_prefix = { level = "allow", priority = 1 } +# nursery-lints: +branches_sharing_code = { level = "allow", priority = 1 } +cognitive_complexity = { level = "allow", priority = 1 } +derive_partial_eq_without_eq = { level = "allow", priority = 1 } +empty_line_after_doc_comments = { level = "allow", priority = 1 } +fallible_impl_from = { level = "allow", priority = 1 } +imprecise_flops = { level = "allow", priority = 1 } +missing_const_for_fn = { level = "allow", priority = 1 } +nonstandard_macro_braces = { level = "allow", priority = 1 } +option_if_let_else = { level = "allow", priority = 1 } +suboptimal_flops = { level = "allow", priority = 1 } +suspicious_operation_groupings = { level = "allow", priority = 1 } +use_self = { level = "allow", priority = 1 } +while_float = { level = "allow", priority = 1 } +too_long_first_doc_paragraph = { level = "allow", priority = 1 } +# cargo-lints: +cargo_common_metadata = { level = "allow", priority = 1 } +# style-lints: +doc_lazy_continuation = { level = "allow", priority = 1 } +needless_return = { level = "allow", priority = 1 } +doc_overindented_list_items = { level = "allow", priority = 1 } +# complexity-lints +precedence = { level = "allow", priority = 1 } +manual_div_ceil = { level = "allow", priority = 1 } diff --git a/DIRECTORY.md b/DIRECTORY.md index a053308fa8a..564a7813807 100644 --- a/DIRECTORY.md +++ b/DIRECTORY.md @@ -1,117 +1,344 @@ # List of all files ## Src + * Backtracking + * [All Combination Of Size K](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/all_combination_of_size_k.rs) + * [Graph Coloring](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/graph_coloring.rs) + * [Hamiltonian Cycle](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/hamiltonian_cycle.rs) + * [Knight Tour](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/knight_tour.rs) + * [N Queens](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/n_queens.rs) + * [Parentheses Generator](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/parentheses_generator.rs) + * [Permutations](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/permutations.rs) + * [Rat In Maze](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/rat_in_maze.rs) + * [Subset Sum](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/subset_sum.rs) + * [Sudoku](https://github.com/TheAlgorithms/Rust/blob/master/src/backtracking/sudoku.rs) + * Big Integer + * [Fast Factorial](https://github.com/TheAlgorithms/Rust/blob/master/src/big_integer/fast_factorial.rs) + * [Multiply](https://github.com/TheAlgorithms/Rust/blob/master/src/big_integer/multiply.rs) + * [Poly1305](https://github.com/TheAlgorithms/Rust/blob/master/src/big_integer/poly1305.rs) + * Bit Manipulation + * [Counting Bits](https://github.com/TheAlgorithms/Rust/blob/master/src/bit_manipulation/counting_bits.rs) + * [Highest Set Bit](https://github.com/TheAlgorithms/Rust/blob/master/src/bit_manipulation/highest_set_bit.rs) + * [N Bits Gray Code](https://github.com/TheAlgorithms/Rust/blob/master/src/bit_manipulation/n_bits_gray_code.rs) + * [Sum Of Two Integers](https://github.com/TheAlgorithms/Rust/blob/master/src/bit_manipulation/sum_of_two_integers.rs) * Ciphers + * [Aes](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/aes.rs) * [Another Rot13](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/another_rot13.rs) + * [Baconian Cipher](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/baconian_cipher.rs) + * [Base64](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/base64.rs) + * [Blake2B](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/blake2b.rs) * [Caesar](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/caesar.rs) + * [Chacha](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/chacha.rs) + * [Diffie Hellman](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/diffie_hellman.rs) + * [Hashing Traits](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/hashing_traits.rs) + * [Kerninghan](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/kerninghan.rs) * [Morse Code](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/morse_code.rs) * [Polybius](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/polybius.rs) + * [Rail Fence](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/rail_fence.rs) * [Rot13](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/rot13.rs) + * [Salsa](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/salsa.rs) * [Sha256](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/sha256.rs) - * [TEA](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/tea.rs) + * [Sha3](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/sha3.rs) + * [Tea](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/tea.rs) + * [Theoretical Rot13](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/theoretical_rot13.rs) * [Transposition](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/transposition.rs) * [Vigenere](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/vigenere.rs) * [Xor](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/xor.rs) - * [Salsa20](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/salsa.rs) - * [HMAC](https://github.com/TheAlgorithms/Rust/blob/master/src/ciphers/hashing_traits.rs) + * Compression + * [Move To Front](https://github.com/TheAlgorithms/Rust/blob/master/src/compression/move_to_front.rs) + * [Run Length Encoding](https://github.com/TheAlgorithms/Rust/blob/master/src/compression/run_length_encoding.rs) + * Conversions + * [Binary To Decimal](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/binary_to_decimal.rs) + * [Binary To Hexadecimal](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/binary_to_hexadecimal.rs) + * [Decimal To Binary](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/decimal_to_binary.rs) + * [Decimal To Hexadecimal](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/decimal_to_hexadecimal.rs) + * [Hexadecimal To Binary](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/hexadecimal_to_binary.rs) + * [Hexadecimal To Decimal](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/hexadecimal_to_decimal.rs) + * [Length Conversion](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/length_conversion.rs) + * [Octal To Binary](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/octal_to_binary.rs) + * [Octal To Decimal](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/octal_to_decimal.rs) + * [Rgb Cmyk Conversion](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/rgb_cmyk_conversion.rs) * Data Structures * [Avl Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/avl_tree.rs) * [B Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/b_tree.rs) * [Binary Search Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/binary_search_tree.rs) + * [Fenwick Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/fenwick_tree.rs) + * [Floyds Algorithm](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/floyds_algorithm.rs) * [Graph](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/graph.rs) + * [Hash Table](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/hash_table.rs) * [Heap](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/heap.rs) + * [Lazy Segment Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/lazy_segment_tree.rs) * [Linked List](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/linked_list.rs) + * Probabilistic + * [Bloom Filter](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/probabilistic/bloom_filter.rs) + * [Count Min Sketch](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/probabilistic/count_min_sketch.rs) * [Queue](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/queue.rs) + * [Range Minimum Query](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/range_minimum_query.rs) + * [Rb Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/rb_tree.rs) + * [Segment Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/segment_tree.rs) + * [Segment Tree Recursive](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/segment_tree_recursive.rs) * [Stack Using Singly Linked List](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/stack_using_singly_linked_list.rs) + * [Treap](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/treap.rs) * [Trie](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/trie.rs) + * [Union Find](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/union_find.rs) + * [Veb Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/veb_tree.rs) * Dynamic Programming * [Coin Change](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/coin_change.rs) - * [Edit Distance](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/edit_distance.rs) * [Egg Dropping](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/egg_dropping.rs) * [Fibonacci](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/fibonacci.rs) + * [Fractional Knapsack](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/fractional_knapsack.rs) * [Is Subsequence](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/is_subsequence.rs) * [Knapsack](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/knapsack.rs) * [Longest Common Subsequence](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/longest_common_subsequence.rs) + * [Longest Common Substring](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/longest_common_substring.rs) * [Longest Continuous Increasing Subsequence](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/longest_continuous_increasing_subsequence.rs) * [Longest Increasing Subsequence](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/longest_increasing_subsequence.rs) + * [Matrix Chain Multiply](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/matrix_chain_multiply.rs) * [Maximal Square](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/maximal_square.rs) * [Maximum Subarray](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/maximum_subarray.rs) + * [Minimum Cost Path](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/minimum_cost_path.rs) + * [Optimal Bst](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/optimal_bst.rs) * [Rod Cutting](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/rod_cutting.rs) + * [Snail](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/snail.rs) + * [Subset Generation](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/subset_generation.rs) + * [Trapped Rainwater](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/trapped_rainwater.rs) + * [Word Break](https://github.com/TheAlgorithms/Rust/blob/master/src/dynamic_programming/word_break.rs) + * Financial + * [Present Value](https://github.com/TheAlgorithms/Rust/blob/master/src/financial/present_value.rs) * General * [Convex Hull](https://github.com/TheAlgorithms/Rust/blob/master/src/general/convex_hull.rs) + * [Fisher Yates Shuffle](https://github.com/TheAlgorithms/Rust/blob/master/src/general/fisher_yates_shuffle.rs) + * [Genetic](https://github.com/TheAlgorithms/Rust/blob/master/src/general/genetic.rs) * [Hanoi](https://github.com/TheAlgorithms/Rust/blob/master/src/general/hanoi.rs) + * [Huffman Encoding](https://github.com/TheAlgorithms/Rust/blob/master/src/general/huffman_encoding.rs) + * [Kadane Algorithm](https://github.com/TheAlgorithms/Rust/blob/master/src/general/kadane_algorithm.rs) * [Kmeans](https://github.com/TheAlgorithms/Rust/blob/master/src/general/kmeans.rs) - * [Nqueens](https://github.com/TheAlgorithms/Rust/blob/master/src/general/nqueens.rs) + * [Mex](https://github.com/TheAlgorithms/Rust/blob/master/src/general/mex.rs) + * Permutations + * [Heap](https://github.com/TheAlgorithms/Rust/blob/master/src/general/permutations/heap.rs) + * [Naive](https://github.com/TheAlgorithms/Rust/blob/master/src/general/permutations/naive.rs) + * [Steinhaus Johnson Trotter](https://github.com/TheAlgorithms/Rust/blob/master/src/general/permutations/steinhaus_johnson_trotter.rs) * [Two Sum](https://github.com/TheAlgorithms/Rust/blob/master/src/general/two_sum.rs) - * [Huffman Encoding](https://github.com/TheAlgorithms/Rust/blob/master/src/general/huffman_encoding.rs) * Geometry * [Closest Points](https://github.com/TheAlgorithms/Rust/blob/master/src/geometry/closest_points.rs) + * [Graham Scan](https://github.com/TheAlgorithms/Rust/blob/master/src/geometry/graham_scan.rs) + * [Jarvis Scan](https://github.com/TheAlgorithms/Rust/blob/master/src/geometry/jarvis_scan.rs) + * [Point](https://github.com/TheAlgorithms/Rust/blob/master/src/geometry/point.rs) + * [Polygon Points](https://github.com/TheAlgorithms/Rust/blob/master/src/geometry/polygon_points.rs) + * [Ramer Douglas Peucker](https://github.com/TheAlgorithms/Rust/blob/master/src/geometry/ramer_douglas_peucker.rs) + * [Segment](https://github.com/TheAlgorithms/Rust/blob/master/src/geometry/segment.rs) * Graph + * [Astar](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/astar.rs) * [Bellman Ford](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/bellman_ford.rs) + * [Bipartite Matching](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/bipartite_matching.rs) * [Breadth First Search](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/breadth_first_search.rs) + * [Centroid Decomposition](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/centroid_decomposition.rs) + * [Decremental Connectivity](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/decremental_connectivity.rs) * [Depth First Search](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/depth_first_search.rs) * [Depth First Search Tic Tac Toe](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/depth_first_search_tic_tac_toe.rs) + * [Detect Cycle](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/detect_cycle.rs) * [Dijkstra](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/dijkstra.rs) + * [Dinic Maxflow](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/dinic_maxflow.rs) + * [Disjoint Set Union](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/disjoint_set_union.rs) + * [Eulerian Path](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/eulerian_path.rs) + * [Floyd Warshall](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/floyd_warshall.rs) + * [Ford Fulkerson](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/ford_fulkerson.rs) + * [Graph Enumeration](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/graph_enumeration.rs) + * [Heavy Light Decomposition](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/heavy_light_decomposition.rs) + * [Kosaraju](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/kosaraju.rs) + * [Lee Breadth First Search](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/lee_breadth_first_search.rs) + * [Lowest Common Ancestor](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/lowest_common_ancestor.rs) * [Minimum Spanning Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/minimum_spanning_tree.rs) * [Prim](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/prim.rs) * [Prufer Code](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/prufer_code.rs) - * [Lowest Common Ancestor](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/lowest_common_ancestor.rs) - * [Disjoint Set Union](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/disjoint_set_union.rs) - * [Heavy Light Decomposition](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/heavy_light_decomposition.rs) - * [Tarjan's Strongly Connected Components](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/strongly_connected_components.rs) - * [Centroid Decomposition](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/centroid_decomposition.rs) - * [Dinic's Max Flow](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/dinic_maxflow.rs) - * [2-SAT Problem](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/two_satisfiability.rs) + * [Strongly Connected Components](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/strongly_connected_components.rs) + * [Tarjans Ssc](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/tarjans_ssc.rs) + * [Topological Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/topological_sort.rs) + * [Two Satisfiability](https://github.com/TheAlgorithms/Rust/blob/master/src/graph/two_satisfiability.rs) + * Greedy + * [Stable Matching](https://github.com/TheAlgorithms/Rust/blob/master/src/greedy/stable_matching.rs) * [Lib](https://github.com/TheAlgorithms/Rust/blob/master/src/lib.rs) + * Machine Learning + * [Cholesky](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/cholesky.rs) + * [K Means](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/k_means.rs) + * [Linear Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/linear_regression.rs) + * [Logistic Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/logistic_regression.rs) + * Loss Function + * [Average Margin Ranking Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/average_margin_ranking_loss.rs) + * [Hinge Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/hinge_loss.rs) + * [Huber Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/huber_loss.rs) + * [Kl Divergence Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/kl_divergence_loss.rs) + * [Mean Absolute Error Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/mean_absolute_error_loss.rs) + * [Mean Squared Error Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/mean_squared_error_loss.rs) + * [Negative Log Likelihood](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/negative_log_likelihood.rs) + * Optimization + * [Adam](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/optimization/adam.rs) + * [Gradient Descent](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/optimization/gradient_descent.rs) * Math - * [Baby-Step Giant-Step Algorithm](https://github.com/TheAlgorithms/Rust/blob/master/src/math/baby_step_giant_step.rs) + * [Abs](https://github.com/TheAlgorithms/Rust/blob/master/src/math/abs.rs) + * [Aliquot Sum](https://github.com/TheAlgorithms/Rust/blob/master/src/math/aliquot_sum.rs) + * [Amicable Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/amicable_numbers.rs) + * [Area Of Polygon](https://github.com/TheAlgorithms/Rust/blob/master/src/math/area_of_polygon.rs) + * [Area Under Curve](https://github.com/TheAlgorithms/Rust/blob/master/src/math/area_under_curve.rs) + * [Armstrong Number](https://github.com/TheAlgorithms/Rust/blob/master/src/math/armstrong_number.rs) + * [Average](https://github.com/TheAlgorithms/Rust/blob/master/src/math/average.rs) + * [Baby Step Giant Step](https://github.com/TheAlgorithms/Rust/blob/master/src/math/baby_step_giant_step.rs) + * [Bell Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/bell_numbers.rs) + * [Binary Exponentiation](https://github.com/TheAlgorithms/Rust/blob/master/src/math/binary_exponentiation.rs) + * [Binomial Coefficient](https://github.com/TheAlgorithms/Rust/blob/master/src/math/binomial_coefficient.rs) + * [Catalan Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/catalan_numbers.rs) + * [Ceil](https://github.com/TheAlgorithms/Rust/blob/master/src/math/ceil.rs) + * [Chinese Remainder Theorem](https://github.com/TheAlgorithms/Rust/blob/master/src/math/chinese_remainder_theorem.rs) + * [Collatz Sequence](https://github.com/TheAlgorithms/Rust/blob/master/src/math/collatz_sequence.rs) + * [Combinations](https://github.com/TheAlgorithms/Rust/blob/master/src/math/combinations.rs) + * [Cross Entropy Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/math/cross_entropy_loss.rs) + * [Decimal To Fraction](https://github.com/TheAlgorithms/Rust/blob/master/src/math/decimal_to_fraction.rs) + * [Doomsday](https://github.com/TheAlgorithms/Rust/blob/master/src/math/doomsday.rs) + * [Elliptic Curve](https://github.com/TheAlgorithms/Rust/blob/master/src/math/elliptic_curve.rs) + * [Euclidean Distance](https://github.com/TheAlgorithms/Rust/blob/master/src/math/euclidean_distance.rs) + * [Exponential Linear Unit](https://github.com/TheAlgorithms/Rust/blob/master/src/math/exponential_linear_unit.rs) * [Extended Euclidean Algorithm](https://github.com/TheAlgorithms/Rust/blob/master/src/math/extended_euclidean_algorithm.rs) + * [Factorial](https://github.com/TheAlgorithms/Rust/blob/master/src/math/factorial.rs) + * [Factors](https://github.com/TheAlgorithms/Rust/blob/master/src/math/factors.rs) + * [Fast Fourier Transform](https://github.com/TheAlgorithms/Rust/blob/master/src/math/fast_fourier_transform.rs) + * [Fast Power](https://github.com/TheAlgorithms/Rust/blob/master/src/math/fast_power.rs) + * [Faster Perfect Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/faster_perfect_numbers.rs) + * [Field](https://github.com/TheAlgorithms/Rust/blob/master/src/math/field.rs) + * [Frizzy Number](https://github.com/TheAlgorithms/Rust/blob/master/src/math/frizzy_number.rs) + * [Gaussian Elimination](https://github.com/TheAlgorithms/Rust/blob/master/src/math/gaussian_elimination.rs) + * [Gaussian Error Linear Unit](https://github.com/TheAlgorithms/Rust/blob/master/src/math/gaussian_error_linear_unit.rs) + * [Gcd Of N Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/gcd_of_n_numbers.rs) + * [Geometric Series](https://github.com/TheAlgorithms/Rust/blob/master/src/math/geometric_series.rs) * [Greatest Common Divisor](https://github.com/TheAlgorithms/Rust/blob/master/src/math/greatest_common_divisor.rs) + * [Huber Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/math/huber_loss.rs) + * [Infix To Postfix](https://github.com/TheAlgorithms/Rust/blob/master/src/math/infix_to_postfix.rs) + * [Interest](https://github.com/TheAlgorithms/Rust/blob/master/src/math/interest.rs) + * [Interpolation](https://github.com/TheAlgorithms/Rust/blob/master/src/math/interpolation.rs) + * [Interquartile Range](https://github.com/TheAlgorithms/Rust/blob/master/src/math/interquartile_range.rs) + * [Karatsuba Multiplication](https://github.com/TheAlgorithms/Rust/blob/master/src/math/karatsuba_multiplication.rs) + * [Lcm Of N Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/lcm_of_n_numbers.rs) + * [Leaky Relu](https://github.com/TheAlgorithms/Rust/blob/master/src/math/leaky_relu.rs) + * [Least Square Approx](https://github.com/TheAlgorithms/Rust/blob/master/src/math/least_square_approx.rs) + * [Linear Sieve](https://github.com/TheAlgorithms/Rust/blob/master/src/math/linear_sieve.rs) + * [Logarithm](https://github.com/TheAlgorithms/Rust/blob/master/src/math/logarithm.rs) + * [Lucas Series](https://github.com/TheAlgorithms/Rust/blob/master/src/math/lucas_series.rs) + * [Matrix Ops](https://github.com/TheAlgorithms/Rust/blob/master/src/math/matrix_ops.rs) + * [Mersenne Primes](https://github.com/TheAlgorithms/Rust/blob/master/src/math/mersenne_primes.rs) + * [Miller Rabin](https://github.com/TheAlgorithms/Rust/blob/master/src/math/miller_rabin.rs) + * [Modular Exponential](https://github.com/TheAlgorithms/Rust/blob/master/src/math/modular_exponential.rs) + * [Newton Raphson](https://github.com/TheAlgorithms/Rust/blob/master/src/math/newton_raphson.rs) + * [Nthprime](https://github.com/TheAlgorithms/Rust/blob/master/src/math/nthprime.rs) * [Pascal Triangle](https://github.com/TheAlgorithms/Rust/blob/master/src/math/pascal_triangle.rs) + * [Perfect Cube](https://github.com/TheAlgorithms/Rust/blob/master/src/math/perfect_cube.rs) * [Perfect Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/perfect_numbers.rs) + * [Perfect Square](https://github.com/TheAlgorithms/Rust/blob/master/src/math/perfect_square.rs) + * [Pollard Rho](https://github.com/TheAlgorithms/Rust/blob/master/src/math/pollard_rho.rs) + * [Postfix Evaluation](https://github.com/TheAlgorithms/Rust/blob/master/src/math/postfix_evaluation.rs) * [Prime Check](https://github.com/TheAlgorithms/Rust/blob/master/src/math/prime_check.rs) + * [Prime Factors](https://github.com/TheAlgorithms/Rust/blob/master/src/math/prime_factors.rs) * [Prime Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/prime_numbers.rs) - * [Trial Division](https://github.com/TheAlgorithms/Rust/blob/master/src/math/trial_division.rs) - * [Miller Rabin](https://github.com/TheAlgorithms/Rust/blob/master/src/math/miller_rabin.rs) - * [Linear Sieve](https://github.com/TheAlgorithms/Rust/blob/master/src/math/linear_sieve.rs) - * [Pollard's Rho algorithm](https://github.com/TheAlgorithms/Rust/blob/master/src/math/pollard_rho.rs) * [Quadratic Residue](https://github.com/TheAlgorithms/Rust/blob/master/src/math/quadratic_residue.rs) - * [Simpson's Rule](https://github.com/TheAlgorithms/Rust/blob/master/src/math/simpson_integration.rs) - * [Fast Fourier Transform](https://github.com/TheAlgorithms/Rust/blob/master/src/math/fast_fourier_transform.rs) - * [Armstrong Number](https://github.com/TheAlgorithms/Rust/blob/master/src/math/armstrong_number.rs) - * [Permuted Congruential Random Number Generator](https://github.com/TheAlgorithms/Rust/blob/master/src/math/random.rs) + * [Random](https://github.com/TheAlgorithms/Rust/blob/master/src/math/random.rs) + * [Relu](https://github.com/TheAlgorithms/Rust/blob/master/src/math/relu.rs) + * [Sieve Of Eratosthenes](https://github.com/TheAlgorithms/Rust/blob/master/src/math/sieve_of_eratosthenes.rs) + * [Sigmoid](https://github.com/TheAlgorithms/Rust/blob/master/src/math/sigmoid.rs) + * [Signum](https://github.com/TheAlgorithms/Rust/blob/master/src/math/signum.rs) + * [Simpsons Integration](https://github.com/TheAlgorithms/Rust/blob/master/src/math/simpsons_integration.rs) + * [Softmax](https://github.com/TheAlgorithms/Rust/blob/master/src/math/softmax.rs) + * [Sprague Grundy Theorem](https://github.com/TheAlgorithms/Rust/blob/master/src/math/sprague_grundy_theorem.rs) + * [Square Pyramidal Numbers](https://github.com/TheAlgorithms/Rust/blob/master/src/math/square_pyramidal_numbers.rs) + * [Square Root](https://github.com/TheAlgorithms/Rust/blob/master/src/math/square_root.rs) + * [Sum Of Digits](https://github.com/TheAlgorithms/Rust/blob/master/src/math/sum_of_digits.rs) + * [Sum Of Geometric Progression](https://github.com/TheAlgorithms/Rust/blob/master/src/math/sum_of_geometric_progression.rs) + * [Sum Of Harmonic Series](https://github.com/TheAlgorithms/Rust/blob/master/src/math/sum_of_harmonic_series.rs) + * [Sylvester Sequence](https://github.com/TheAlgorithms/Rust/blob/master/src/math/sylvester_sequence.rs) + * [Tanh](https://github.com/TheAlgorithms/Rust/blob/master/src/math/tanh.rs) + * [Trapezoidal Integration](https://github.com/TheAlgorithms/Rust/blob/master/src/math/trapezoidal_integration.rs) + * [Trial Division](https://github.com/TheAlgorithms/Rust/blob/master/src/math/trial_division.rs) + * [Trig Functions](https://github.com/TheAlgorithms/Rust/blob/master/src/math/trig_functions.rs) + * [Vector Cross Product](https://github.com/TheAlgorithms/Rust/blob/master/src/math/vector_cross_product.rs) + * [Zellers Congruence Algorithm](https://github.com/TheAlgorithms/Rust/blob/master/src/math/zellers_congruence_algorithm.rs) + * Navigation + * [Bearing](https://github.com/TheAlgorithms/Rust/blob/master/src/navigation/bearing.rs) + * [Haversine](https://github.com/TheAlgorithms/Rust/blob/master/src/navigation/haversine.rs) + * Number Theory + * [Compute Totient](https://github.com/TheAlgorithms/Rust/blob/master/src/number_theory/compute_totient.rs) + * [Euler Totient](https://github.com/TheAlgorithms/Rust/blob/master/src/number_theory/euler_totient.rs) + * [Kth Factor](https://github.com/TheAlgorithms/Rust/blob/master/src/number_theory/kth_factor.rs) * Searching * [Binary Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/binary_search.rs) * [Binary Search Recursive](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/binary_search_recursive.rs) - * [Ternary Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/ternary_search.rs) - * [Ternary Search Recursive](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/ternary_search_recursive.rs) - * [Ternary Minimum Maximum Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/ternary_search_min_max.rs) - * [Ternary Minimum Maximum Search Recursive](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/ternary_search_min_max_recursive.rs) + * [Exponential Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/exponential_search.rs) + * [Fibonacci Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/fibonacci_search.rs) + * [Interpolation Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/interpolation_search.rs) + * [Jump Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/jump_search.rs) * [Kth Smallest](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/kth_smallest.rs) * [Kth Smallest Heap](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/kth_smallest_heap.rs) * [Linear Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/linear_search.rs) + * [Moore Voting](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/moore_voting.rs) + * [Quick Select](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/quick_select.rs) + * [Saddleback Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/saddleback_search.rs) + * [Ternary Search](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/ternary_search.rs) + * [Ternary Search Min Max](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/ternary_search_min_max.rs) + * [Ternary Search Min Max Recursive](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/ternary_search_min_max_recursive.rs) + * [Ternary Search Recursive](https://github.com/TheAlgorithms/Rust/blob/master/src/searching/ternary_search_recursive.rs) * Sorting + * [Bead Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/bead_sort.rs) + * [Binary Insertion Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/binary_insertion_sort.rs) + * [Bingo Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/bingo_sort.rs) + * [Bitonic Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/bitonic_sort.rs) + * [Bogo Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/bogo_sort.rs) * [Bubble Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/bubble_sort.rs) * [Bucket Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/bucket_sort.rs) * [Cocktail Shaker Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/cocktail_shaker_sort.rs) * [Comb Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/comb_sort.rs) * [Counting Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/counting_sort.rs) + * [Cycle Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/cycle_sort.rs) + * [Dutch National Flag Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/dutch_national_flag_sort.rs) + * [Exchange Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/exchange_sort.rs) * [Gnome Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/gnome_sort.rs) * [Heap Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/heap_sort.rs) * [Insertion Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/insertion_sort.rs) + * [Intro Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/intro_sort.rs) * [Merge Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/merge_sort.rs) * [Odd Even Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/odd_even_sort.rs) + * [Pancake Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/pancake_sort.rs) + * [Patience Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/patience_sort.rs) + * [Pigeonhole Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/pigeonhole_sort.rs) * [Quick Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/quick_sort.rs) + * [Quick Sort 3_ways](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/quick_sort_3_ways.rs) * [Radix Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/radix_sort.rs) * [Selection Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/selection_sort.rs) * [Shell Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/shell_sort.rs) + * [Sleep Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/sleep_sort.rs) + * [Sort Utils](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/sort_utils.rs) * [Stooge Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/stooge_sort.rs) + * [Tim Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/tim_sort.rs) + * [Tree Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/tree_sort.rs) + * [Wave Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/wave_sort.rs) + * [Wiggle Sort](https://github.com/TheAlgorithms/Rust/blob/master/src/sorting/wiggle_sort.rs) * String * [Aho Corasick](https://github.com/TheAlgorithms/Rust/blob/master/src/string/aho_corasick.rs) + * [Anagram](https://github.com/TheAlgorithms/Rust/blob/master/src/string/anagram.rs) + * [Autocomplete Using Trie](https://github.com/TheAlgorithms/Rust/blob/master/src/string/autocomplete_using_trie.rs) + * [Boyer Moore Search](https://github.com/TheAlgorithms/Rust/blob/master/src/string/boyer_moore_search.rs) * [Burrows Wheeler Transform](https://github.com/TheAlgorithms/Rust/blob/master/src/string/burrows_wheeler_transform.rs) + * [Duval Algorithm](https://github.com/TheAlgorithms/Rust/blob/master/src/string/duval_algorithm.rs) + * [Hamming Distance](https://github.com/TheAlgorithms/Rust/blob/master/src/string/hamming_distance.rs) + * [Isogram](https://github.com/TheAlgorithms/Rust/blob/master/src/string/isogram.rs) + * [Isomorphism](https://github.com/TheAlgorithms/Rust/blob/master/src/string/isomorphism.rs) + * [Jaro Winkler Distance](https://github.com/TheAlgorithms/Rust/blob/master/src/string/jaro_winkler_distance.rs) * [Knuth Morris Pratt](https://github.com/TheAlgorithms/Rust/blob/master/src/string/knuth_morris_pratt.rs) + * [Levenshtein Distance](https://github.com/TheAlgorithms/Rust/blob/master/src/string/levenshtein_distance.rs) + * [Lipogram](https://github.com/TheAlgorithms/Rust/blob/master/src/string/lipogram.rs) * [Manacher](https://github.com/TheAlgorithms/Rust/blob/master/src/string/manacher.rs) + * [Palindrome](https://github.com/TheAlgorithms/Rust/blob/master/src/string/palindrome.rs) + * [Pangram](https://github.com/TheAlgorithms/Rust/blob/master/src/string/pangram.rs) * [Rabin Karp](https://github.com/TheAlgorithms/Rust/blob/master/src/string/rabin_karp.rs) * [Reverse](https://github.com/TheAlgorithms/Rust/blob/master/src/string/reverse.rs) + * [Run Length Encoding](https://github.com/TheAlgorithms/Rust/blob/master/src/string/run_length_encoding.rs) + * [Shortest Palindrome](https://github.com/TheAlgorithms/Rust/blob/master/src/string/shortest_palindrome.rs) + * [Suffix Array](https://github.com/TheAlgorithms/Rust/blob/master/src/string/suffix_array.rs) + * [Suffix Array Manber Myers](https://github.com/TheAlgorithms/Rust/blob/master/src/string/suffix_array_manber_myers.rs) + * [Suffix Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/string/suffix_tree.rs) * [Z Algorithm](https://github.com/TheAlgorithms/Rust/blob/master/src/string/z_algorithm.rs) - * [Hamming Distance](https://github.com/TheAlgorithms/Rust/blob/master/src/string/hamming_distance.rs) diff --git a/README.md b/README.md index 335c9200ebe..fcab70096eb 100644 --- a/README.md +++ b/README.md @@ -1,166 +1,32 @@ -# The Algorithms - Rust [![Gitter](https://img.shields.io/gitter/room/the-algorithms/rust.svg?style=flat-square)](https://gitter.im/the-algorithms/rust) [![build](https://github.com/TheAlgorithms/Rust/actions/workflows/build.yml/badge.svg)](https://github.com/TheAlgorithms/Rust/actions/workflows/build.yml) - - - -### All algorithms implemented in Rust - -These are for demonstration purposes only. - -## [Sort Algorithms](./src/sorting) - -- [x] [Bubble](./src/sorting/bubble_sort.rs) -- [X] [Bucket](./src/sorting/bucket_sort.rs) -- [x] [Cocktail-Shaker](./src/sorting/cocktail_shaker_sort.rs) -- [x] [Counting](./src/sorting/counting_sort.rs) -- [x] [Cycle](./src/sorting/cycle_sort.rs) -- [x] [Exchange](./src/sorting/exchange_sort.rs) -- [x] [Heap](./src/sorting/heap_sort.rs) -- [x] [Insertion](./src/sorting/insertion_sort.rs) -- [x] [Gnome](./src/sorting/gnome_sort.rs) -- [x] [Merge](./src/sorting/merge_sort.rs) -- [x] [Odd-even](./src/sorting/odd_even_sort.rs) -- [x] [Pancake](./src/sorting/pancake_sort.rs) -- [x] [Pigeonhole](./src/sorting/pigeonhole_sort.rs) -- [x] [Quick](./src/sorting/quick_sort.rs) -- [x] [Radix](./src/sorting/radix_sort.rs) -- [x] [Selection](./src/sorting/selection_sort.rs) -- [x] [Shell](./src/sorting/shell_sort.rs) -- [x] [Stooge](./src/sorting/stooge_sort.rs) -- [x] [Comb](./src/sorting/comb_sort.rs) -- [x] [Bucket](./src/sorting/bucket_sort.rs) -- [x] [Timsort](./src/sorting/tim_sort.rs) - -## [Graphs](./src/graph) - -- [x] [Dijkstra](./src/graph/dijkstra.rs) -- [x] [Kruskal's Minimum Spanning Tree](./src/graph/minimum_spanning_tree.rs) -- [x] [Prim's Minimum Spanning Tree](./src/graph/prim.rs) -- [x] [Breadth-First Search (BFS)](./src/graph/breadth_first_search.rs) -- [x] [Depth First Search (DFS)](./src/graph/depth_first_search.rs) -- [x] [Bellman-Ford](./src/graph/bellman_ford.rs) -- [x] [Prufer Code](./src/graph/prufer_code.rs) -- [x] [Lowest Common Ancestor](./src/graph/lowest_common_ancestor.rs) -- [x] [Heavy Light Decomposition](./src/graph/heavy_light_decomposition.rs) -- [x] [Tarjan's Strongly Connected Components](./src/graph/strongly_connected_components.rs) -- [x] [Topological sorting](./src/graph/topological_sort.rs) -- [x] [Centroid Decomposition](./src/graph/centroid_decomposition.rs) -- [x] [Dinic's Max Flow](./src/graph/dinic_maxflow.rs) -- [x] [2-SAT Problem](./src/graph/two_satisfiability.rs) - -## [Math](./src/math) -- [x] [Baby-Step Giant-Step Algorithm](./src/math/baby_step_giant_step.rs) -- [x] [Extended euclidean algorithm](./src/math/extended_euclidean_algorithm.rs) -- [x] [Gaussian Elimination](./src/math/gaussian_elimination.rs) -- [x] [Greatest common divisor](./src/math/greatest_common_divisor.rs) -- [x] [Greatest common divisor of n numbers](./src/math/gcd_of_n_numbers.rs) -- [x] [Least common multiple of n numbers](./src/math/lcm_of_n_numbers.rs) -- [x] [Miller Rabin primality test](./src/math/miller_rabin.rs) -- [x] [Pascal's triangle](./src/math/pascal_triangle.rs) -- [x] [Square root with Newton's method](./src/math/square_root.rs) -- [x] [Fast power algorithm](./src/math/fast_power.rs) -- [X] [Perfect number](./src/math/perfect_numbers.rs) -- [X] [Prime factors](./src/math/prime_factors.rs) -- [X] [Prime number](./src/math/prime_numbers.rs) -- [x] [Linear Sieve](./src/math/linear_sieve.rs) -- [x] [Pollard's Rho algorithm](./src/math/pollard_rho.rs) -- [x] [Quadratic Residue](./src/math/quadratic_residue.rs) -- [x] [Simpson's Rule for Integration](./src/math/simpson_integration.rs) -- [x] [Fast Fourier Transform](./src/math/fast_fourier_transform.rs) -- [x] [Armstrong Number](./src/math/armstrong_number.rs) -- [x] [Permuted Congruential Random Number Generator](./src/math/random.rs) -- [x] [Zeller's Congruence Algorithm](./src/math/zellers_congruence_algorithm.rs) -- [x] [Karatsuba Multiplication Algorithm](./src/math/karatsuba_multiplication.rs) - -## [Dynamic Programming](./src/dynamic_programming) - -- [x] [0-1 Knapsack](./src/dynamic_programming/knapsack.rs) -- [x] [Edit Distance](./src/dynamic_programming/edit_distance.rs) -- [x] [Longest common subsequence](./src/dynamic_programming/longest_common_subsequence.rs) -- [x] [Longest continuous increasing subsequence](./src/dynamic_programming/longest_continuous_increasing_subsequence.rs) -- [x] [Longest increasing subsequence](./src/dynamic_programming/longest_increasing_subsequence.rs) -- [x] [K-Means Clustering](./src/general/kmeans.rs) -- [x] [Coin Change](./src/dynamic_programming/coin_change.rs) -- [x] [Rod Cutting](./src/dynamic_programming/rod_cutting.rs) -- [x] [Egg Dropping Puzzle](./src/dynamic_programming/egg_dropping.rs) -- [x] [Maximum Subarray](./src/dynamic_programming/maximum_subarray.rs) -- [x] [Is Subsequence](./src/dynamic_programming/is_subsequence.rs) -- [x] [Maximal Square](./src/dynamic_programming/maximal_square.rs) - -## [Data Structures](./src/data_structures) - -- [x] [Queue](./src/data_structures/queue.rs) -- [x] [Heap](./src/data_structures/heap.rs) -- [x] [Linked List](./src/data_structures/linked_list.rs) -- [x] [Graph](./src/data_structures/graph.rs) - - [x] [Directed](./src/data_structures/graph.rs) - - [x] [Undirected](./src/data_structures/graph.rs) -- [x] [Trie](./src/data_structures/trie.rs) -- [x] [Binary Search Tree](./src/data_structures/binary_search_tree.rs) -- [x] [B-Tree](./src/data_structures/b_tree.rs) -- [x] [AVL Tree](./src/data_structures/avl_tree.rs) -- [x] [RB Tree](./src/data_structures/rb_tree.rs) -- [X] [Stack using Linked List](./src/data_structures/stack_using_singly_linked_list.rs) -- [x] [Segment Tree](./src/data_structures/segment_tree.rs) -- [x] [Fenwick Tree](./src/data_structures/fenwick_tree.rs) -- [x] [Union-find](./src/data_structures/union_find.rs) - -## [Strings](./src/string) - -- [x] [Aho-Corasick Algorithm](./src/string/aho_corasick.rs) -- [x] [Burrows-Wheeler transform](./src/string/burrows_wheeler_transform.rs) -- [x] [Knuth Morris Pratt](./src/string/knuth_morris_pratt.rs) -- [x] [Manacher](./src/string/manacher.rs) -- [x] [Rabin Carp](./src/string/rabin_karp.rs) -- [x] [Reverse](./src/string/reverse.rs) -- [x] [Hamming Distance](./src/string/hamming_distance.rs) - -## [General](./src/general) - -- [x] [Convex Hull: Graham Scan](./src/general/convex_hull.rs) -- [x] [N-Queens Problem](./src/general/nqueens.rs) -- [ ] Graph Coloring -- [x] [Tower of Hanoi](./src/general/hanoi.rs) -- [x] [Kmeans](./src/general/kmeans.rs) -- [x] [Two Sum](./src/general/two_sum.rs) -- [x] [Huffman Encoding](./src/general/huffman_encoding.rs) - -## [Search Algorithms](./src/searching) - -- [x] [Linear](./src/searching/linear_search.rs) -- [x] [Binary](./src/searching/binary_search.rs) -- [x] [Recursive Binary](./src/searching/binary_search_recursive.rs) -- [x] [Kth Smallest](./src/searching/kth_smallest.rs) -- [x] [Exponential](./src/searching/exponential_search.rs) -- [x] [Jump](./src/searching/jump_search.rs) -- [x] [Fibonacci](./src/searching/fibonacci_search.rs) -- [x] [Quick Select](./src/searching/quick_select.rs) - -## [Geometry](./src/geometry) - -- [x] [Closest pair of 2D points](./src/geometry/closest_points.rs) - -## [Ciphers](./src/ciphers) - -- [x] [Caesar](./src/ciphers/caesar.rs) -- [x] [Morse Code](./src/ciphers/morse_code.rs) -- [x] [Polybius](./src/ciphers/polybius.rs) -- [x] [SHA-2](./src/ciphers/sha256.rs) -- [x] [TEA](./src/ciphers/tea.rs) -- [x] [Transposition](./src/ciphers/transposition.rs) -- [x] [Vigenère](./src/ciphers/vigenere.rs) -- [x] [XOR](./src/ciphers/xor.rs) -- [x] [Salsa20](./src/ciphers/salsa.rs) -- [x] [HMAC](./src/ciphers/hashing_traits.rs) -- Rot13 - - [x] [Another Rot13](./src/ciphers/another_rot13.rs) - - [x] [Rot13](./src/ciphers/rot13.rs) - ---- - -### All implemented Algos - -See [DIRECTORY.md](./DIRECTORY.md) +
+ + + +

The Algorithms - Rust

+ + + + Gitpod Ready-to-Code + + + Build workflow + + + + + + Discord community + + + Gitter chat + + + +

All algorithms implemented in Rust - for education

+
+ +### List of Algorithms +See our [directory](DIRECTORY.md) for easier navigation and a better overview of the project. ### Contributing - -See [CONTRIBUTING.md](CONTRIBUTING.md) +Read through our [Contribution Guidelines](CONTRIBUTING.md) before you contribute. diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 00000000000..1b3dd21fbf7 --- /dev/null +++ b/clippy.toml @@ -0,0 +1,4 @@ +allowed-duplicate-crates = [ + "zerocopy", + "zerocopy-derive", +] diff --git a/src/backtracking/all_combination_of_size_k.rs b/src/backtracking/all_combination_of_size_k.rs new file mode 100644 index 00000000000..65b6b643b97 --- /dev/null +++ b/src/backtracking/all_combination_of_size_k.rs @@ -0,0 +1,123 @@ +//! This module provides a function to generate all possible combinations +//! of `k` numbers out of `0...n-1` using a backtracking algorithm. + +/// Custom error type for combination generation. +#[derive(Debug, PartialEq)] +pub enum CombinationError { + KGreaterThanN, + InvalidZeroRange, +} + +/// Generates all possible combinations of `k` numbers out of `0...n-1`. +/// +/// # Arguments +/// +/// * `n` - The upper limit of the range (`0` to `n-1`). +/// * `k` - The number of elements in each combination. +/// +/// # Returns +/// +/// A `Result` containing a vector with all possible combinations of `k` numbers out of `0...n-1`, +/// or a `CombinationError` if the input is invalid. +pub fn generate_all_combinations(n: usize, k: usize) -> Result>, CombinationError> { + if n == 0 && k > 0 { + return Err(CombinationError::InvalidZeroRange); + } + + if k > n { + return Err(CombinationError::KGreaterThanN); + } + + let mut combinations = vec![]; + let mut current = vec![0; k]; + backtrack(0, n, k, 0, &mut current, &mut combinations); + Ok(combinations) +} + +/// Helper function to generate combinations recursively. +/// +/// # Arguments +/// +/// * `start` - The current number to start the combination with. +/// * `n` - The upper limit of the range (`0` to `n-1`). +/// * `k` - The number of elements left to complete the combination. +/// * `index` - The current index being filled in the combination. +/// * `current` - A mutable reference to the current combination being constructed. +/// * `combinations` - A mutable reference to the vector holding all combinations. +fn backtrack( + start: usize, + n: usize, + k: usize, + index: usize, + current: &mut Vec, + combinations: &mut Vec>, +) { + if index == k { + combinations.push(current.clone()); + return; + } + + for num in start..=(n - k + index) { + current[index] = num; + backtrack(num + 1, n, k, index + 1, current, combinations); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! combination_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (n, k, expected) = $test_case; + assert_eq!(generate_all_combinations(n, k), expected); + } + )* + } + } + + combination_tests! { + test_generate_4_2: (4, 2, Ok(vec![ + vec![0, 1], + vec![0, 2], + vec![0, 3], + vec![1, 2], + vec![1, 3], + vec![2, 3], + ])), + test_generate_4_3: (4, 3, Ok(vec![ + vec![0, 1, 2], + vec![0, 1, 3], + vec![0, 2, 3], + vec![1, 2, 3], + ])), + test_generate_5_3: (5, 3, Ok(vec![ + vec![0, 1, 2], + vec![0, 1, 3], + vec![0, 1, 4], + vec![0, 2, 3], + vec![0, 2, 4], + vec![0, 3, 4], + vec![1, 2, 3], + vec![1, 2, 4], + vec![1, 3, 4], + vec![2, 3, 4], + ])), + test_generate_5_1: (5, 1, Ok(vec![ + vec![0], + vec![1], + vec![2], + vec![3], + vec![4], + ])), + test_empty: (0, 0, Ok(vec![vec![]])), + test_generate_n_eq_k: (3, 3, Ok(vec![ + vec![0, 1, 2], + ])), + test_generate_k_greater_than_n: (3, 4, Err(CombinationError::KGreaterThanN)), + test_zero_range_with_nonzero_k: (0, 1, Err(CombinationError::InvalidZeroRange)), + } +} diff --git a/src/backtracking/graph_coloring.rs b/src/backtracking/graph_coloring.rs new file mode 100644 index 00000000000..40bdf398e91 --- /dev/null +++ b/src/backtracking/graph_coloring.rs @@ -0,0 +1,370 @@ +//! This module provides functionality for generating all possible colorings of a undirected (or directed) graph +//! given a certain number of colors. It includes the GraphColoring struct and methods +//! for validating color assignments and finding all valid colorings. + +/// Represents potential errors when coloring on an adjacency matrix. +#[derive(Debug, PartialEq, Eq)] +pub enum GraphColoringError { + // Indicates that the adjacency matrix is empty + EmptyAdjacencyMatrix, + // Indicates that the adjacency matrix is not squared + ImproperAdjacencyMatrix, +} + +/// Generates all possible valid colorings of a graph. +/// +/// # Arguments +/// +/// * `adjacency_matrix` - A 2D vector representing the adjacency matrix of the graph. +/// * `num_colors` - The number of colors available for coloring the graph. +/// +/// # Returns +/// +/// * A `Result` containing an `Option` with a vector of solutions or a `GraphColoringError` if +/// there is an issue with the matrix. +pub fn generate_colorings( + adjacency_matrix: Vec>, + num_colors: usize, +) -> Result>>, GraphColoringError> { + Ok(GraphColoring::new(adjacency_matrix)?.find_solutions(num_colors)) +} + +/// A struct representing a graph coloring problem. +struct GraphColoring { + // The adjacency matrix of the graph + adjacency_matrix: Vec>, + // The current colors assigned to each vertex + vertex_colors: Vec, + // Vector of all valid color assignments for the vertices found during the search + solutions: Vec>, +} + +impl GraphColoring { + /// Creates a new GraphColoring instance. + /// + /// # Arguments + /// + /// * `adjacency_matrix` - A 2D vector representing the adjacency matrix of the graph. + /// + /// # Returns + /// + /// * A new instance of GraphColoring or a `GraphColoringError` if the matrix is empty or non-square. + fn new(adjacency_matrix: Vec>) -> Result { + let num_vertices = adjacency_matrix.len(); + + // Check if the adjacency matrix is empty + if num_vertices == 0 { + return Err(GraphColoringError::EmptyAdjacencyMatrix); + } + + // Check if the adjacency matrix is square + if adjacency_matrix.iter().any(|row| row.len() != num_vertices) { + return Err(GraphColoringError::ImproperAdjacencyMatrix); + } + + Ok(GraphColoring { + adjacency_matrix, + vertex_colors: vec![usize::MAX; num_vertices], + solutions: Vec::new(), + }) + } + + /// Returns the number of vertices in the graph. + fn num_vertices(&self) -> usize { + self.adjacency_matrix.len() + } + + /// Checks if a given color can be assigned to a vertex without causing a conflict. + /// + /// # Arguments + /// + /// * `vertex` - The index of the vertex to be colored. + /// * `color` - The color to be assigned to the vertex. + /// + /// # Returns + /// + /// * `true` if the color can be assigned to the vertex, `false` otherwise. + fn is_color_valid(&self, vertex: usize, color: usize) -> bool { + for neighbor in 0..self.num_vertices() { + // Check outgoing edges from vertex and incoming edges to vertex + if (self.adjacency_matrix[vertex][neighbor] || self.adjacency_matrix[neighbor][vertex]) + && self.vertex_colors[neighbor] == color + { + return false; + } + } + true + } + + /// Recursively finds all valid colorings for the graph. + /// + /// # Arguments + /// + /// * `vertex` - The current vertex to be colored. + /// * `num_colors` - The number of colors available for coloring the graph. + fn find_colorings(&mut self, vertex: usize, num_colors: usize) { + if vertex == self.num_vertices() { + self.solutions.push(self.vertex_colors.clone()); + return; + } + + for color in 0..num_colors { + if self.is_color_valid(vertex, color) { + self.vertex_colors[vertex] = color; + self.find_colorings(vertex + 1, num_colors); + self.vertex_colors[vertex] = usize::MAX; + } + } + } + + /// Finds all solutions for the graph coloring problem. + /// + /// # Arguments + /// + /// * `num_colors` - The number of colors available for coloring the graph. + /// + /// # Returns + /// + /// * A `Result` containing an `Option` with a vector of solutions or a `GraphColoringError`. + fn find_solutions(&mut self, num_colors: usize) -> Option>> { + self.find_colorings(0, num_colors); + if self.solutions.is_empty() { + None + } else { + Some(std::mem::take(&mut self.solutions)) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_graph_coloring { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (adjacency_matrix, num_colors, expected) = $test_case; + let actual = generate_colorings(adjacency_matrix, num_colors); + assert_eq!(actual, expected); + } + )* + }; + } + + test_graph_coloring! { + test_complete_graph_with_3_colors: ( + vec![ + vec![false, true, true, true], + vec![true, false, true, false], + vec![true, true, false, true], + vec![true, false, true, false], + ], + 3, + Ok(Some(vec![ + vec![0, 1, 2, 1], + vec![0, 2, 1, 2], + vec![1, 0, 2, 0], + vec![1, 2, 0, 2], + vec![2, 0, 1, 0], + vec![2, 1, 0, 1], + ])) + ), + test_linear_graph_with_2_colors: ( + vec![ + vec![false, true, false, false], + vec![true, false, true, false], + vec![false, true, false, true], + vec![false, false, true, false], + ], + 2, + Ok(Some(vec![ + vec![0, 1, 0, 1], + vec![1, 0, 1, 0], + ])) + ), + test_incomplete_graph_with_insufficient_colors: ( + vec![ + vec![false, true, true], + vec![true, false, true], + vec![true, true, false], + ], + 1, + Ok(None::>>) + ), + test_empty_graph: ( + vec![], + 1, + Err(GraphColoringError::EmptyAdjacencyMatrix) + ), + test_non_square_matrix: ( + vec![ + vec![false, true, true], + vec![true, false, true], + ], + 3, + Err(GraphColoringError::ImproperAdjacencyMatrix) + ), + test_single_vertex_graph: ( + vec![ + vec![false], + ], + 1, + Ok(Some(vec![ + vec![0], + ])) + ), + test_bipartite_graph_with_2_colors: ( + vec![ + vec![false, true, false, true], + vec![true, false, true, false], + vec![false, true, false, true], + vec![true, false, true, false], + ], + 2, + Ok(Some(vec![ + vec![0, 1, 0, 1], + vec![1, 0, 1, 0], + ])) + ), + test_large_graph_with_3_colors: ( + vec![ + vec![false, true, true, false, true, true, false, true, true, false], + vec![true, false, true, true, false, true, true, false, true, true], + vec![true, true, false, true, true, false, true, true, false, true], + vec![false, true, true, false, true, true, false, true, true, false], + vec![true, false, true, true, false, true, true, false, true, true], + vec![true, true, false, true, true, false, true, true, false, true], + vec![false, true, true, false, true, true, false, true, true, false], + vec![true, false, true, true, false, true, true, false, true, true], + vec![true, true, false, true, true, false, true, true, false, true], + vec![false, true, true, false, true, true, false, true, true, false], + ], + 3, + Ok(Some(vec![ + vec![0, 1, 2, 0, 1, 2, 0, 1, 2, 0], + vec![0, 2, 1, 0, 2, 1, 0, 2, 1, 0], + vec![1, 0, 2, 1, 0, 2, 1, 0, 2, 1], + vec![1, 2, 0, 1, 2, 0, 1, 2, 0, 1], + vec![2, 0, 1, 2, 0, 1, 2, 0, 1, 2], + vec![2, 1, 0, 2, 1, 0, 2, 1, 0, 2], + ])) + ), + test_disconnected_graph: ( + vec![ + vec![false, false, false], + vec![false, false, false], + vec![false, false, false], + ], + 2, + Ok(Some(vec![ + vec![0, 0, 0], + vec![0, 0, 1], + vec![0, 1, 0], + vec![0, 1, 1], + vec![1, 0, 0], + vec![1, 0, 1], + vec![1, 1, 0], + vec![1, 1, 1], + ])) + ), + test_no_valid_coloring: ( + vec![ + vec![false, true, true], + vec![true, false, true], + vec![true, true, false], + ], + 2, + Ok(None::>>) + ), + test_more_colors_than_nodes: ( + vec![ + vec![true, true], + vec![true, true], + ], + 3, + Ok(Some(vec![ + vec![0, 1], + vec![0, 2], + vec![1, 0], + vec![1, 2], + vec![2, 0], + vec![2, 1], + ])) + ), + test_no_coloring_with_zero_colors: ( + vec![ + vec![true], + ], + 0, + Ok(None::>>) + ), + test_complete_graph_with_3_vertices_and_3_colors: ( + vec![ + vec![false, true, true], + vec![true, false, true], + vec![true, true, false], + ], + 3, + Ok(Some(vec![ + vec![0, 1, 2], + vec![0, 2, 1], + vec![1, 0, 2], + vec![1, 2, 0], + vec![2, 0, 1], + vec![2, 1, 0], + ])) + ), + test_directed_graph_with_3_colors: ( + vec![ + vec![false, true, false, true], + vec![false, false, true, false], + vec![true, false, false, true], + vec![true, false, false, false], + ], + 3, + Ok(Some(vec![ + vec![0, 1, 2, 1], + vec![0, 2, 1, 2], + vec![1, 0, 2, 0], + vec![1, 2, 0, 2], + vec![2, 0, 1, 0], + vec![2, 1, 0, 1], + ])) + ), + test_directed_graph_no_valid_coloring: ( + vec![ + vec![false, true, false, true], + vec![false, false, true, true], + vec![true, false, false, true], + vec![true, false, false, false], + ], + 3, + Ok(None::>>) + ), + test_large_directed_graph_with_3_colors: ( + vec![ + vec![false, true, false, false, true, false, false, true, false, false], + vec![false, false, true, false, false, true, false, false, true, false], + vec![false, false, false, true, false, false, true, false, false, true], + vec![true, false, false, false, true, false, false, true, false, false], + vec![false, true, false, false, false, true, false, false, true, false], + vec![false, false, true, false, false, false, true, false, false, true], + vec![true, false, false, false, true, false, false, true, false, false], + vec![false, true, false, false, false, true, false, false, true, false], + vec![false, false, true, false, false, false, true, false, false, true], + vec![true, false, false, false, true, false, false, true, false, false], + ], + 3, + Ok(Some(vec![ + vec![0, 1, 2, 1, 2, 0, 1, 2, 0, 1], + vec![0, 2, 1, 2, 1, 0, 2, 1, 0, 2], + vec![1, 0, 2, 0, 2, 1, 0, 2, 1, 0], + vec![1, 2, 0, 2, 0, 1, 2, 0, 1, 2], + vec![2, 0, 1, 0, 1, 2, 0, 1, 2, 0], + vec![2, 1, 0, 1, 0, 2, 1, 0, 2, 1] + ])) + ), + } +} diff --git a/src/backtracking/hamiltonian_cycle.rs b/src/backtracking/hamiltonian_cycle.rs new file mode 100644 index 00000000000..2eacd5feb3c --- /dev/null +++ b/src/backtracking/hamiltonian_cycle.rs @@ -0,0 +1,310 @@ +//! This module provides functionality to find a Hamiltonian cycle in a directed or undirected graph. +//! Source: [Wikipedia](https://en.wikipedia.org/wiki/Hamiltonian_path_problem) + +/// Represents potential errors when finding hamiltonian cycle on an adjacency matrix. +#[derive(Debug, PartialEq, Eq)] +pub enum FindHamiltonianCycleError { + /// Indicates that the adjacency matrix is empty. + EmptyAdjacencyMatrix, + /// Indicates that the adjacency matrix is not square. + ImproperAdjacencyMatrix, + /// Indicates that the starting vertex is out of bounds. + StartOutOfBound, +} + +/// Represents a graph using an adjacency matrix. +struct Graph { + /// The adjacency matrix representing the graph. + adjacency_matrix: Vec>, +} + +impl Graph { + /// Creates a new graph with the provided adjacency matrix. + /// + /// # Arguments + /// + /// * `adjacency_matrix` - A square matrix where each element indicates + /// the presence (`true`) or absence (`false`) of an edge + /// between two vertices. + /// + /// # Returns + /// + /// A `Result` containing the graph if successful, or an `FindHamiltonianCycleError` if there is an issue with the matrix. + fn new(adjacency_matrix: Vec>) -> Result { + // Check if the adjacency matrix is empty. + if adjacency_matrix.is_empty() { + return Err(FindHamiltonianCycleError::EmptyAdjacencyMatrix); + } + + // Validate that the adjacency matrix is square. + if adjacency_matrix + .iter() + .any(|row| row.len() != adjacency_matrix.len()) + { + return Err(FindHamiltonianCycleError::ImproperAdjacencyMatrix); + } + + Ok(Self { adjacency_matrix }) + } + + /// Returns the number of vertices in the graph. + fn num_vertices(&self) -> usize { + self.adjacency_matrix.len() + } + + /// Determines if it is safe to include vertex `v` in the Hamiltonian cycle path. + /// + /// # Arguments + /// + /// * `v` - The index of the vertex being considered. + /// * `visited` - A reference to the vector representing the visited vertices. + /// * `path` - A reference to the current path being explored. + /// * `pos` - The position of the current vertex being considered. + /// + /// # Returns + /// + /// `true` if it is safe to include `v` in the path, `false` otherwise. + fn is_safe(&self, v: usize, visited: &[bool], path: &[Option], pos: usize) -> bool { + // Check if the current vertex and the last vertex in the path are adjacent. + if !self.adjacency_matrix[path[pos - 1].unwrap()][v] { + return false; + } + + // Check if the vertex has already been included in the path. + !visited[v] + } + + /// Recursively searches for a Hamiltonian cycle. + /// + /// This function is called by `find_hamiltonian_cycle`. + /// + /// # Arguments + /// + /// * `path` - A mutable vector representing the current path being explored. + /// * `visited` - A mutable vector representing the visited vertices. + /// * `pos` - The position of the current vertex being considered. + /// + /// # Returns + /// + /// `true` if a Hamiltonian cycle is found, `false` otherwise. + fn hamiltonian_cycle_util( + &self, + path: &mut [Option], + visited: &mut [bool], + pos: usize, + ) -> bool { + if pos == self.num_vertices() { + // Check if there is an edge from the last included vertex to the first vertex. + return self.adjacency_matrix[path[pos - 1].unwrap()][path[0].unwrap()]; + } + + for v in 0..self.num_vertices() { + if self.is_safe(v, visited, path, pos) { + path[pos] = Some(v); + visited[v] = true; + if self.hamiltonian_cycle_util(path, visited, pos + 1) { + return true; + } + path[pos] = None; + visited[v] = false; + } + } + + false + } + + /// Attempts to find a Hamiltonian cycle in the graph, starting from the specified vertex. + /// + /// A Hamiltonian cycle visits every vertex exactly once and returns to the starting vertex. + /// + /// # Note + /// This implementation may not find all possible Hamiltonian cycles. + /// It stops as soon as it finds one valid cycle. If multiple Hamiltonian cycles exist, + /// only one will be returned. + /// + /// # Returns + /// + /// `Ok(Some(path))` if a Hamiltonian cycle is found, where `path` is a vector + /// containing the indices of vertices in the cycle, starting and ending with the same vertex. + /// + /// `Ok(None)` if no Hamiltonian cycle exists. + fn find_hamiltonian_cycle( + &self, + start_vertex: usize, + ) -> Result>, FindHamiltonianCycleError> { + // Validate the start vertex. + if start_vertex >= self.num_vertices() { + return Err(FindHamiltonianCycleError::StartOutOfBound); + } + + // Initialize the path. + let mut path = vec![None; self.num_vertices()]; + // Start at the specified vertex. + path[0] = Some(start_vertex); + + // Initialize the visited vector. + let mut visited = vec![false; self.num_vertices()]; + visited[start_vertex] = true; + + if self.hamiltonian_cycle_util(&mut path, &mut visited, 1) { + // Complete the cycle by returning to the starting vertex. + path.push(Some(start_vertex)); + Ok(Some(path.into_iter().map(Option::unwrap).collect())) + } else { + Ok(None) + } + } +} + +/// Attempts to find a Hamiltonian cycle in a graph represented by an adjacency matrix, starting from a specified vertex. +pub fn find_hamiltonian_cycle( + adjacency_matrix: Vec>, + start_vertex: usize, +) -> Result>, FindHamiltonianCycleError> { + Graph::new(adjacency_matrix)?.find_hamiltonian_cycle(start_vertex) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! hamiltonian_cycle_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (adjacency_matrix, start_vertex, expected) = $test_case; + let result = find_hamiltonian_cycle(adjacency_matrix, start_vertex); + assert_eq!(result, expected); + } + )* + }; + } + + hamiltonian_cycle_tests! { + test_complete_graph: ( + vec![ + vec![false, true, true, true], + vec![true, false, true, true], + vec![true, true, false, true], + vec![true, true, true, false], + ], + 0, + Ok(Some(vec![0, 1, 2, 3, 0])) + ), + test_directed_graph_with_cycle: ( + vec![ + vec![false, true, false, false, false], + vec![false, false, true, true, false], + vec![true, false, false, true, true], + vec![false, false, true, false, true], + vec![true, true, false, false, false], + ], + 2, + Ok(Some(vec![2, 3, 4, 0, 1, 2])) + ), + test_undirected_graph_with_cycle: ( + vec![ + vec![false, true, false, false, true], + vec![true, false, true, false, false], + vec![false, true, false, true, false], + vec![false, false, true, false, true], + vec![true, false, false, true, false], + ], + 2, + Ok(Some(vec![2, 1, 0, 4, 3, 2])) + ), + test_directed_graph_no_cycle: ( + vec![ + vec![false, true, false, true, false], + vec![false, false, true, true, false], + vec![false, false, false, true, false], + vec![false, false, false, false, true], + vec![false, false, true, false, false], + ], + 0, + Ok(None::>) + ), + test_undirected_graph_no_cycle: ( + vec![ + vec![false, true, false, false, false], + vec![true, false, true, true, false], + vec![false, true, false, true, true], + vec![false, true, true, false, true], + vec![false, false, true, true, false], + ], + 0, + Ok(None::>) + ), + test_triangle_graph: ( + vec![ + vec![false, true, false], + vec![false, false, true], + vec![true, false, false], + ], + 1, + Ok(Some(vec![1, 2, 0, 1])) + ), + test_tree_graph: ( + vec![ + vec![false, true, false, true, false], + vec![true, false, true, true, false], + vec![false, true, false, false, false], + vec![true, true, false, false, true], + vec![false, false, false, true, false], + ], + 0, + Ok(None::>) + ), + test_empty_graph: ( + vec![], + 0, + Err(FindHamiltonianCycleError::EmptyAdjacencyMatrix) + ), + test_improper_graph: ( + vec![ + vec![false, true], + vec![true], + vec![false, true, true], + vec![true, true, true, false] + ], + 0, + Err(FindHamiltonianCycleError::ImproperAdjacencyMatrix) + ), + test_start_out_of_bound: ( + vec![ + vec![false, true, true], + vec![true, false, true], + vec![true, true, false], + ], + 3, + Err(FindHamiltonianCycleError::StartOutOfBound) + ), + test_complex_directed_graph: ( + vec![ + vec![false, true, false, true, false, false], + vec![false, false, true, false, true, false], + vec![false, false, false, true, false, false], + vec![false, true, false, false, true, false], + vec![false, false, true, false, false, true], + vec![true, false, false, false, false, false], + ], + 0, + Ok(Some(vec![0, 1, 2, 3, 4, 5, 0])) + ), + single_node_self_loop: ( + vec![ + vec![true], + ], + 0, + Ok(Some(vec![0, 0])) + ), + single_node: ( + vec![ + vec![false], + ], + 0, + Ok(None) + ), + } +} diff --git a/src/backtracking/knight_tour.rs b/src/backtracking/knight_tour.rs new file mode 100644 index 00000000000..26e9e36d682 --- /dev/null +++ b/src/backtracking/knight_tour.rs @@ -0,0 +1,195 @@ +//! This module contains the implementation of the Knight's Tour problem. +//! +//! The Knight's Tour is a classic chess problem where the objective is to move a knight to every square on a chessboard exactly once. + +/// Finds the Knight's Tour starting from the specified position. +/// +/// # Arguments +/// +/// * `size_x` - The width of the chessboard. +/// * `size_y` - The height of the chessboard. +/// * `start_x` - The x-coordinate of the starting position. +/// * `start_y` - The y-coordinate of the starting position. +/// +/// # Returns +/// +/// A tour matrix if the tour was found or None if not found. +/// The tour matrix returned is essentially the board field of the `KnightTour` +/// struct `Vec>`. It represents the sequence of moves made by the +/// knight on the chessboard, with each cell containing the order in which the knight visited that square. +pub fn find_knight_tour( + size_x: usize, + size_y: usize, + start_x: usize, + start_y: usize, +) -> Option>> { + let mut tour = KnightTour::new(size_x, size_y); + tour.find_tour(start_x, start_y) +} + +/// Represents the KnightTour struct which implements the Knight's Tour problem. +struct KnightTour { + board: Vec>, +} + +impl KnightTour { + /// Possible moves of the knight on the board + const MOVES: [(isize, isize); 8] = [ + (2, 1), + (1, 2), + (-1, 2), + (-2, 1), + (-2, -1), + (-1, -2), + (1, -2), + (2, -1), + ]; + + /// Constructs a new KnightTour instance with the given board size. + /// # Arguments + /// + /// * `size_x` - The width of the chessboard. + /// * `size_y` - The height of the chessboard. + /// + /// # Returns + /// + /// A new KnightTour instance. + fn new(size_x: usize, size_y: usize) -> Self { + let board = vec![vec![0; size_x]; size_y]; + KnightTour { board } + } + + /// Returns the width of the chessboard. + fn size_x(&self) -> usize { + self.board.len() + } + + /// Returns the height of the chessboard. + fn size_y(&self) -> usize { + self.board[0].len() + } + + /// Checks if the given position is safe to move to. + /// + /// # Arguments + /// + /// * `x` - The x-coordinate of the position. + /// * `y` - The y-coordinate of the position. + /// + /// # Returns + /// + /// A boolean indicating whether the position is safe to move to. + fn is_safe(&self, x: isize, y: isize) -> bool { + x >= 0 + && y >= 0 + && x < self.size_x() as isize + && y < self.size_y() as isize + && self.board[x as usize][y as usize] == 0 + } + + /// Recursively solves the Knight's Tour problem. + /// + /// # Arguments + /// + /// * `x` - The current x-coordinate of the knight. + /// * `y` - The current y-coordinate of the knight. + /// * `move_count` - The current move count. + /// + /// # Returns + /// + /// A boolean indicating whether a solution was found. + fn solve_tour(&mut self, x: isize, y: isize, move_count: usize) -> bool { + if move_count == self.size_x() * self.size_y() { + return true; + } + for &(dx, dy) in &Self::MOVES { + let next_x = x + dx; + let next_y = y + dy; + + if self.is_safe(next_x, next_y) { + self.board[next_x as usize][next_y as usize] = move_count + 1; + + if self.solve_tour(next_x, next_y, move_count + 1) { + return true; + } + // Backtrack + self.board[next_x as usize][next_y as usize] = 0; + } + } + + false + } + + /// Finds the Knight's Tour starting from the specified position. + /// + /// # Arguments + /// + /// * `start_x` - The x-coordinate of the starting position. + /// * `start_y` - The y-coordinate of the starting position. + /// + /// # Returns + /// + /// A tour matrix if the tour was found or None if not found. + fn find_tour(&mut self, start_x: usize, start_y: usize) -> Option>> { + if !self.is_safe(start_x as isize, start_y as isize) { + return None; + } + + self.board[start_x][start_y] = 1; + + if !self.solve_tour(start_x as isize, start_y as isize, 1) { + return None; + } + + Some(self.board.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_find_knight_tour { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (size_x, size_y, start_x, start_y, expected) = $tc; + if expected.is_some() { + assert_eq!(expected.clone().unwrap()[start_x][start_y], 1) + } + assert_eq!(find_knight_tour(size_x, size_y, start_x, start_y), expected); + } + )* + } + } + test_find_knight_tour! { + test_knight_tour_5x5: (5, 5, 0, 0, Some(vec![ + vec![1, 6, 15, 10, 21], + vec![14, 9, 20, 5, 16], + vec![19, 2, 7, 22, 11], + vec![8, 13, 24, 17, 4], + vec![25, 18, 3, 12, 23], + ])), + test_knight_tour_6x6: (6, 6, 0, 0, Some(vec![ + vec![1, 16, 7, 26, 11, 14], + vec![34, 25, 12, 15, 6, 27], + vec![17, 2, 33, 8, 13, 10], + vec![32, 35, 24, 21, 28, 5], + vec![23, 18, 3, 30, 9, 20], + vec![36, 31, 22, 19, 4, 29], + ])), + test_knight_tour_8x8: (8, 8, 0, 0, Some(vec![ + vec![1, 60, 39, 34, 31, 18, 9, 64], + vec![38, 35, 32, 61, 10, 63, 30, 17], + vec![59, 2, 37, 40, 33, 28, 19, 8], + vec![36, 49, 42, 27, 62, 11, 16, 29], + vec![43, 58, 3, 50, 41, 24, 7, 20], + vec![48, 51, 46, 55, 26, 21, 12, 15], + vec![57, 44, 53, 4, 23, 14, 25, 6], + vec![52, 47, 56, 45, 54, 5, 22, 13], + ])), + test_no_solution: (5, 5, 2, 1, None::>>), + test_invalid_start_position: (8, 8, 10, 10, None::>>), + } +} diff --git a/src/backtracking/mod.rs b/src/backtracking/mod.rs new file mode 100644 index 00000000000..182c6fbbc01 --- /dev/null +++ b/src/backtracking/mod.rs @@ -0,0 +1,21 @@ +mod all_combination_of_size_k; +mod graph_coloring; +mod hamiltonian_cycle; +mod knight_tour; +mod n_queens; +mod parentheses_generator; +mod permutations; +mod rat_in_maze; +mod subset_sum; +mod sudoku; + +pub use all_combination_of_size_k::generate_all_combinations; +pub use graph_coloring::generate_colorings; +pub use hamiltonian_cycle::find_hamiltonian_cycle; +pub use knight_tour::find_knight_tour; +pub use n_queens::n_queens_solver; +pub use parentheses_generator::generate_parentheses; +pub use permutations::permute; +pub use rat_in_maze::find_path_in_maze; +pub use subset_sum::has_subset_with_sum; +pub use sudoku::sudoku_solver; diff --git a/src/backtracking/n_queens.rs b/src/backtracking/n_queens.rs new file mode 100644 index 00000000000..234195e97ea --- /dev/null +++ b/src/backtracking/n_queens.rs @@ -0,0 +1,221 @@ +//! This module provides functionality to solve the N-Queens problem. +//! +//! The N-Queens problem is a classic chessboard puzzle where the goal is to +//! place N queens on an NxN chessboard so that no two queens threaten each +//! other. Queens can attack each other if they share the same row, column, or +//! diagonal. +//! +//! This implementation solves the N-Queens problem using a backtracking algorithm. +//! It starts with an empty chessboard and iteratively tries to place queens in +//! different rows, ensuring they do not conflict with each other. If a valid +//! solution is found, it's added to the list of solutions. + +/// Solves the N-Queens problem for a given size and returns a vector of solutions. +/// +/// # Arguments +/// +/// * `n` - The size of the chessboard (NxN). +/// +/// # Returns +/// +/// A vector containing all solutions to the N-Queens problem. +pub fn n_queens_solver(n: usize) -> Vec> { + let mut solver = NQueensSolver::new(n); + solver.solve() +} + +/// Represents a solver for the N-Queens problem. +struct NQueensSolver { + // The size of the chessboard + size: usize, + // A 2D vector representing the chessboard where '.' denotes an empty space and 'Q' denotes a queen + board: Vec>, + // A vector to store all valid solutions + solutions: Vec>, +} + +impl NQueensSolver { + /// Creates a new `NQueensSolver` instance with the given size. + /// + /// # Arguments + /// + /// * `size` - The size of the chessboard (N×N). + /// + /// # Returns + /// + /// A new `NQueensSolver` instance. + fn new(size: usize) -> Self { + NQueensSolver { + size, + board: vec![vec!['.'; size]; size], + solutions: Vec::new(), + } + } + + /// Solves the N-Queens problem and returns a vector of solutions. + /// + /// # Returns + /// + /// A vector containing all solutions to the N-Queens problem. + fn solve(&mut self) -> Vec> { + self.solve_helper(0); + std::mem::take(&mut self.solutions) + } + + /// Checks if it's safe to place a queen at the specified position (row, col). + /// + /// # Arguments + /// + /// * `row` - The row index of the position to check. + /// * `col` - The column index of the position to check. + /// + /// # Returns + /// + /// `true` if it's safe to place a queen at the specified position, `false` otherwise. + fn is_safe(&self, row: usize, col: usize) -> bool { + // Check column and diagonals + for i in 0..row { + if self.board[i][col] == 'Q' + || (col >= row - i && self.board[i][col - (row - i)] == 'Q') + || (col + row - i < self.size && self.board[i][col + (row - i)] == 'Q') + { + return false; + } + } + true + } + + /// Recursive helper function to solve the N-Queens problem. + /// + /// # Arguments + /// + /// * `row` - The current row being processed. + fn solve_helper(&mut self, row: usize) { + if row == self.size { + self.solutions + .push(self.board.iter().map(|row| row.iter().collect()).collect()); + return; + } + + for col in 0..self.size { + if self.is_safe(row, col) { + self.board[row][col] = 'Q'; + self.solve_helper(row + 1); + self.board[row][col] = '.'; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_n_queens_solver { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (n, expected_solutions) = $tc; + let solutions = n_queens_solver(n); + assert_eq!(solutions, expected_solutions); + } + )* + }; + } + + test_n_queens_solver! { + test_0_queens: (0, vec![Vec::::new()]), + test_1_queen: (1, vec![vec!["Q"]]), + test_2_queens:(2, Vec::>::new()), + test_3_queens:(3, Vec::>::new()), + test_4_queens: (4, vec![ + vec![".Q..", + "...Q", + "Q...", + "..Q."], + vec!["..Q.", + "Q...", + "...Q", + ".Q.."], + ]), + test_5_queens:(5, vec![ + vec!["Q....", + "..Q..", + "....Q", + ".Q...", + "...Q."], + vec!["Q....", + "...Q.", + ".Q...", + "....Q", + "..Q.."], + vec![".Q...", + "...Q.", + "Q....", + "..Q..", + "....Q"], + vec![".Q...", + "....Q", + "..Q..", + "Q....", + "...Q."], + vec!["..Q..", + "Q....", + "...Q.", + ".Q...", + "....Q"], + vec!["..Q..", + "....Q", + ".Q...", + "...Q.", + "Q...."], + vec!["...Q.", + "Q....", + "..Q..", + "....Q", + ".Q..."], + vec!["...Q.", + ".Q...", + "....Q", + "..Q..", + "Q...."], + vec!["....Q", + ".Q...", + "...Q.", + "Q....", + "..Q.."], + vec!["....Q", + "..Q..", + "Q....", + "...Q.", + ".Q..."], + ]), + test_6_queens: (6, vec![ + vec![".Q....", + "...Q..", + ".....Q", + "Q.....", + "..Q...", + "....Q."], + vec!["..Q...", + ".....Q", + ".Q....", + "....Q.", + "Q.....", + "...Q.."], + vec!["...Q..", + "Q.....", + "....Q.", + ".Q....", + ".....Q", + "..Q..."], + vec!["....Q.", + "..Q...", + "Q.....", + ".....Q", + "...Q..", + ".Q...."], + ]), + } +} diff --git a/src/backtracking/parentheses_generator.rs b/src/backtracking/parentheses_generator.rs new file mode 100644 index 00000000000..9aafe81fa7a --- /dev/null +++ b/src/backtracking/parentheses_generator.rs @@ -0,0 +1,76 @@ +/// Generates all combinations of well-formed parentheses given a non-negative integer `n`. +/// +/// This function uses backtracking to generate all possible combinations of well-formed +/// parentheses. The resulting combinations are returned as a vector of strings. +/// +/// # Arguments +/// +/// * `n` - A non-negative integer representing the number of pairs of parentheses. +pub fn generate_parentheses(n: usize) -> Vec { + let mut result = Vec::new(); + if n > 0 { + generate("", 0, 0, n, &mut result); + } + result +} + +/// Helper function for generating parentheses recursively. +/// +/// This function is called recursively to build combinations of well-formed parentheses. +/// It tracks the number of open and close parentheses added so far and adds a new parenthesis +/// if it's valid to do so. +/// +/// # Arguments +/// +/// * `current` - The current string of parentheses being built. +/// * `open_count` - The count of open parentheses in the current string. +/// * `close_count` - The count of close parentheses in the current string. +/// * `n` - The total number of pairs of parentheses to be generated. +/// * `result` - A mutable reference to the vector storing the generated combinations. +fn generate( + current: &str, + open_count: usize, + close_count: usize, + n: usize, + result: &mut Vec, +) { + if current.len() == (n * 2) { + result.push(current.to_string()); + return; + } + + if open_count < n { + let new_str = current.to_string() + "("; + generate(&new_str, open_count + 1, close_count, n, result); + } + + if close_count < open_count { + let new_str = current.to_string() + ")"; + generate(&new_str, open_count, close_count + 1, n, result); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! generate_parentheses_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (n, expected_result) = $test_case; + assert_eq!(generate_parentheses(n), expected_result); + } + )* + }; + } + + generate_parentheses_tests! { + test_generate_parentheses_0: (0, Vec::::new()), + test_generate_parentheses_1: (1, vec!["()"]), + test_generate_parentheses_2: (2, vec!["(())", "()()"]), + test_generate_parentheses_3: (3, vec!["((()))", "(()())", "(())()", "()(())", "()()()"]), + test_generate_parentheses_4: (4, vec!["(((())))", "((()()))", "((())())", "((()))()", "(()(()))", "(()()())", "(()())()", "(())(())", "(())()()", "()((()))", "()(()())", "()(())()", "()()(())", "()()()()"]), + } +} diff --git a/src/backtracking/permutations.rs b/src/backtracking/permutations.rs new file mode 100644 index 00000000000..8859a633310 --- /dev/null +++ b/src/backtracking/permutations.rs @@ -0,0 +1,141 @@ +//! This module provides a function to generate all possible distinct permutations +//! of a given collection of integers using a backtracking algorithm. + +/// Generates all possible distinct permutations of a given vector of integers. +/// +/// # Arguments +/// +/// * `nums` - A vector of integers. The input vector is sorted before generating +/// permutations to handle duplicates effectively. +/// +/// # Returns +/// +/// A vector containing all possible distinct permutations of the input vector. +pub fn permute(mut nums: Vec) -> Vec> { + let mut permutations = Vec::new(); + let mut current = Vec::new(); + let mut used = vec![false; nums.len()]; + + nums.sort(); + generate(&nums, &mut current, &mut used, &mut permutations); + + permutations +} + +/// Helper function for the `permute` function to generate distinct permutations recursively. +/// +/// # Arguments +/// +/// * `nums` - A reference to the sorted slice of integers. +/// * `current` - A mutable reference to the vector holding the current permutation. +/// * `used` - A mutable reference to a vector tracking which elements are used. +/// * `permutations` - A mutable reference to the vector holding all generated distinct permutations. +fn generate( + nums: &[isize], + current: &mut Vec, + used: &mut Vec, + permutations: &mut Vec>, +) { + if current.len() == nums.len() { + permutations.push(current.clone()); + return; + } + + for idx in 0..nums.len() { + if used[idx] { + continue; + } + + if idx > 0 && nums[idx] == nums[idx - 1] && !used[idx - 1] { + continue; + } + + current.push(nums[idx]); + used[idx] = true; + + generate(nums, current, used, permutations); + + current.pop(); + used[idx] = false; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! permute_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(permute(input), expected); + } + )* + } + } + + permute_tests! { + test_permute_basic: (vec![1, 2, 3], vec![ + vec![1, 2, 3], + vec![1, 3, 2], + vec![2, 1, 3], + vec![2, 3, 1], + vec![3, 1, 2], + vec![3, 2, 1], + ]), + test_permute_empty: (Vec::::new(), vec![vec![]]), + test_permute_single: (vec![1], vec![vec![1]]), + test_permute_duplicates: (vec![1, 1, 2], vec![ + vec![1, 1, 2], + vec![1, 2, 1], + vec![2, 1, 1], + ]), + test_permute_all_duplicates: (vec![1, 1, 1, 1], vec![ + vec![1, 1, 1, 1], + ]), + test_permute_negative: (vec![-1, -2, -3], vec![ + vec![-3, -2, -1], + vec![-3, -1, -2], + vec![-2, -3, -1], + vec![-2, -1, -3], + vec![-1, -3, -2], + vec![-1, -2, -3], + ]), + test_permute_mixed: (vec![-1, 0, 1], vec![ + vec![-1, 0, 1], + vec![-1, 1, 0], + vec![0, -1, 1], + vec![0, 1, -1], + vec![1, -1, 0], + vec![1, 0, -1], + ]), + test_permute_larger: (vec![1, 2, 3, 4], vec![ + vec![1, 2, 3, 4], + vec![1, 2, 4, 3], + vec![1, 3, 2, 4], + vec![1, 3, 4, 2], + vec![1, 4, 2, 3], + vec![1, 4, 3, 2], + vec![2, 1, 3, 4], + vec![2, 1, 4, 3], + vec![2, 3, 1, 4], + vec![2, 3, 4, 1], + vec![2, 4, 1, 3], + vec![2, 4, 3, 1], + vec![3, 1, 2, 4], + vec![3, 1, 4, 2], + vec![3, 2, 1, 4], + vec![3, 2, 4, 1], + vec![3, 4, 1, 2], + vec![3, 4, 2, 1], + vec![4, 1, 2, 3], + vec![4, 1, 3, 2], + vec![4, 2, 1, 3], + vec![4, 2, 3, 1], + vec![4, 3, 1, 2], + vec![4, 3, 2, 1], + ]), + } +} diff --git a/src/backtracking/rat_in_maze.rs b/src/backtracking/rat_in_maze.rs new file mode 100644 index 00000000000..fb658697b39 --- /dev/null +++ b/src/backtracking/rat_in_maze.rs @@ -0,0 +1,327 @@ +//! This module contains the implementation of the Rat in Maze problem. +//! +//! The Rat in Maze problem is a classic algorithmic problem where the +//! objective is to find a path from the starting position to the exit +//! position in a maze. + +/// Enum representing various errors that can occur while working with mazes. +#[derive(Debug, PartialEq, Eq)] +pub enum MazeError { + /// Indicates that the maze is empty (zero rows). + EmptyMaze, + /// Indicates that the starting position is out of bounds. + OutOfBoundPos, + /// Indicates an improper representation of the maze (e.g., non-rectangular maze). + ImproperMazeRepr, +} + +/// Finds a path through the maze starting from the specified position. +/// +/// # Arguments +/// +/// * `maze` - The maze represented as a vector of vectors where each +/// inner vector represents a row in the maze grid. +/// * `start_x` - The x-coordinate of the starting position. +/// * `start_y` - The y-coordinate of the starting position. +/// +/// # Returns +/// +/// A `Result` where: +/// - `Ok(Some(solution))` if a path is found and contains the solution matrix. +/// - `Ok(None)` if no path is found. +/// - `Err(MazeError)` for various error conditions such as out-of-bound start position or improper maze representation. +/// +/// # Solution Selection +/// +/// The function returns the first successful path it discovers based on the predefined order of moves. +/// The order of moves is defined in the `MOVES` constant of the `Maze` struct. +/// +/// The backtracking algorithm explores each direction in this order. If multiple solutions exist, +/// the algorithm returns the first path it finds according to this sequence. It recursively explores +/// each direction, marks valid moves, and backtracks if necessary, ensuring that the solution is found +/// efficiently and consistently. +pub fn find_path_in_maze( + maze: &[Vec], + start_x: usize, + start_y: usize, +) -> Result>>, MazeError> { + if maze.is_empty() { + return Err(MazeError::EmptyMaze); + } + + // Validate start position + if start_x >= maze.len() || start_y >= maze[0].len() { + return Err(MazeError::OutOfBoundPos); + } + + // Validate maze representation (if necessary) + if maze.iter().any(|row| row.len() != maze[0].len()) { + return Err(MazeError::ImproperMazeRepr); + } + + // If validations pass, proceed with finding the path + let maze_instance = Maze::new(maze.to_owned()); + Ok(maze_instance.find_path(start_x, start_y)) +} + +/// Represents a maze. +struct Maze { + maze: Vec>, +} + +impl Maze { + /// Represents possible moves in the maze. + const MOVES: [(isize, isize); 4] = [(0, 1), (1, 0), (0, -1), (-1, 0)]; + + /// Constructs a new Maze instance. + /// # Arguments + /// + /// * `maze` - The maze represented as a vector of vectors where each + /// inner vector represents a row in the maze grid. + /// + /// # Returns + /// + /// A new Maze instance. + fn new(maze: Vec>) -> Self { + Maze { maze } + } + + /// Returns the width of the maze. + /// + /// # Returns + /// + /// The width of the maze. + fn width(&self) -> usize { + self.maze[0].len() + } + + /// Returns the height of the maze. + /// + /// # Returns + /// + /// The height of the maze. + fn height(&self) -> usize { + self.maze.len() + } + + /// Finds a path through the maze starting from the specified position. + /// + /// # Arguments + /// + /// * `start_x` - The x-coordinate of the starting position. + /// * `start_y` - The y-coordinate of the starting position. + /// + /// # Returns + /// + /// A solution matrix if a path is found or None if not found. + fn find_path(&self, start_x: usize, start_y: usize) -> Option>> { + let mut solution = vec![vec![false; self.width()]; self.height()]; + if self.solve(start_x as isize, start_y as isize, &mut solution) { + Some(solution) + } else { + None + } + } + + /// Recursively solves the Rat in Maze problem using backtracking. + /// + /// # Arguments + /// + /// * `x` - The current x-coordinate. + /// * `y` - The current y-coordinate. + /// * `solution` - The current solution matrix. + /// + /// # Returns + /// + /// A boolean indicating whether a solution was found. + fn solve(&self, x: isize, y: isize, solution: &mut [Vec]) -> bool { + if x == (self.height() as isize - 1) && y == (self.width() as isize - 1) { + solution[x as usize][y as usize] = true; + return true; + } + + if self.is_valid(x, y, solution) { + solution[x as usize][y as usize] = true; + + for &(dx, dy) in &Self::MOVES { + if self.solve(x + dx, y + dy, solution) { + return true; + } + } + + // If none of the directions lead to the solution, backtrack + solution[x as usize][y as usize] = false; + return false; + } + false + } + + /// Checks if a given position is valid in the maze. + /// + /// # Arguments + /// + /// * `x` - The x-coordinate of the position. + /// * `y` - The y-coordinate of the position. + /// * `solution` - The current solution matrix. + /// + /// # Returns + /// + /// A boolean indicating whether the position is valid. + fn is_valid(&self, x: isize, y: isize, solution: &[Vec]) -> bool { + x >= 0 + && y >= 0 + && x < self.height() as isize + && y < self.width() as isize + && self.maze[x as usize][y as usize] + && !solution[x as usize][y as usize] + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_find_path_in_maze { + ($($name:ident: $start_x:expr, $start_y:expr, $maze:expr, $expected:expr,)*) => { + $( + #[test] + fn $name() { + let solution = find_path_in_maze($maze, $start_x, $start_y); + assert_eq!(solution, $expected); + if let Ok(Some(expected_solution)) = &solution { + assert_eq!(expected_solution[$start_x][$start_y], true); + } + } + )* + } + } + + test_find_path_in_maze! { + maze_with_solution_5x5: 0, 0, &[ + vec![true, false, true, false, false], + vec![true, true, false, true, false], + vec![false, true, true, true, false], + vec![false, false, false, true, true], + vec![false, true, false, false, true], + ], Ok(Some(vec![ + vec![true, false, false, false, false], + vec![true, true, false, false, false], + vec![false, true, true, true, false], + vec![false, false, false, true, true], + vec![false, false, false, false, true], + ])), + maze_with_solution_6x6: 0, 0, &[ + vec![true, false, true, false, true, false], + vec![true, true, false, true, false, true], + vec![false, true, true, true, true, false], + vec![false, false, false, true, true, true], + vec![false, true, false, false, true, false], + vec![true, true, true, true, true, true], + ], Ok(Some(vec![ + vec![true, false, false, false, false, false], + vec![true, true, false, false, false, false], + vec![false, true, true, true, true, false], + vec![false, false, false, false, true, false], + vec![false, false, false, false, true, false], + vec![false, false, false, false, true, true], + ])), + maze_with_solution_8x8: 0, 0, &[ + vec![true, false, false, false, false, false, false, true], + vec![true, true, false, true, true, true, false, false], + vec![false, true, true, true, false, false, false, false], + vec![false, false, false, true, false, true, true, false], + vec![false, true, false, true, true, true, false, true], + vec![true, false, true, false, false, true, true, true], + vec![false, false, true, true, true, false, true, true], + vec![true, true, true, false, true, true, true, true], + ], Ok(Some(vec![ + vec![true, false, false, false, false, false, false, false], + vec![true, true, false, false, false, false, false, false], + vec![false, true, true, true, false, false, false, false], + vec![false, false, false, true, false, false, false, false], + vec![false, false, false, true, true, true, false, false], + vec![false, false, false, false, false, true, true, true], + vec![false, false, false, false, false, false, false, true], + vec![false, false, false, false, false, false, false, true], + ])), + maze_without_solution_4x4: 0, 0, &[ + vec![true, false, false, false], + vec![true, true, false, false], + vec![false, false, true, false], + vec![false, false, false, true], + ], Ok(None::>>), + maze_with_solution_3x4: 0, 0, &[ + vec![true, false, true, true], + vec![true, true, true, false], + vec![false, true, true, true], + ], Ok(Some(vec![ + vec![true, false, false, false], + vec![true, true, true, false], + vec![false, false, true, true], + ])), + maze_without_solution_3x4: 0, 0, &[ + vec![true, false, true, true], + vec![true, false, true, false], + vec![false, true, false, true], + ], Ok(None::>>), + improper_maze_representation: 0, 0, &[ + vec![true], + vec![true, true], + vec![true, true, true], + vec![true, true, true, true] + ], Err(MazeError::ImproperMazeRepr), + out_of_bound_start: 0, 3, &[ + vec![true, false, true], + vec![true, true], + vec![false, true, true], + ], Err(MazeError::OutOfBoundPos), + empty_maze: 0, 0, &[], Err(MazeError::EmptyMaze), + maze_with_single_cell: 0, 0, &[ + vec![true], + ], Ok(Some(vec![ + vec![true] + ])), + maze_with_one_row_and_multiple_columns: 0, 0, &[ + vec![true, false, true, true, false] + ], Ok(None::>>), + maze_with_multiple_rows_and_one_column: 0, 0, &[ + vec![true], + vec![true], + vec![false], + vec![true], + ], Ok(None::>>), + maze_with_walls_surrounding_border: 0, 0, &[ + vec![false, false, false], + vec![false, true, false], + vec![false, false, false], + ], Ok(None::>>), + maze_with_no_walls: 0, 0, &[ + vec![true, true, true], + vec![true, true, true], + vec![true, true, true], + ], Ok(Some(vec![ + vec![true, true, true], + vec![false, false, true], + vec![false, false, true], + ])), + maze_with_going_back: 0, 0, &[ + vec![true, true, true, true, true, true], + vec![false, false, false, true, false, true], + vec![true, true, true, true, false, false], + vec![true, false, false, false, false, false], + vec![true, false, false, false, true, true], + vec![true, false, true, true, true, false], + vec![true, false, true , false, true, false], + vec![true, true, true, false, true, true], + ], Ok(Some(vec![ + vec![true, true, true, true, false, false], + vec![false, false, false, true, false, false], + vec![true, true, true, true, false, false], + vec![true, false, false, false, false, false], + vec![true, false, false, false, false, false], + vec![true, false, true, true, true, false], + vec![true, false, true , false, true, false], + vec![true, true, true, false, true, true], + ])), + } +} diff --git a/src/backtracking/subset_sum.rs b/src/backtracking/subset_sum.rs new file mode 100644 index 00000000000..3e69b380b58 --- /dev/null +++ b/src/backtracking/subset_sum.rs @@ -0,0 +1,55 @@ +//! This module provides functionality to check if there exists a subset of a given set of integers +//! that sums to a target value. The implementation uses a recursive backtracking approach. + +/// Checks if there exists a subset of the given set that sums to the target value. +pub fn has_subset_with_sum(set: &[isize], target: isize) -> bool { + backtrack(set, set.len(), target) +} + +fn backtrack(set: &[isize], remaining_items: usize, target: isize) -> bool { + // Found a subset with the required sum + if target == 0 { + return true; + } + // No more elements to process + if remaining_items == 0 { + return false; + } + // Check if we can find a subset including or excluding the last element + backtrack(set, remaining_items - 1, target) + || backtrack(set, remaining_items - 1, target - set[remaining_items - 1]) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! has_subset_with_sum_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (set, target, expected) = $test_case; + assert_eq!(has_subset_with_sum(set, target), expected); + } + )* + } + } + + has_subset_with_sum_tests! { + test_small_set_with_sum: (&[3, 34, 4, 12, 5, 2], 9, true), + test_small_set_without_sum: (&[3, 34, 4, 12, 5, 2], 30, false), + test_consecutive_set_with_sum: (&[1, 2, 3, 4, 5, 6], 10, true), + test_consecutive_set_without_sum: (&[1, 2, 3, 4, 5, 6], 22, false), + test_large_set_with_sum: (&[5, 10, 12, 13, 15, 18, -1, 10, 50, -2, 3, 4], 30, true), + test_empty_set: (&[], 0, true), + test_empty_set_with_nonzero_sum: (&[], 10, false), + test_single_element_equal_to_sum: (&[10], 10, true), + test_single_element_not_equal_to_sum: (&[5], 10, false), + test_negative_set_with_sum: (&[-7, -3, -2, 5, 8], 0, true), + test_negative_sum: (&[1, 2, 3, 4, 5], -1, false), + test_negative_sum_with_negatives: (&[-7, -3, -2, 5, 8], -4, true), + test_negative_sum_with_negatives_no_solution: (&[-7, -3, -2, 5, 8], -14, false), + test_even_inputs_odd_target: (&[2, 4, 6, 2, 8, -2, 10, 12, -24, 8, 12, 18], 3, false), + } +} diff --git a/src/backtracking/sudoku.rs b/src/backtracking/sudoku.rs new file mode 100644 index 00000000000..bb6c13cbde6 --- /dev/null +++ b/src/backtracking/sudoku.rs @@ -0,0 +1,164 @@ +//! A Rust implementation of Sudoku solver using Backtracking. +//! +//! This module provides functionality to solve Sudoku puzzles using the backtracking algorithm. +//! +//! GeeksForGeeks: [Sudoku Backtracking](https://www.geeksforgeeks.org/sudoku-backtracking-7/) + +/// Solves a Sudoku puzzle. +/// +/// Given a partially filled Sudoku puzzle represented by a 9x9 grid, this function attempts to +/// solve the puzzle using the backtracking algorithm. +/// +/// Returns the solved Sudoku board if a solution exists, or `None` if no solution is found. +pub fn sudoku_solver(board: &[[u8; 9]; 9]) -> Option<[[u8; 9]; 9]> { + let mut solver = SudokuSolver::new(*board); + if solver.solve() { + Some(solver.board) + } else { + None + } +} + +/// Represents a Sudoku puzzle solver. +struct SudokuSolver { + /// The Sudoku board represented by a 9x9 grid. + board: [[u8; 9]; 9], +} + +impl SudokuSolver { + /// Creates a new Sudoku puzzle solver with the given board. + fn new(board: [[u8; 9]; 9]) -> SudokuSolver { + SudokuSolver { board } + } + + /// Finds an empty cell in the Sudoku board. + /// + /// Returns the coordinates of an empty cell `(row, column)` if found, or `None` if all cells are filled. + fn find_empty_cell(&self) -> Option<(usize, usize)> { + // Find an empty cell in the board (returns None if all cells are filled) + for row in 0..9 { + for column in 0..9 { + if self.board[row][column] == 0 { + return Some((row, column)); + } + } + } + + None + } + + /// Checks whether a given value can be placed in a specific cell according to Sudoku rules. + /// + /// Returns `true` if the value can be placed in the cell, otherwise `false`. + fn is_value_valid(&self, coordinates: (usize, usize), value: u8) -> bool { + let (row, column) = coordinates; + + // Checks if the value to be added in the board is an acceptable value for the cell + // Checking through the row + for current_column in 0..9 { + if self.board[row][current_column] == value { + return false; + } + } + + // Checking through the column + for current_row in 0..9 { + if self.board[current_row][column] == value { + return false; + } + } + + // Checking through the 3x3 block of the cell + let start_row = row / 3 * 3; + let start_column = column / 3 * 3; + + for current_row in start_row..start_row + 3 { + for current_column in start_column..start_column + 3 { + if self.board[current_row][current_column] == value { + return false; + } + } + } + + true + } + + /// Solves the Sudoku puzzle recursively using backtracking. + /// + /// Returns `true` if a solution is found, otherwise `false`. + fn solve(&mut self) -> bool { + let empty_cell = self.find_empty_cell(); + + if let Some((row, column)) = empty_cell { + for value in 1..=9 { + if self.is_value_valid((row, column), value) { + self.board[row][column] = value; + if self.solve() { + return true; + } + // Backtracking if the board cannot be solved using the current configuration + self.board[row][column] = 0; + } + } + } else { + // If the board is complete + return true; + } + + // Returning false if the board cannot be solved using the current configuration + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_sudoku_solver { + ($($name:ident: $board:expr, $expected:expr,)*) => { + $( + #[test] + fn $name() { + let result = sudoku_solver(&$board); + assert_eq!(result, $expected); + } + )* + }; + } + + test_sudoku_solver! { + test_sudoku_correct: [ + [3, 0, 6, 5, 0, 8, 4, 0, 0], + [5, 2, 0, 0, 0, 0, 0, 0, 0], + [0, 8, 7, 0, 0, 0, 0, 3, 1], + [0, 0, 3, 0, 1, 0, 0, 8, 0], + [9, 0, 0, 8, 6, 3, 0, 0, 5], + [0, 5, 0, 0, 9, 0, 6, 0, 0], + [1, 3, 0, 0, 0, 0, 2, 5, 0], + [0, 0, 0, 0, 0, 0, 0, 7, 4], + [0, 0, 5, 2, 0, 6, 3, 0, 0], + ], Some([ + [3, 1, 6, 5, 7, 8, 4, 9, 2], + [5, 2, 9, 1, 3, 4, 7, 6, 8], + [4, 8, 7, 6, 2, 9, 5, 3, 1], + [2, 6, 3, 4, 1, 5, 9, 8, 7], + [9, 7, 4, 8, 6, 3, 1, 2, 5], + [8, 5, 1, 7, 9, 2, 6, 4, 3], + [1, 3, 8, 9, 4, 7, 2, 5, 6], + [6, 9, 2, 3, 5, 1, 8, 7, 4], + [7, 4, 5, 2, 8, 6, 3, 1, 9], + ]), + + test_sudoku_incorrect: [ + [6, 0, 3, 5, 0, 8, 4, 0, 0], + [5, 2, 0, 0, 0, 0, 0, 0, 0], + [0, 8, 7, 0, 0, 0, 0, 3, 1], + [0, 0, 3, 0, 1, 0, 0, 8, 0], + [9, 0, 0, 8, 6, 3, 0, 0, 5], + [0, 5, 0, 0, 9, 0, 6, 0, 0], + [1, 3, 0, 0, 0, 0, 2, 5, 0], + [0, 0, 0, 0, 0, 0, 0, 7, 4], + [0, 0, 5, 2, 0, 6, 3, 0, 0], + ], None::<[[u8; 9]; 9]>, + } +} diff --git a/src/big_integer/fast_factorial.rs b/src/big_integer/fast_factorial.rs new file mode 100644 index 00000000000..8272f9ee100 --- /dev/null +++ b/src/big_integer/fast_factorial.rs @@ -0,0 +1,87 @@ +// Algorithm created by Peter Borwein in 1985 +// https://doi.org/10.1016/0196-6774(85)90006-9 + +use crate::math::sieve_of_eratosthenes; +use num_bigint::BigUint; +use num_traits::One; +use std::collections::BTreeMap; + +/// Calculate the sum of n / p^i with integer division for all values of i +fn index(p: usize, n: usize) -> usize { + let mut index = 0; + let mut i = 1; + let mut quot = n / p; + + while quot > 0 { + index += quot; + i += 1; + quot = n / p.pow(i); + } + + index +} + +/// Calculate the factorial with time complexity O(log(log(n)) * M(n * log(n))) where M(n) is the time complexity of multiplying two n-digit numbers together. +pub fn fast_factorial(n: usize) -> BigUint { + if n < 2 { + return BigUint::one(); + } + + // get list of primes that will be factors of n! + let primes = sieve_of_eratosthenes(n); + + // Map the primes with their index + let p_indices = primes + .into_iter() + .map(|p| (p, index(p, n))) + .collect::>(); + + let max_bits = p_indices[&2].next_power_of_two().ilog2() + 1; + + // Create a Vec of 1's + let mut a = vec![BigUint::one(); max_bits as usize]; + + // For every prime p, multiply a[i] by p if the ith bit of p's index is 1 + for (p, i) in p_indices { + let mut bit = 1usize; + while bit.ilog2() < max_bits { + if (bit & i) > 0 { + a[bit.ilog2() as usize] *= p; + } + + bit <<= 1; + } + } + + a.into_iter() + .enumerate() + .map(|(i, a_i)| a_i.pow(2u32.pow(i as u32))) // raise every a[i] to the 2^ith power + .product() // we get our answer by multiplying the result +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::math::factorial::factorial_bigmath; + + #[test] + fn fact() { + assert_eq!(fast_factorial(0), BigUint::one()); + assert_eq!(fast_factorial(1), BigUint::one()); + assert_eq!(fast_factorial(2), factorial_bigmath(2)); + assert_eq!(fast_factorial(3), factorial_bigmath(3)); + assert_eq!(fast_factorial(6), factorial_bigmath(6)); + assert_eq!(fast_factorial(7), factorial_bigmath(7)); + assert_eq!(fast_factorial(10), factorial_bigmath(10)); + assert_eq!(fast_factorial(11), factorial_bigmath(11)); + assert_eq!(fast_factorial(18), factorial_bigmath(18)); + assert_eq!(fast_factorial(19), factorial_bigmath(19)); + assert_eq!(fast_factorial(30), factorial_bigmath(30)); + assert_eq!(fast_factorial(34), factorial_bigmath(34)); + assert_eq!(fast_factorial(35), factorial_bigmath(35)); + assert_eq!(fast_factorial(52), factorial_bigmath(52)); + assert_eq!(fast_factorial(100), factorial_bigmath(100)); + assert_eq!(fast_factorial(1000), factorial_bigmath(1000)); + assert_eq!(fast_factorial(5000), factorial_bigmath(5000)); + } +} diff --git a/src/big_integer/mod.rs b/src/big_integer/mod.rs new file mode 100644 index 00000000000..13c6767b36b --- /dev/null +++ b/src/big_integer/mod.rs @@ -0,0 +1,9 @@ +#![cfg(feature = "big-math")] + +mod fast_factorial; +mod multiply; +mod poly1305; + +pub use self::fast_factorial::fast_factorial; +pub use self::multiply::multiply; +pub use self::poly1305::Poly1305; diff --git a/src/big_integer/multiply.rs b/src/big_integer/multiply.rs new file mode 100644 index 00000000000..1f7d1a57de3 --- /dev/null +++ b/src/big_integer/multiply.rs @@ -0,0 +1,77 @@ +/// Performs long multiplication on string representations of non-negative numbers. +pub fn multiply(num1: &str, num2: &str) -> String { + if !is_valid_nonnegative(num1) || !is_valid_nonnegative(num2) { + panic!("String does not conform to specification") + } + + if num1 == "0" || num2 == "0" { + return "0".to_string(); + } + let output_size = num1.len() + num2.len(); + + let mut mult = vec![0; output_size]; + for (i, c1) in num1.chars().rev().enumerate() { + for (j, c2) in num2.chars().rev().enumerate() { + let mul = c1.to_digit(10).unwrap() * c2.to_digit(10).unwrap(); + // It could be a two-digit number here. + mult[i + j + 1] += (mult[i + j] + mul) / 10; + // Handling rounding. Here's a single digit. + mult[i + j] = (mult[i + j] + mul) % 10; + } + } + if mult[output_size - 1] == 0 { + mult.pop(); + } + mult.iter().rev().map(|&n| n.to_string()).collect() +} + +pub fn is_valid_nonnegative(num: &str) -> bool { + num.chars().all(char::is_numeric) && !num.is_empty() && (!num.starts_with('0') || num == "0") +} + +#[cfg(test)] +mod tests { + use super::*; + macro_rules! test_multiply { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (s, t, expected) = $inputs; + assert_eq!(multiply(s, t), expected); + assert_eq!(multiply(t, s), expected); + } + )* + } + } + + test_multiply! { + multiply0: ("2", "3", "6"), + multiply1: ("123", "456", "56088"), + multiply_zero: ("0", "222", "0"), + other_1: ("99", "99", "9801"), + other_2: ("999", "99", "98901"), + other_3: ("9999", "99", "989901"), + other_4: ("192939", "9499596", "1832842552644"), + } + + macro_rules! test_multiply_with_wrong_input { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + #[should_panic] + fn $name() { + let (s, t) = $inputs; + multiply(s, t); + } + )* + } + } + test_multiply_with_wrong_input! { + empty_input: ("", "121"), + leading_zero: ("01", "3"), + wrong_characters: ("2", "12d4"), + wrong_input_and_zero_1: ("0", "x"), + wrong_input_and_zero_2: ("y", "0"), + } +} diff --git a/src/big_integer/poly1305.rs b/src/big_integer/poly1305.rs new file mode 100644 index 00000000000..eae3c016eaa --- /dev/null +++ b/src/big_integer/poly1305.rs @@ -0,0 +1,98 @@ +use num_bigint::BigUint; +use num_traits::Num; +use num_traits::Zero; + +macro_rules! hex_uint { + ($a:literal) => { + BigUint::from_str_radix($a, 16).unwrap() + }; +} + +/** + * Poly1305 Message Authentication Code: + * This implementation is based on RFC8439. + * Note that the Big Integer library we are using may not be suitable for + * cryptographic applications due to non constant time operations. +*/ +pub struct Poly1305 { + p: BigUint, + r: BigUint, + s: BigUint, + /// The accumulator + pub acc: BigUint, +} + +impl Default for Poly1305 { + fn default() -> Self { + Self::new() + } +} + +impl Poly1305 { + pub fn new() -> Self { + Poly1305 { + p: hex_uint!("3fffffffffffffffffffffffffffffffb"), // 2^130 - 5 + r: Zero::zero(), + s: Zero::zero(), + acc: Zero::zero(), + } + } + pub fn clamp_r(&mut self) { + self.r &= hex_uint!("0ffffffc0ffffffc0ffffffc0fffffff"); + } + pub fn set_key(&mut self, key: &[u8; 32]) { + self.r = BigUint::from_bytes_le(&key[..16]); + self.s = BigUint::from_bytes_le(&key[16..]); + self.clamp_r(); + } + /// process a 16-byte-long message block. If message is not long enough, + /// fill the `msg` array with zeros, but set `msg_bytes` to the original + /// chunk length in bytes. See `basic_tv1` for example usage. + pub fn add_msg(&mut self, msg: &[u8; 16], msg_bytes: u64) { + let mut n = BigUint::from_bytes_le(msg); + n.set_bit(msg_bytes * 8, true); + self.acc += n; + self.acc *= &self.r; + self.acc %= &self.p; + } + /// The result is guaranteed to be 16 bytes long + pub fn get_tag(&self) -> Vec { + let result = &self.acc + &self.s; + let mut bytes = result.to_bytes_le(); + bytes.resize(16, 0); + bytes + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fmt::Write; + fn get_tag_hex(tag: &[u8]) -> String { + let mut result = String::new(); + for &x in tag { + write!(result, "{x:02x}").unwrap(); + } + result + } + #[test] + fn basic_tv1() { + let mut mac = Poly1305::default(); + let key: [u8; 32] = [ + 0x85, 0xd6, 0xbe, 0x78, 0x57, 0x55, 0x6d, 0x33, 0x7f, 0x44, 0x52, 0xfe, 0x42, 0xd5, + 0x06, 0xa8, 0x01, 0x03, 0x80, 0x8a, 0xfb, 0x0d, 0xb2, 0xfd, 0x4a, 0xbf, 0xf6, 0xaf, + 0x41, 0x49, 0xf5, 0x1b, + ]; + let mut tmp_buffer = [0_u8; 16]; + mac.set_key(&key); + mac.add_msg(b"Cryptographic Fo", 16); + mac.add_msg(b"rum Research Gro", 16); + tmp_buffer[..2].copy_from_slice(b"up"); + mac.add_msg(&tmp_buffer, 2); + let result = mac.get_tag(); + assert_eq!( + get_tag_hex(result.as_slice()), + "a8061dc1305136c6c22b8baf0c0127a9" + ); + } +} diff --git a/src/bit_manipulation/counting_bits.rs b/src/bit_manipulation/counting_bits.rs new file mode 100644 index 00000000000..9357ca3080c --- /dev/null +++ b/src/bit_manipulation/counting_bits.rs @@ -0,0 +1,54 @@ +//! This module implements a function to count the number of set bits (1s) +//! in the binary representation of an unsigned integer. +//! It uses Brian Kernighan's algorithm, which efficiently clears the least significant +//! set bit in each iteration until all bits are cleared. +//! The algorithm runs in O(k), where k is the number of set bits. + +/// Counts the number of set bits in an unsigned integer. +/// +/// # Arguments +/// +/// * `n` - An unsigned 32-bit integer whose set bits will be counted. +/// +/// # Returns +/// +/// * `usize` - The number of set bits (1s) in the binary representation of the input number. +pub fn count_set_bits(mut n: usize) -> usize { + // Initialize a variable to keep track of the count of set bits + let mut count = 0; + while n > 0 { + // Clear the least significant set bit by + // performing a bitwise AND operation with (n - 1) + n &= n - 1; + + // Increment the count for each set bit found + count += 1; + } + + count +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_count_set_bits { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(count_set_bits(input), expected); + } + )* + }; + } + test_count_set_bits! { + test_count_set_bits_zero: (0, 0), + test_count_set_bits_one: (1, 1), + test_count_set_bits_power_of_two: (16, 1), + test_count_set_bits_all_set_bits: (usize::MAX, std::mem::size_of::() * 8), + test_count_set_bits_alternating_bits: (0b10101010, 4), + test_count_set_bits_mixed_bits: (0b11011011, 6), + } +} diff --git a/src/bit_manipulation/highest_set_bit.rs b/src/bit_manipulation/highest_set_bit.rs new file mode 100644 index 00000000000..3488f49a7d9 --- /dev/null +++ b/src/bit_manipulation/highest_set_bit.rs @@ -0,0 +1,52 @@ +//! This module provides a function to find the position of the most significant bit (MSB) +//! set to 1 in a given positive integer. + +/// Finds the position of the highest (most significant) set bit in a positive integer. +/// +/// # Arguments +/// +/// * `num` - An integer value for which the highest set bit will be determined. +/// +/// # Returns +/// +/// * Returns `Some(position)` if a set bit exists or `None` if no bit is set. +pub fn find_highest_set_bit(num: usize) -> Option { + if num == 0 { + return None; + } + + let mut position = 0; + let mut n = num; + + while n > 0 { + n >>= 1; + position += 1; + } + + Some(position - 1) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_find_highest_set_bit { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(find_highest_set_bit(input), expected); + } + )* + }; + } + + test_find_highest_set_bit! { + test_positive_number: (18, Some(4)), + test_0: (0, None), + test_1: (1, Some(0)), + test_2: (2, Some(1)), + test_3: (3, Some(1)), + } +} diff --git a/src/bit_manipulation/mod.rs b/src/bit_manipulation/mod.rs new file mode 100644 index 00000000000..027c4b81817 --- /dev/null +++ b/src/bit_manipulation/mod.rs @@ -0,0 +1,9 @@ +mod counting_bits; +mod highest_set_bit; +mod n_bits_gray_code; +mod sum_of_two_integers; + +pub use counting_bits::count_set_bits; +pub use highest_set_bit::find_highest_set_bit; +pub use n_bits_gray_code::generate_gray_code; +pub use sum_of_two_integers::add_two_integers; diff --git a/src/bit_manipulation/n_bits_gray_code.rs b/src/bit_manipulation/n_bits_gray_code.rs new file mode 100644 index 00000000000..64c717bc761 --- /dev/null +++ b/src/bit_manipulation/n_bits_gray_code.rs @@ -0,0 +1,75 @@ +/// Custom error type for Gray code generation. +#[derive(Debug, PartialEq)] +pub enum GrayCodeError { + ZeroBitCount, +} + +/// Generates an n-bit Gray code sequence using the direct Gray code formula. +/// +/// # Arguments +/// +/// * `n` - The number of bits for the Gray code. +/// +/// # Returns +/// +/// A vector of Gray code sequences as strings. +pub fn generate_gray_code(n: usize) -> Result, GrayCodeError> { + if n == 0 { + return Err(GrayCodeError::ZeroBitCount); + } + + let num_codes = 1 << n; + let mut result = Vec::with_capacity(num_codes); + + for i in 0..num_codes { + let gray = i ^ (i >> 1); + let gray_code = (0..n) + .rev() + .map(|bit| if gray & (1 << bit) != 0 { '1' } else { '0' }) + .collect::(); + result.push(gray_code); + } + + Ok(result) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! gray_code_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(generate_gray_code(input), expected); + } + )* + }; + } + + gray_code_tests! { + zero_bit_count: (0, Err(GrayCodeError::ZeroBitCount)), + gray_code_1_bit: (1, Ok(vec![ + "0".to_string(), + "1".to_string(), + ])), + gray_code_2_bit: (2, Ok(vec![ + "00".to_string(), + "01".to_string(), + "11".to_string(), + "10".to_string(), + ])), + gray_code_3_bit: (3, Ok(vec![ + "000".to_string(), + "001".to_string(), + "011".to_string(), + "010".to_string(), + "110".to_string(), + "111".to_string(), + "101".to_string(), + "100".to_string(), + ])), + } +} diff --git a/src/bit_manipulation/sum_of_two_integers.rs b/src/bit_manipulation/sum_of_two_integers.rs new file mode 100644 index 00000000000..45d3532b173 --- /dev/null +++ b/src/bit_manipulation/sum_of_two_integers.rs @@ -0,0 +1,55 @@ +//! This module provides a function to add two integers without using the `+` operator. +//! It relies on bitwise operations (XOR and AND) to compute the sum, simulating the addition process. + +/// Adds two integers using bitwise operations. +/// +/// # Arguments +/// +/// * `a` - The first integer to be added. +/// * `b` - The second integer to be added. +/// +/// # Returns +/// +/// * `isize` - The result of adding the two integers. +pub fn add_two_integers(mut a: isize, mut b: isize) -> isize { + let mut carry; + + while b != 0 { + let sum = a ^ b; + carry = (a & b) << 1; + a = sum; + b = carry; + } + + a +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_add_two_integers { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (a, b) = $test_case; + assert_eq!(add_two_integers(a, b), a + b); + assert_eq!(add_two_integers(b, a), a + b); + } + )* + }; + } + + test_add_two_integers! { + test_add_two_integers_positive: (3, 5), + test_add_two_integers_large_positive: (100, 200), + test_add_two_integers_edge_positive: (65535, 1), + test_add_two_integers_negative: (-10, 6), + test_add_two_integers_both_negative: (-50, -30), + test_add_two_integers_edge_negative: (-1, -1), + test_add_two_integers_zero: (0, 0), + test_add_two_integers_zero_with_positive: (0, 42), + test_add_two_integers_zero_with_negative: (0, -42), + } +} diff --git a/src/ciphers/README.md b/src/ciphers/README.md index 3afa92212ef..fb54477c191 100644 --- a/src/ciphers/README.md +++ b/src/ciphers/README.md @@ -34,7 +34,7 @@ Many people have tried to implement encryption schemes that are essentially Vige SHA-2 (Secure Hash Algorithm 2) is a set of cryptographic hash functions designed by the United States National Security Agency (NSA) and first published in 2001. They are built using the Merkle–Damgård structure, from a one-way compression function itself built using the Davies–Meyer structure from a (classified) specialized block cipher. ###### Source: [Wikipedia](https://en.wikipedia.org/wiki/SHA-2) -### Transposition _(Not implemented yet)_ +### [Transposition](./transposition.rs) In cryptography, a **transposition cipher** is a method of encryption by which the positions held by units of plaintext (which are commonly characters or groups of characters) are shifted according to a regular system, so that the ciphertext constitutes a permutation of the plaintext. That is, the order of the units is changed (the plaintext is reordered).
Mathematically a bijective function is used on the characters' positions to encrypt and an inverse function to decrypt. ###### Source: [Wikipedia](https://en.wikipedia.org/wiki/Transposition_cipher) diff --git a/src/ciphers/aes.rs b/src/ciphers/aes.rs index 3c154b52402..5d2eb98ece0 100644 --- a/src/ciphers/aes.rs +++ b/src/ciphers/aes.rs @@ -538,7 +538,7 @@ mod tests { let decrypted = aes_decrypt(&encrypted, AesKey::AesKey128(key)); assert_eq!( str, - String::from_utf8(decrypted).unwrap().trim_end_matches("\0") + String::from_utf8(decrypted).unwrap().trim_end_matches('\0') ); } } diff --git a/src/ciphers/baconian_cipher.rs b/src/ciphers/baconian_cipher.rs new file mode 100644 index 00000000000..0ae71cab2cf --- /dev/null +++ b/src/ciphers/baconian_cipher.rs @@ -0,0 +1,70 @@ +// Author : cyrixninja +//Program to encode and decode Baconian or Bacon's Cipher +//Wikipedia reference : https://en.wikipedia.org/wiki/Bacon%27s_cipher +// Bacon's cipher or the Baconian cipher is a method of steganographic message encoding devised by Francis Bacon in 1605. +// A message is concealed in the presentation of text, rather than its content. Bacon cipher is categorized as both a substitution cipher (in plain code) and a concealment cipher (using the two typefaces). + +// Encode Baconian Cipher +pub fn baconian_encode(message: &str) -> String { + let alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + let baconian = [ + "AAAAA", "AAAAB", "AAABA", "AAABB", "AABAA", "AABAB", "AABBA", "AABBB", "ABAAA", "ABAAB", + "ABABA", "ABABB", "ABBAA", "ABBAB", "ABBBA", "ABBBB", "BAAAA", "BAAAB", "BAABA", "BAABB", + "BABAA", "BABAB", "BABBA", "BABBB", + ]; + + message + .chars() + .map(|c| { + if let Some(index) = alphabet.find(c.to_ascii_uppercase()) { + baconian[index].to_string() + } else { + c.to_string() + } + }) + .collect() +} + +// Decode Baconian Cipher +pub fn baconian_decode(encoded: &str) -> String { + let baconian = [ + "AAAAA", "AAAAB", "AAABA", "AAABB", "AABAA", "AABAB", "AABBA", "AABBB", "ABAAA", "ABAAB", + "ABABA", "ABABB", "ABBAA", "ABBAB", "ABBBA", "ABBBB", "BAAAA", "BAAAB", "BAABA", "BAABB", + "BABAA", "BABAB", "BABBA", "BABBB", + ]; + let alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + + encoded + .as_bytes() + .chunks(5) + .map(|chunk| { + if let Some(index) = baconian + .iter() + .position(|&x| x == String::from_utf8_lossy(chunk)) + { + alphabet.chars().nth(index).unwrap() + } else { + ' ' + } + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_baconian_encoding() { + let message = "HELLO"; + let encoded = baconian_encode(message); + assert_eq!(encoded, "AABBBAABAAABABBABABBABBBA"); + } + + #[test] + fn test_baconian_decoding() { + let message = "AABBBAABAAABABBABABBABBBA"; + let decoded = baconian_decode(message); + assert_eq!(decoded, "HELLO"); + } +} diff --git a/src/ciphers/base64.rs b/src/ciphers/base64.rs new file mode 100644 index 00000000000..81d4ac5dd6a --- /dev/null +++ b/src/ciphers/base64.rs @@ -0,0 +1,272 @@ +/* + A Rust implementation of a base64 encoder and decoder. + Written from scratch. +*/ + +// The charset and padding used for en- and decoding. +const CHARSET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; +const PADDING: char = '='; + +/* + Combines the two provided bytes into an u16, + and collects 6 bits from it using an AND mask: + + Example: + Bytes: X and Y + (Bits of those bytes will be signified using the names of their byte) + Offset: 4 + + `combined` = 0bXXXXXXXXYYYYYYYY + AND mask: + 0b1111110000000000 >> offset (4) = 0b0000111111000000 + `combined` with mask applied: + 0b0000XXYYYY000000 + Shift the value right by (16 bit number) - (6 bit mask) - (4 offset) = 6: + 0b0000000000XXYYYY + And then turn it into an u8: + 0b00XXYYYY (Return value) +*/ +fn collect_six_bits(from: (u8, u8), offset: u8) -> u8 { + let combined: u16 = ((from.0 as u16) << 8) | (from.1 as u16); + ((combined & (0b1111110000000000u16 >> offset)) >> (10 - offset)) as u8 +} + +pub fn base64_encode(data: &[u8]) -> String { + let mut bits_encoded = 0usize; + let mut encoded_string = String::new(); + // Using modulo twice to prevent an underflow, Wolfram|Alpha says this is optimal + let padding_needed = ((6 - (data.len() * 8) % 6) / 2) % 3; + loop { + let lower_byte_index_to_encode = bits_encoded / 8usize; // Integer division + if lower_byte_index_to_encode == data.len() { + break; + } + let lower_byte_to_encode = data[lower_byte_index_to_encode]; + let upper_byte_to_encode = if (lower_byte_index_to_encode + 1) == data.len() { + 0u8 // Padding + } else { + data[lower_byte_index_to_encode + 1] + }; + let bytes_to_encode = (lower_byte_to_encode, upper_byte_to_encode); + let offset: u8 = (bits_encoded % 8) as u8; + encoded_string.push(CHARSET[collect_six_bits(bytes_to_encode, offset) as usize] as char); + bits_encoded += 6; + } + for _ in 0..padding_needed { + encoded_string.push(PADDING); + } + encoded_string +} + +/* + Performs the exact inverse of the above description of `base64_encode` +*/ +pub fn base64_decode(data: &str) -> Result, (&str, u8)> { + let mut collected_bits = 0; + let mut byte_buffer = 0u16; + let mut databytes = data.bytes(); + let mut outputbytes = Vec::::new(); + 'decodeloop: loop { + while collected_bits < 8 { + if let Some(nextbyte) = databytes.next() { + // Finds the first occurence of the latest byte + if let Some(idx) = CHARSET.iter().position(|&x| x == nextbyte) { + byte_buffer |= ((idx & 0b00111111) as u16) << (10 - collected_bits); + collected_bits += 6; + } else if nextbyte == (PADDING as u8) { + collected_bits -= 2; // Padding only comes at the end so this works + } else { + return Err(( + "Failed to decode base64: Expected byte from charset, found invalid byte.", + nextbyte, + )); + } + } else { + break 'decodeloop; + } + } + outputbytes.push(((0b1111111100000000 & byte_buffer) >> 8) as u8); + byte_buffer &= 0b0000000011111111; + byte_buffer <<= 8; + collected_bits -= 8; + } + if collected_bits != 0 { + return Err(("Failed to decode base64: Invalid padding.", collected_bits)); + } + Ok(outputbytes) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn pregenerated_random_bytes_encode() { + macro_rules! test_encode { + ($left: expr, $right: expr) => { + assert_eq!(base64_encode(&$left.to_vec()), $right); + }; + } + test_encode!( + b"\xd31\xc9\x87D\xfe\xaa\xb3\xff\xef\x8c\x0eoD", + "0zHJh0T+qrP/74wOb0Q=" + ); + test_encode!( + b"\x9f\x0e8\xbc\xf5\xd0-\xb4.\xd4\xf0?\x8f\xe7\t{.\xff/6\xcbTY!\xae9\x82", + "nw44vPXQLbQu1PA/j+cJey7/LzbLVFkhrjmC" + ); + test_encode!(b"\x7f3\x15\x1a\xd3\xf91\x9bS\xa44=", "fzMVGtP5MZtTpDQ9"); + test_encode!( + b"7:\xf5\xd1[\xbfV/P\x18\x03\x00\xdc\xcd\xa1\xecG", + "Nzr10Vu/Vi9QGAMA3M2h7Ec=" + ); + test_encode!( + b"\xc3\xc9\x18={\xc4\x08\x97wN\xda\x81\x84?\x94\xe6\x9e", + "w8kYPXvECJd3TtqBhD+U5p4=" + ); + test_encode!( + b"\x8cJ\xf8e\x13\r\x8fw\xa8\xe6G\xce\x93c*\xe7M\xb6\xd7", + "jEr4ZRMNj3eo5kfOk2Mq50221w==" + ); + test_encode!( + b"\xde\xc4~\xb2}\xb1\x14F.~\xa1z|s\x90\x8dd\x9b\x04\x81\xf2\x92{", + "3sR+sn2xFEYufqF6fHOQjWSbBIHykns=" + ); + test_encode!( + b"\xf0y\t\x14\xd161n\x03e\xed\x0e\x05\xdf\xc1\xb9\xda", + "8HkJFNE2MW4DZe0OBd/Budo=" + ); + test_encode!( + b"*.\x8e\x1d@\x1ac\xdd;\x9a\xcc \x0c\xc2KI", + "Ki6OHUAaY907mswgDMJLSQ==" + ); + test_encode!(b"\xd6\x829\x82\xbc\x00\xc9\xfe\x03", "1oI5grwAyf4D"); + test_encode!( + b"\r\xf2\xb4\xd4\xa1g\x8fhl\xaa@\x98\x00\xda\x95", + "DfK01KFnj2hsqkCYANqV" + ); + test_encode!( + b"\x1a\xfaV\x1a\xc2e\xc0\xad\xef|\x07\xcf\xa9\xb7O", + "GvpWGsJlwK3vfAfPqbdP" + ); + test_encode!(b"\xc20{_\x81\xac", "wjB7X4Gs"); + test_encode!( + b"B\xa85\xac\xe9\x0ev-\x8bT\xb3|\xde", + "Qqg1rOkOdi2LVLN83g==" + ); + test_encode!( + b"\x05\xe0\xeeSs\xfdY9\x0b7\x84\xfc-\xec", + "BeDuU3P9WTkLN4T8Lew=" + ); + test_encode!( + b"Qj\x92\xfa?\xa5\xe3_[\xde\x82\x97{$\xb2\xf9\xd5\x98\x0cy\x15\xe4R\x8d", + "UWqS+j+l419b3oKXeySy+dWYDHkV5FKN" + ); + test_encode!(b"\x853\xe0\xc0\x1d\xc1", "hTPgwB3B"); + test_encode!(b"}2\xd0\x13m\x8d\x8f#\x9c\xf5,\xc7", "fTLQE22NjyOc9SzH"); + } + + #[test] + fn pregenerated_random_bytes_decode() { + macro_rules! test_decode { + ($left: expr, $right: expr) => { + assert_eq!( + base64_decode(&String::from($left)).unwrap(), + $right.to_vec() + ); + }; + } + test_decode!( + "0zHJh0T+qrP/74wOb0Q=", + b"\xd31\xc9\x87D\xfe\xaa\xb3\xff\xef\x8c\x0eoD" + ); + test_decode!( + "nw44vPXQLbQu1PA/j+cJey7/LzbLVFkhrjmC", + b"\x9f\x0e8\xbc\xf5\xd0-\xb4.\xd4\xf0?\x8f\xe7\t{.\xff/6\xcbTY!\xae9\x82" + ); + test_decode!("fzMVGtP5MZtTpDQ9", b"\x7f3\x15\x1a\xd3\xf91\x9bS\xa44="); + test_decode!( + "Nzr10Vu/Vi9QGAMA3M2h7Ec=", + b"7:\xf5\xd1[\xbfV/P\x18\x03\x00\xdc\xcd\xa1\xecG" + ); + test_decode!( + "w8kYPXvECJd3TtqBhD+U5p4=", + b"\xc3\xc9\x18={\xc4\x08\x97wN\xda\x81\x84?\x94\xe6\x9e" + ); + test_decode!( + "jEr4ZRMNj3eo5kfOk2Mq50221w==", + b"\x8cJ\xf8e\x13\r\x8fw\xa8\xe6G\xce\x93c*\xe7M\xb6\xd7" + ); + test_decode!( + "3sR+sn2xFEYufqF6fHOQjWSbBIHykns=", + b"\xde\xc4~\xb2}\xb1\x14F.~\xa1z|s\x90\x8dd\x9b\x04\x81\xf2\x92{" + ); + test_decode!( + "8HkJFNE2MW4DZe0OBd/Budo=", + b"\xf0y\t\x14\xd161n\x03e\xed\x0e\x05\xdf\xc1\xb9\xda" + ); + test_decode!( + "Ki6OHUAaY907mswgDMJLSQ==", + b"*.\x8e\x1d@\x1ac\xdd;\x9a\xcc \x0c\xc2KI" + ); + test_decode!("1oI5grwAyf4D", b"\xd6\x829\x82\xbc\x00\xc9\xfe\x03"); + test_decode!( + "DfK01KFnj2hsqkCYANqV", + b"\r\xf2\xb4\xd4\xa1g\x8fhl\xaa@\x98\x00\xda\x95" + ); + test_decode!( + "GvpWGsJlwK3vfAfPqbdP", + b"\x1a\xfaV\x1a\xc2e\xc0\xad\xef|\x07\xcf\xa9\xb7O" + ); + test_decode!("wjB7X4Gs", b"\xc20{_\x81\xac"); + test_decode!( + "Qqg1rOkOdi2LVLN83g==", + b"B\xa85\xac\xe9\x0ev-\x8bT\xb3|\xde" + ); + test_decode!( + "BeDuU3P9WTkLN4T8Lew=", + b"\x05\xe0\xeeSs\xfdY9\x0b7\x84\xfc-\xec" + ); + test_decode!( + "UWqS+j+l419b3oKXeySy+dWYDHkV5FKN", + b"Qj\x92\xfa?\xa5\xe3_[\xde\x82\x97{$\xb2\xf9\xd5\x98\x0cy\x15\xe4R\x8d" + ); + test_decode!("hTPgwB3B", b"\x853\xe0\xc0\x1d\xc1"); + test_decode!("fTLQE22NjyOc9SzH", b"}2\xd0\x13m\x8d\x8f#\x9c\xf5,\xc7"); + } + + #[test] + fn encode_decode() { + macro_rules! test_e_d { + ($text: expr) => { + assert_eq!( + base64_decode(&base64_encode(&$text.to_vec())).unwrap(), + $text + ); + }; + } + test_e_d!(b"green"); + test_e_d!(b"The quick brown fox jumped over the lazy dog."); + test_e_d!(b"Lorem Ipsum sit dolor amet."); + test_e_d!(b"0"); + test_e_d!(b"01"); + test_e_d!(b"012"); + test_e_d!(b"0123"); + test_e_d!(b"0123456789"); + } + + #[test] + fn decode_encode() { + macro_rules! test_d_e { + ($data: expr) => { + assert_eq!( + base64_encode(&base64_decode(&String::from($data)).unwrap()), + String::from($data) + ); + }; + } + test_d_e!("TG9uZyBsaXZlIGVhc3RlciBlZ2dzIDop"); + test_d_e!("SGFwcHkgSGFja3RvYmVyZmVzdCE="); + test_d_e!("PVRoZSBBbGdvcml0aG1zPQ=="); + } +} diff --git a/src/ciphers/blake2b.rs b/src/ciphers/blake2b.rs new file mode 100644 index 00000000000..c28486489d6 --- /dev/null +++ b/src/ciphers/blake2b.rs @@ -0,0 +1,318 @@ +// For specification go to https://www.rfc-editor.org/rfc/rfc7693 + +use std::cmp::{max, min}; +use std::convert::{TryFrom, TryInto}; + +type Word = u64; + +const BB: usize = 128; + +const U64BYTES: usize = (u64::BITS as usize) / 8; + +type Block = [Word; BB / U64BYTES]; + +const KK_MAX: usize = 64; +const NN_MAX: u8 = 64; + +// Array of round constants used in mixing function G +const RC: [u32; 4] = [32, 24, 16, 63]; + +// IV[i] = floor(2**64 * frac(sqrt(prime(i+1)))) where prime(i) is the ith prime number +const IV: [Word; 8] = [ + 0x6A09E667F3BCC908, + 0xBB67AE8584CAA73B, + 0x3C6EF372FE94F82B, + 0xA54FF53A5F1D36F1, + 0x510E527FADE682D1, + 0x9B05688C2B3E6C1F, + 0x1F83D9ABFB41BD6B, + 0x5BE0CD19137E2179, +]; + +const SIGMA: [[usize; 16]; 10] = [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3], + [11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4], + [7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8], + [9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13], + [2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9], + [12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11], + [13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10], + [6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5], + [10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0], +]; + +#[inline] +const fn blank_block() -> Block { + [0u64; BB / U64BYTES] +} + +// Overflowing addition +#[inline] +fn add(a: &mut Word, b: Word) { + *a = a.overflowing_add(b).0; +} + +#[inline] +const fn ceil(dividend: usize, divisor: usize) -> usize { + (dividend / divisor) + ((dividend % divisor != 0) as usize) +} + +fn g(v: &mut [Word; 16], a: usize, b: usize, c: usize, d: usize, x: Word, y: Word) { + for (m, r) in [x, y].into_iter().zip(RC.chunks(2)) { + let v_b = v[b]; + add(&mut v[a], v_b); + add(&mut v[a], m); + + v[d] = (v[d] ^ v[a]).rotate_right(r[0]); + + let v_d = v[d]; + add(&mut v[c], v_d); + + v[b] = (v[b] ^ v[c]).rotate_right(r[1]); + } +} + +fn f(h: &mut [Word; 8], m: Block, t: u128, flag: bool) { + let mut v: [Word; 16] = [0; 16]; + + for (i, (h_i, iv_i)) in h.iter().zip(IV.iter()).enumerate() { + v[i] = *h_i; + v[i + 8] = *iv_i; + } + + v[12] ^= (t % (u64::MAX as u128)) as u64; + v[13] ^= (t >> 64) as u64; + + if flag { + v[14] = !v[14]; + } + + for i in 0..12 { + let s = SIGMA[i % 10]; + + let mut s_index = 0; + for j in 0..4 { + g( + &mut v, + j, + j + 4, + j + 8, + j + 12, + m[s[s_index]], + m[s[s_index + 1]], + ); + + s_index += 2; + } + + let i1d = |col, row| { + let col = col % 4; + let row = row % 4; + + (row * 4) + col + }; + + for j in 0..4 { + // Produces indeces for diagonals of a 4x4 matrix starting at 0,j + let idx: Vec = (0..4).map(|n| i1d(j + n, n) as usize).collect(); + + g( + &mut v, + idx[0], + idx[1], + idx[2], + idx[3], + m[s[s_index]], + m[s[s_index + 1]], + ); + + s_index += 2; + } + } + + for (i, n) in h.iter_mut().enumerate() { + *n ^= v[i] ^ v[i + 8]; + } +} + +fn blake2(d: Vec, ll: u128, kk: Word, nn: Word) -> Vec { + let mut h: [Word; 8] = IV + .iter() + .take(8) + .copied() + .collect::>() + .try_into() + .unwrap(); + + h[0] ^= 0x01010000u64 ^ (kk << 8) ^ nn; + + if d.len() > 1 { + for (i, w) in d.iter().enumerate().take(d.len() - 1) { + f(&mut h, *w, (i as u128 + 1) * BB as u128, false); + } + } + + let ll = if kk > 0 { ll + BB as u128 } else { ll }; + f(&mut h, d[d.len() - 1], ll, true); + + h.iter() + .flat_map(|n| n.to_le_bytes()) + .take(nn as usize) + .collect() +} + +// Take arbitrarily long slice of u8's and turn up to 8 bytes into u64 +fn bytes_to_word(bytes: &[u8]) -> Word { + if let Ok(arr) = <[u8; U64BYTES]>::try_from(bytes) { + Word::from_le_bytes(arr) + } else { + let mut arr = [0u8; 8]; + for (a_i, b_i) in arr.iter_mut().zip(bytes) { + *a_i = *b_i; + } + + Word::from_le_bytes(arr) + } +} + +pub fn blake2b(m: &[u8], k: &[u8], nn: u8) -> Vec { + let kk = min(k.len(), KK_MAX); + let nn = min(nn, NN_MAX); + + // Prevent user from giving a key that is too long + let k = &k[..kk]; + + let dd = max(ceil(kk, BB) + ceil(m.len(), BB), 1); + + let mut blocks: Vec = vec![blank_block(); dd]; + + // Copy key into blocks + for (w, c) in blocks[0].iter_mut().zip(k.chunks(U64BYTES)) { + *w = bytes_to_word(c); + } + + let first_index = (kk > 0) as usize; + + // Copy bytes from message into blocks + for (i, c) in m.chunks(U64BYTES).enumerate() { + let block_index = first_index + (i / (BB / U64BYTES)); + let word_in_block = i % (BB / U64BYTES); + + blocks[block_index][word_in_block] = bytes_to_word(c); + } + + blake2(blocks, m.len() as u128, kk as u64, nn as Word) +} + +#[cfg(test)] +mod test { + use super::*; + + macro_rules! digest_test { + ($fname:ident, $message:expr, $key:expr, $nn:literal, $expected:expr) => { + #[test] + fn $fname() { + let digest = blake2b($message, $key, $nn); + + let expected = Vec::from($expected); + + assert_eq!(digest, expected); + } + }; + } + + digest_test!( + blake2b_from_rfc, + &[0x61, 0x62, 0x63], + &[0; 0], + 64, + [ + 0xBA, 0x80, 0xA5, 0x3F, 0x98, 0x1C, 0x4D, 0x0D, 0x6A, 0x27, 0x97, 0xB6, 0x9F, 0x12, + 0xF6, 0xE9, 0x4C, 0x21, 0x2F, 0x14, 0x68, 0x5A, 0xC4, 0xB7, 0x4B, 0x12, 0xBB, 0x6F, + 0xDB, 0xFF, 0xA2, 0xD1, 0x7D, 0x87, 0xC5, 0x39, 0x2A, 0xAB, 0x79, 0x2D, 0xC2, 0x52, + 0xD5, 0xDE, 0x45, 0x33, 0xCC, 0x95, 0x18, 0xD3, 0x8A, 0xA8, 0xDB, 0xF1, 0x92, 0x5A, + 0xB9, 0x23, 0x86, 0xED, 0xD4, 0x00, 0x99, 0x23 + ] + ); + + digest_test!( + blake2b_empty, + &[0; 0], + &[0; 0], + 64, + [ + 0x78, 0x6a, 0x02, 0xf7, 0x42, 0x01, 0x59, 0x03, 0xc6, 0xc6, 0xfd, 0x85, 0x25, 0x52, + 0xd2, 0x72, 0x91, 0x2f, 0x47, 0x40, 0xe1, 0x58, 0x47, 0x61, 0x8a, 0x86, 0xe2, 0x17, + 0xf7, 0x1f, 0x54, 0x19, 0xd2, 0x5e, 0x10, 0x31, 0xaf, 0xee, 0x58, 0x53, 0x13, 0x89, + 0x64, 0x44, 0x93, 0x4e, 0xb0, 0x4b, 0x90, 0x3a, 0x68, 0x5b, 0x14, 0x48, 0xb7, 0x55, + 0xd5, 0x6f, 0x70, 0x1a, 0xfe, 0x9b, 0xe2, 0xce + ] + ); + + digest_test!( + blake2b_empty_with_key, + &[0; 0], + &[ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, + 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f + ], + 64, + [ + 0x10, 0xeb, 0xb6, 0x77, 0x00, 0xb1, 0x86, 0x8e, 0xfb, 0x44, 0x17, 0x98, 0x7a, 0xcf, + 0x46, 0x90, 0xae, 0x9d, 0x97, 0x2f, 0xb7, 0xa5, 0x90, 0xc2, 0xf0, 0x28, 0x71, 0x79, + 0x9a, 0xaa, 0x47, 0x86, 0xb5, 0xe9, 0x96, 0xe8, 0xf0, 0xf4, 0xeb, 0x98, 0x1f, 0xc2, + 0x14, 0xb0, 0x05, 0xf4, 0x2d, 0x2f, 0xf4, 0x23, 0x34, 0x99, 0x39, 0x16, 0x53, 0xdf, + 0x7a, 0xef, 0xcb, 0xc1, 0x3f, 0xc5, 0x15, 0x68 + ] + ); + + digest_test!( + blake2b_key_shortin, + &[0], + &[ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, + 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f + ], + 64, + [ + 0x96, 0x1f, 0x6d, 0xd1, 0xe4, 0xdd, 0x30, 0xf6, 0x39, 0x01, 0x69, 0x0c, 0x51, 0x2e, + 0x78, 0xe4, 0xb4, 0x5e, 0x47, 0x42, 0xed, 0x19, 0x7c, 0x3c, 0x5e, 0x45, 0xc5, 0x49, + 0xfd, 0x25, 0xf2, 0xe4, 0x18, 0x7b, 0x0b, 0xc9, 0xfe, 0x30, 0x49, 0x2b, 0x16, 0xb0, + 0xd0, 0xbc, 0x4e, 0xf9, 0xb0, 0xf3, 0x4c, 0x70, 0x03, 0xfa, 0xc0, 0x9a, 0x5e, 0xf1, + 0x53, 0x2e, 0x69, 0x43, 0x02, 0x34, 0xce, 0xbd + ] + ); + + digest_test!( + blake2b_keyed_filled, + &[ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, + 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f + ], + &[ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, + 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, + 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f + ], + 64, + [ + 0x65, 0x67, 0x6d, 0x80, 0x06, 0x17, 0x97, 0x2f, 0xbd, 0x87, 0xe4, 0xb9, 0x51, 0x4e, + 0x1c, 0x67, 0x40, 0x2b, 0x7a, 0x33, 0x10, 0x96, 0xd3, 0xbf, 0xac, 0x22, 0xf1, 0xab, + 0xb9, 0x53, 0x74, 0xab, 0xc9, 0x42, 0xf1, 0x6e, 0x9a, 0xb0, 0xea, 0xd3, 0x3b, 0x87, + 0xc9, 0x19, 0x68, 0xa6, 0xe5, 0x09, 0xe1, 0x19, 0xff, 0x07, 0x78, 0x7b, 0x3e, 0xf4, + 0x83, 0xe1, 0xdc, 0xdc, 0xcf, 0x6e, 0x30, 0x22 + ] + ); +} diff --git a/src/ciphers/caesar.rs b/src/ciphers/caesar.rs index d86a7e36b38..970ae575fce 100644 --- a/src/ciphers/caesar.rs +++ b/src/ciphers/caesar.rs @@ -1,43 +1,118 @@ -//! Caesar Cipher -//! Based on cipher_crypt::caesar -//! -//! # Algorithm -//! -//! Rotate each ascii character by shift. The most basic example is ROT 13, which rotates 'a' to -//! 'n'. This implementation does not rotate unicode characters. - -/// Caesar cipher to rotate cipher text by shift and return an owned String. -pub fn caesar(cipher: &str, shift: u8) -> String { - cipher +const ERROR_MESSAGE: &str = "Rotation must be in the range [0, 25]"; +const ALPHABET_LENGTH: u8 = b'z' - b'a' + 1; + +/// Encrypts a given text using the Caesar cipher technique. +/// +/// In cryptography, a Caesar cipher, also known as Caesar's cipher, the shift cipher, Caesar's code, +/// or Caesar shift, is one of the simplest and most widely known encryption techniques. +/// It is a type of substitution cipher in which each letter in the plaintext is replaced by a letter +/// some fixed number of positions down the alphabet. +/// +/// # Arguments +/// +/// * `text` - The text to be encrypted. +/// * `rotation` - The number of rotations (shift) to be applied. It should be within the range [0, 25]. +/// +/// # Returns +/// +/// Returns a `Result` containing the encrypted string if successful, or an error message if the rotation +/// is out of the valid range. +/// +/// # Errors +/// +/// Returns an error if the rotation value is out of the valid range [0, 25] +pub fn caesar(text: &str, rotation: isize) -> Result { + if !(0..ALPHABET_LENGTH as isize).contains(&rotation) { + return Err(ERROR_MESSAGE); + } + + let result = text .chars() .map(|c| { if c.is_ascii_alphabetic() { - let first = if c.is_ascii_lowercase() { b'a' } else { b'A' }; - // modulo the distance to keep character range - (first + (c as u8 + shift - first) % 26) as char + shift_char(c, rotation) } else { c } }) - .collect() + .collect(); + + Ok(result) +} + +/// Shifts a single ASCII alphabetic character by a specified number of positions in the alphabet. +/// +/// # Arguments +/// +/// * `c` - The ASCII alphabetic character to be shifted. +/// * `rotation` - The number of positions to shift the character. Should be within the range [0, 25]. +/// +/// # Returns +/// +/// Returns the shifted ASCII alphabetic character. +fn shift_char(c: char, rotation: isize) -> char { + let first = if c.is_ascii_lowercase() { b'a' } else { b'A' }; + let rotation = rotation as u8; // Safe cast as rotation is within [0, 25] + + (((c as u8 - first) + rotation) % ALPHABET_LENGTH + first) as char } #[cfg(test)] mod tests { use super::*; - #[test] - fn empty() { - assert_eq!(caesar("", 13), ""); + macro_rules! test_caesar_happy_path { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (text, rotation, expected) = $test_case; + assert_eq!(caesar(&text, rotation).unwrap(), expected); + + let backward_rotation = if rotation == 0 { 0 } else { ALPHABET_LENGTH as isize - rotation }; + assert_eq!(caesar(&expected, backward_rotation).unwrap(), text); + } + )* + }; } - #[test] - fn caesar_rot_13() { - assert_eq!(caesar("rust", 13), "ehfg"); + macro_rules! test_caesar_error_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (text, rotation) = $test_case; + assert_eq!(caesar(&text, rotation), Err(ERROR_MESSAGE)); + } + )* + }; } #[test] - fn caesar_unicode() { - assert_eq!(caesar("attack at dawn 攻", 5), "fyyfhp fy ifbs 攻"); + fn alphabet_length_should_be_26() { + assert_eq!(ALPHABET_LENGTH, 26); + } + + test_caesar_happy_path! { + empty_text: ("", 13, ""), + rot_13: ("rust", 13, "ehfg"), + unicode: ("attack at dawn 攻", 5, "fyyfhp fy ifbs 攻"), + rotation_within_alphabet_range: ("Hello, World!", 3, "Khoor, Zruog!"), + no_rotation: ("Hello, World!", 0, "Hello, World!"), + rotation_at_alphabet_end: ("Hello, World!", 25, "Gdkkn, Vnqkc!"), + longer: ("The quick brown fox jumps over the lazy dog.", 5, "Ymj vznhp gwtbs ktc ozrux tajw ymj qfed itl."), + non_alphabetic_characters: ("12345!@#$%", 3, "12345!@#$%"), + uppercase_letters: ("ABCDEFGHIJKLMNOPQRSTUVWXYZ", 1, "BCDEFGHIJKLMNOPQRSTUVWXYZA"), + mixed_case: ("HeLlO WoRlD", 7, "OlSsV DvYsK"), + with_whitespace: ("Hello, World!", 13, "Uryyb, Jbeyq!"), + with_special_characters: ("Hello!@#$%^&*()_+World", 4, "Lipps!@#$%^&*()_+Asvph"), + with_numbers: ("Abcd1234XYZ", 10, "Klmn1234HIJ"), + } + + test_caesar_error_cases! { + negative_rotation: ("Hello, World!", -5), + empty_input_negative_rotation: ("", -1), + empty_input_large_rotation: ("", 27), + large_rotation: ("Large rotation", 139), } } diff --git a/src/ciphers/chacha.rs b/src/ciphers/chacha.rs index c41425094cf..6b0440a9d11 100644 --- a/src/ciphers/chacha.rs +++ b/src/ciphers/chacha.rs @@ -1,10 +1,3 @@ -/* - * ChaCha20 implementation based on RFC8439 - * ChaCha20 is a stream cipher developed independently by Daniel J. Bernstein. - * To use it, the `chacha20` function should be called with appropriate - * parameters and the output of the function should be XORed with plain text. - */ - macro_rules! quarter_round { ($a:expr,$b:expr,$c:expr,$d:expr) => { $a = $a.wrapping_add($b); @@ -22,66 +15,74 @@ macro_rules! quarter_round { // "expand 32-byte k", written in little-endian order pub const C: [u32; 4] = [0x61707865, 0x3320646e, 0x79622d32, 0x6b206574]; -/** - * `chacha20` function takes as input an array of 16 32-bit integers (512 bits) - * of which 128 bits is the constant 'expand 32-byte k', 256 bits is the key, - * and 128 bits are nonce and counter. According to RFC8439, the nonce should - * be 96 bits long, which leaves 32 bits for the counter. Given that the block - * length is 512 bits, this leaves enough counter values to encrypt 256GB of - * data. - * - * The 16 input numbers can be thought of as the elements of a 4x4 matrix like - * the one bellow, on which we do the main operations of the cipher. - * - * +----+----+----+----+ - * | 00 | 01 | 02 | 03 | - * +----+----+----+----+ - * | 04 | 05 | 06 | 07 | - * +----+----+----+----+ - * | 08 | 09 | 10 | 11 | - * +----+----+----+----+ - * | 12 | 13 | 14 | 15 | - * +----+----+----+----+ - * - * As per the diagram bellow, input[0, 1, 2, 3] are the constants mentioned - * above, input[4..=11] is filled with the key, and input[6..=9] should be - * filled with nonce and counter values. The output of the function is stored - * in `output` variable and can be XORed with the plain text to produce the - * cipher text. - * - * +------+------+------+------+ - * | | | | | - * | C[0] | C[1] | C[2] | C[3] | - * | | | | | - * +------+------+------+------+ - * | | | | | - * | key0 | key1 | key2 | key3 | - * | | | | | - * +------+------+------+------+ - * | | | | | - * | key4 | key5 | key6 | key7 | - * | | | | | - * +------+------+------+------+ - * | | | | | - * | ctr0 | no.0 | no.1 | no.2 | - * | | | | | - * +------+------+------+------+ - * - * Note that the constants, the key, and the nonce should be written in - * little-endian order, meaning that for example if the key is 01:02:03:04 - * (in hex), it corresponds to the integer 0x04030201. It is important to - * know that the hex value of the counter is meaningless, and only its integer - * value matters, and it should start with (for example) 0x00000000, and then - * 0x00000001 and so on until 0xffffffff. Keep in mind that as soon as we get - * from bytes to words, we stop caring about their representation in memory, - * and we only need the math to be correct. - * - * The output of the function can be used without any change, as long as the - * plain text has the same endianness. For example if the plain text is - * "hello world", and the first word of the output is 0x01020304, then the - * first byte of plain text ('h') should be XORed with the least-significant - * byte of 0x01020304, which is 0x04. -*/ +/// ChaCha20 implementation based on RFC8439 +/// +/// ChaCha20 is a stream cipher developed independently by Daniel J. Bernstein.\ +/// To use it, the `chacha20` function should be called with appropriate +/// parameters and the output of the function should be XORed with plain text. +/// +/// `chacha20` function takes as input an array of 16 32-bit integers (512 bits) +/// of which 128 bits is the constant 'expand 32-byte k', 256 bits is the key, +/// and 128 bits are nonce and counter. According to RFC8439, the nonce should +/// be 96 bits long, which leaves 32 bits for the counter. Given that the block +/// length is 512 bits, this leaves enough counter values to encrypt 256GB of +/// data. +/// +/// The 16 input numbers can be thought of as the elements of a 4x4 matrix like +/// the one bellow, on which we do the main operations of the cipher. +/// +/// ```text +/// +----+----+----+----+ +/// | 00 | 01 | 02 | 03 | +/// +----+----+----+----+ +/// | 04 | 05 | 06 | 07 | +/// +----+----+----+----+ +/// | 08 | 09 | 10 | 11 | +/// +----+----+----+----+ +/// | 12 | 13 | 14 | 15 | +/// +----+----+----+----+ +/// ``` +/// +/// As per the diagram bellow, `input[0, 1, 2, 3]` are the constants mentioned +/// above, `input[4..=11]` is filled with the key, and `input[6..=9]` should be +/// filled with nonce and counter values. The output of the function is stored +/// in `output` variable and can be XORed with the plain text to produce the +/// cipher text. +/// +/// ```text +/// +------+------+------+------+ +/// | | | | | +/// | C[0] | C[1] | C[2] | C[3] | +/// | | | | | +/// +------+------+------+------+ +/// | | | | | +/// | key0 | key1 | key2 | key3 | +/// | | | | | +/// +------+------+------+------+ +/// | | | | | +/// | key4 | key5 | key6 | key7 | +/// | | | | | +/// +------+------+------+------+ +/// | | | | | +/// | ctr0 | no.0 | no.1 | no.2 | +/// | | | | | +/// +------+------+------+------+ +/// ``` +/// +/// Note that the constants, the key, and the nonce should be written in +/// little-endian order, meaning that for example if the key is 01:02:03:04 +/// (in hex), it corresponds to the integer `0x04030201`. It is important to +/// know that the hex value of the counter is meaningless, and only its integer +/// value matters, and it should start with (for example) `0x00000000`, and then +/// `0x00000001` and so on until `0xffffffff`. Keep in mind that as soon as we get +/// from bytes to words, we stop caring about their representation in memory, +/// and we only need the math to be correct. +/// +/// The output of the function can be used without any change, as long as the +/// plain text has the same endianness. For example if the plain text is +/// "hello world", and the first word of the output is `0x01020304`, then the +/// first byte of plain text ('h') should be XORed with the least-significant +/// byte of `0x01020304`, which is `0x04`. pub fn chacha20(input: &[u32; 16], output: &mut [u32; 16]) { output.copy_from_slice(&input[..]); for _ in 0..10 { diff --git a/src/ciphers/diffie_hellman.rs b/src/ciphers/diffie_hellman.rs new file mode 100644 index 00000000000..3cfe53802bb --- /dev/null +++ b/src/ciphers/diffie_hellman.rs @@ -0,0 +1,347 @@ +// Based on the TheAlgorithms/Python +// RFC 3526 - More Modular Exponential (MODP) Diffie-Hellman groups for +// Internet Key Exchange (IKE) https://tools.ietf.org/html/rfc3526 + +use num_bigint::BigUint; +use num_traits::{Num, Zero}; +use std::{ + collections::HashMap, + sync::LazyLock, + time::{SystemTime, UNIX_EPOCH}, +}; + +// A map of predefined prime numbers for different bit lengths, as specified in RFC 3526 +static PRIMES: LazyLock> = LazyLock::new(|| { + let mut m: HashMap = HashMap::new(); + m.insert( + // 1536-bit + 5, + BigUint::parse_bytes( + b"FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1\ + 29024E088A67CC74020BBEA63B139B22514A08798E3404DD\ + EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245\ + E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED\ + EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3D\ + C2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F\ + 83655D23DCA3AD961C62F356208552BB9ED529077096966D\ + 670C354E4ABC9804F1746C08CA237327FFFFFFFFFFFFFFFF", + 16, + ) + .unwrap(), + ); + m.insert( + // 2048-bit + 14, + BigUint::parse_bytes( + b"FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1\ + 29024E088A67CC74020BBEA63B139B22514A08798E3404DD\ + EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245\ + E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED\ + EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3D\ + C2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F\ + 83655D23DCA3AD961C62F356208552BB9ED529077096966D\ + 670C354E4ABC9804F1746C08CA18217C32905E462E36CE3B\ + E39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9\ + DE2BCBF6955817183995497CEA956AE515D2261898FA0510\ + 15728E5A8AACAA68FFFFFFFFFFFFFFFF", + 16, + ) + .unwrap(), + ); + + m.insert( + // 3072-bit + 15, + BigUint::parse_bytes( + b"FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1\ + 29024E088A67CC74020BBEA63B139B22514A08798E3404DD\ + EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245\ + E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED\ + EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3D\ + C2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F\ + 83655D23DCA3AD961C62F356208552BB9ED529077096966D\ + 670C354E4ABC9804F1746C08CA18217C32905E462E36CE3B\ + E39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9\ + DE2BCBF6955817183995497CEA956AE515D2261898FA0510\ + 15728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64\ + ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7\ + ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6B\ + F12FFA06D98A0864D87602733EC86A64521F2B18177B200C\ + BBE117577A615D6C770988C0BAD946E208E24FA074E5AB31\ + 43DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF", + 16, + ) + .unwrap(), + ); + m.insert( + // 4096-bit + 16, + BigUint::parse_bytes( + b"FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1\ + 29024E088A67CC74020BBEA63B139B22514A08798E3404DD\ + EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245\ + E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED\ + EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3D\ + C2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F\ + 83655D23DCA3AD961C62F356208552BB9ED529077096966D\ + 670C354E4ABC9804F1746C08CA18217C32905E462E36CE3B\ + E39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9\ + DE2BCBF6955817183995497CEA956AE515D2261898FA0510\ + 15728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64\ + ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7\ + ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6B\ + F12FFA06D98A0864D87602733EC86A64521F2B18177B200C\ + BBE117577A615D6C770988C0BAD946E208E24FA074E5AB31\ + 43DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D7\ + 88719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA\ + 2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6\ + 287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED\ + 1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA9\ + 93B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199\ + FFFFFFFFFFFFFFFF", + 16, + ) + .unwrap(), + ); + m.insert( + // 6144-bit + 17, + BigUint::parse_bytes( + b"FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E08\ + 8A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B\ + 302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9\ + A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE6\ + 49286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8\ + FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D\ + 670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C\ + 180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF695581718\ + 3995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D\ + 04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7D\ + B3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D226\ + 1AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200C\ + BBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFC\ + E0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B26\ + 99C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB\ + 04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2\ + 233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127\ + D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934028492\ + 36C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BDF8FF9406\ + AD9E530EE5DB382F413001AEB06A53ED9027D831179727B0865A8918\ + DA3EDBEBCF9B14ED44CE6CBACED4BB1BDB7F1447E6CC254B33205151\ + 2BD7AF426FB8F401378CD2BF5983CA01C64B92ECF032EA15D1721D03\ + F482D7CE6E74FEF6D55E702F46980C82B5A84031900B1C9E59E7C97F\ + BEC7E8F323A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AA\ + CC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE32806A1D58B\ + B7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55CDA56C9EC2EF29632\ + 387FE8D76E3C0468043E8F663F4860EE12BF2D5B0B7474D6E694F91E\ + 6DCC4024FFFFFFFFFFFFFFFF", + 16, + ) + .unwrap(), + ); + + m.insert( + // 8192-bit + 18, + BigUint::parse_bytes( + b"FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1\ + 29024E088A67CC74020BBEA63B139B22514A08798E3404DD\ + EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245\ + E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED\ + EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3D\ + C2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F\ + 83655D23DCA3AD961C62F356208552BB9ED529077096966D\ + 670C354E4ABC9804F1746C08CA18217C32905E462E36CE3B\ + E39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9\ + DE2BCBF6955817183995497CEA956AE515D2261898FA0510\ + 15728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64\ + ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7\ + ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6B\ + F12FFA06D98A0864D87602733EC86A64521F2B18177B200C\ + BBE117577A615D6C770988C0BAD946E208E24FA074E5AB31\ + 43DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D7\ + 88719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA\ + 2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6\ + 287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED\ + 1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA9\ + 93B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934028492\ + 36C3FAB4D27C7026C1D4DCB2602646DEC9751E763DBA37BD\ + F8FF9406AD9E530EE5DB382F413001AEB06A53ED9027D831\ + 179727B0865A8918DA3EDBEBCF9B14ED44CE6CBACED4BB1B\ + DB7F1447E6CC254B332051512BD7AF426FB8F401378CD2BF\ + 5983CA01C64B92ECF032EA15D1721D03F482D7CE6E74FEF6\ + D55E702F46980C82B5A84031900B1C9E59E7C97FBEC7E8F3\ + 23A97A7E36CC88BE0F1D45B7FF585AC54BD407B22B4154AA\ + CC8F6D7EBF48E1D814CC5ED20F8037E0A79715EEF29BE328\ + 06A1D58BB7C5DA76F550AA3D8A1FBFF0EB19CCB1A313D55C\ + DA56C9EC2EF29632387FE8D76E3C0468043E8F663F4860EE\ + 12BF2D5B0B7474D6E694F91E6DBE115974A3926F12FEE5E4\ + 38777CB6A932DF8CD8BEC4D073B931BA3BC832B68D9DD300\ + 741FA7BF8AFC47ED2576F6936BA424663AAB639C5AE4F568\ + 3423B4742BF1C978238F16CBE39D652DE3FDB8BEFC848AD9\ + 22222E04A4037C0713EB57A81A23F0C73473FC646CEA306B\ + 4BCBC8862F8385DDFA9D4B7FA2C087E879683303ED5BDD3A\ + 062B3CF5B3A278A66D2A13F83F44F82DDF310EE074AB6A36\ + 4597E899A0255DC164F31CC50846851DF9AB48195DED7EA1\ + B1D510BD7EE74D73FAF36BC31ECFA268359046F4EB879F92\ + 4009438B481C6CD7889A002ED5EE382BC9190DA6FC026E47\ + 9558E4475677E9AA9E3050E2765694DFC81F56E880B96E71\ + 60C980DD98EDD3DFFFFFFFFFFFFFFFFF", + 16, + ) + .unwrap(), + ); + m +}); + +/// Generating random number, should use num_bigint::RandomBits if possible. +fn rand() -> usize { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .subsec_nanos() as usize +} + +pub struct DiffieHellman { + prime: BigUint, + private_key: BigUint, + public_key: BigUint, + generator: u8, +} + +impl DiffieHellman { + // Diffie-Hellman key exchange algorithm is based on the following mathematical concepts: + + // - A large prime number p (known as the prime modulus) is chosen and shared by both parties. + + // - A base number g (known as the generator) is chosen and shared by both parties. + + // - Each party generates a private key a or b (which are secret and only known to that party) and calculates a corresponding public key A or B using the following formulas: + // - A = g^a mod p + // - B = g^b mod p + + // - Each party then exchanges their public keys with each other. + + // - Each party then calculates the shared secret key s using the following formula: + // - s = B^a mod p or s = A^b mod p + + // Both parties now have the same shared secret key s which can be used for encryption or authentication. + + pub fn new(group: Option) -> Self { + let mut _group: u8 = 14; + if let Some(x) = group { + _group = x; + } + + if !PRIMES.contains_key(&_group) { + panic!("group not in primes") + } + + // generate private key + let private_key: BigUint = BigUint::from(rand()); + + Self { + prime: PRIMES[&_group].clone(), + private_key, + generator: 2, // the generator is 2 for all the primes if this would not be the case it can be added to hashmap + public_key: BigUint::default(), + } + } + + /// get private key as hexadecimal String + pub fn get_private_key(&self) -> String { + self.private_key.to_str_radix(16) + } + + /// Generate public key A = g**a mod p + pub fn generate_public_key(&mut self) -> String { + self.public_key = BigUint::from(self.generator).modpow(&self.private_key, &self.prime); + self.public_key.to_str_radix(16) + } + + pub fn is_valid_public_key(&self, key_str: &str) -> bool { + // the unwrap_or_else will make sure it is false, because 2 <= 0 and therefor False is returned + let key = BigUint::from_str_radix(key_str, 16) + .unwrap_or_else(|_| BigUint::parse_bytes(b"0", 16).unwrap()); + + // Check if the other public key is valid based on NIST SP800-56 + if BigUint::from(2_u8) <= key + && key <= &self.prime - BigUint::from(2_u8) + && !key + .modpow( + &((&self.prime - BigUint::from(1_u8)) / BigUint::from(2_u8)), + &self.prime, + ) + .is_zero() + { + return true; + } + false + } + + /// Generate the shared key + pub fn generate_shared_key(self, other_key_str: &str) -> Option { + let other_key = BigUint::from_str_radix(other_key_str, 16) + .unwrap_or_else(|_| BigUint::parse_bytes(b"0", 16).unwrap()); + if !self.is_valid_public_key(&other_key.to_str_radix(16)) { + return None; + } + + let shared_key = other_key.modpow(&self.private_key, &self.prime); + Some(shared_key.to_str_radix(16)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn verify_invalid_pub_key() { + let diffie = DiffieHellman::new(Some(14)); + assert!(!diffie.is_valid_public_key("0000")); + } + + #[test] + fn verify_valid_pub_key() { + let diffie = DiffieHellman::new(Some(14)); + assert!(diffie.is_valid_public_key("EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245")); + } + + #[test] + fn verify_invalid_pub_key_same_as_prime() { + let diffie = DiffieHellman::new(Some(14)); + assert!(!diffie.is_valid_public_key( + "FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1\ + 29024E088A67CC74020BBEA63B139B22514A08798E3404DD\ + EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245\ + E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED\ + EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3D\ + C2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F\ + 83655D23DCA3AD961C62F356208552BB9ED529077096966D\ + 670C354E4ABC9804F1746C08CA18217C32905E462E36CE3B\ + E39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9\ + DE2BCBF6955817183995497CEA956AE515D2261898FA0510\ + 15728E5A8AACAA68FFFFFFFFFFFFFFFF" + )); + } + + #[test] + fn verify_key_exchange() { + let mut alice = DiffieHellman::new(Some(16)); + let mut bob = DiffieHellman::new(Some(16)); + + // Private key not used, showed for illustrative purpose + let _alice_private = alice.get_private_key(); + let alice_public = alice.generate_public_key(); + + // Private key not used, showed for illustrative purpose + let _bob_private = bob.get_private_key(); + let bob_public = bob.generate_public_key(); + + // generating shared key using the struct implemenations + let alice_shared = alice.generate_shared_key(bob_public.as_str()).unwrap(); + let bob_shared = bob.generate_shared_key(alice_public.as_str()).unwrap(); + assert_eq!(alice_shared, bob_shared); + } +} diff --git a/src/ciphers/hashing_traits.rs b/src/ciphers/hashing_traits.rs index 3ef24f924ad..af8b8391f09 100644 --- a/src/ciphers/hashing_traits.rs +++ b/src/ciphers/hashing_traits.rs @@ -69,7 +69,7 @@ impl> #[cfg(test)] mod tests { - use super::super::sha256::get_hash_string; + use super::super::sha256::tests::get_hash_string; use super::super::SHA256; use super::HMAC; @@ -79,7 +79,7 @@ mod tests { // echo -n "Hello World" | openssl sha256 -hex -mac HMAC -macopt hexkey:"deadbeef" let mut hmac: HMAC<64, 32, SHA256> = HMAC::new_default(); hmac.add_key(&[0xde, 0xad, 0xbe, 0xef]).unwrap(); - hmac.update(&b"Hello World".to_vec()); + hmac.update(b"Hello World"); let hash = hmac.finalize(); assert_eq!( get_hash_string(&hash), diff --git a/src/ciphers/kerninghan.rs b/src/ciphers/kerninghan.rs new file mode 100644 index 00000000000..4263850ff3f --- /dev/null +++ b/src/ciphers/kerninghan.rs @@ -0,0 +1,23 @@ +pub fn kerninghan(n: u32) -> i32 { + let mut count = 0; + let mut n = n; + + while n > 0 { + n = n & (n - 1); + count += 1; + } + + count +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn count_set_bits() { + assert_eq!(kerninghan(0b0000_0000_0000_0000_0000_0000_0000_1011), 3); + assert_eq!(kerninghan(0b0000_0000_0000_0000_0000_0000_1000_0000), 1); + assert_eq!(kerninghan(0b1111_1111_1111_1111_1111_1111_1111_1101), 31); + } +} diff --git a/src/ciphers/mod.rs b/src/ciphers/mod.rs index c65202374cd..f7a55b0014d 100644 --- a/src/ciphers/mod.rs +++ b/src/ciphers/mod.rs @@ -1,30 +1,43 @@ mod aes; mod another_rot13; +mod baconian_cipher; +mod base64; +mod blake2b; mod caesar; mod chacha; +mod diffie_hellman; mod hashing_traits; +mod kerninghan; mod morse_code; mod polybius; +mod rail_fence; mod rot13; mod salsa; mod sha256; +mod sha3; mod tea; mod theoretical_rot13; mod transposition; mod vigenere; mod xor; - pub use self::aes::{aes_decrypt, aes_encrypt, AesKey}; pub use self::another_rot13::another_rot13; +pub use self::baconian_cipher::{baconian_decode, baconian_encode}; +pub use self::base64::{base64_decode, base64_encode}; +pub use self::blake2b::blake2b; pub use self::caesar::caesar; pub use self::chacha::chacha20; +pub use self::diffie_hellman::DiffieHellman; pub use self::hashing_traits::Hasher; pub use self::hashing_traits::HMAC; +pub use self::kerninghan::kerninghan; pub use self::morse_code::{decode, encode}; pub use self::polybius::{decode_ascii, encode_ascii}; +pub use self::rail_fence::{rail_fence_decrypt, rail_fence_encrypt}; pub use self::rot13::rot13; pub use self::salsa::salsa20; pub use self::sha256::SHA256; +pub use self::sha3::{sha3_224, sha3_256, sha3_384, sha3_512}; pub use self::tea::{tea_decrypt, tea_encrypt}; pub use self::theoretical_rot13::theoretical_rot13; pub use self::transposition::transposition; diff --git a/src/ciphers/morse_code.rs b/src/ciphers/morse_code.rs index 2b234dccbb8..c1ecaa5b2ad 100644 --- a/src/ciphers/morse_code.rs +++ b/src/ciphers/morse_code.rs @@ -8,15 +8,14 @@ pub fn encode(message: &str) -> String { let dictionary = _morse_dictionary(); message .chars() - .into_iter() .map(|char| char.to_uppercase().to_string()) .map(|letter| dictionary.get(letter.as_str())) - .map(|option| option.unwrap_or(&UNKNOWN_CHARACTER).to_string()) + .map(|option| (*option.unwrap_or(&UNKNOWN_CHARACTER)).to_string()) .collect::>() .join(" ") } -// Declaritive macro for creating readable map declarations, for more info see https://doc.rust-lang.org/book/ch19-06-macros.html +// Declarative macro for creating readable map declarations, for more info see https://doc.rust-lang.org/book/ch19-06-macros.html macro_rules! map { ($($key:expr => $value:expr),* $(,)?) => { std::iter::Iterator::collect(IntoIterator::into_iter([$(($key, $value),)*])) @@ -90,18 +89,14 @@ fn _check_all_parts(string: &str) -> bool { } fn _decode_token(string: &str) -> String { - _morse_to_alphanumeric_dictionary() + (*_morse_to_alphanumeric_dictionary() .get(string) - .unwrap_or(&_UNKNOWN_MORSE_CHARACTER) - .to_string() + .unwrap_or(&_UNKNOWN_MORSE_CHARACTER)) + .to_string() } fn _decode_part(string: &str) -> String { - string - .split(' ') - .map(_decode_token) - .collect::>() - .join("") + string.split(' ').map(_decode_token).collect::() } /// Convert morse code to ascii. @@ -173,12 +168,7 @@ mod tests { #[test] fn decrypt_valid_character_set_invalid_morsecode() { let expected = format!( - "{}{}{}{} {}", - _UNKNOWN_MORSE_CHARACTER, - _UNKNOWN_MORSE_CHARACTER, - _UNKNOWN_MORSE_CHARACTER, - _UNKNOWN_MORSE_CHARACTER, - _UNKNOWN_MORSE_CHARACTER, + "{_UNKNOWN_MORSE_CHARACTER}{_UNKNOWN_MORSE_CHARACTER}{_UNKNOWN_MORSE_CHARACTER}{_UNKNOWN_MORSE_CHARACTER} {_UNKNOWN_MORSE_CHARACTER}", ); let encypted = ".-.-.--.-.-. --------. ..---.-.-. .-.-.--.-.-. / .-.-.--.-.-.".to_string(); diff --git a/src/ciphers/rail_fence.rs b/src/ciphers/rail_fence.rs new file mode 100644 index 00000000000..aedff07ea31 --- /dev/null +++ b/src/ciphers/rail_fence.rs @@ -0,0 +1,41 @@ +// wiki: https://en.wikipedia.org/wiki/Rail_fence_cipher +pub fn rail_fence_encrypt(plain_text: &str, key: usize) -> String { + let mut cipher = vec![Vec::new(); key]; + + for (c, i) in plain_text.chars().zip(zigzag(key)) { + cipher[i].push(c); + } + + return cipher.iter().flatten().collect::(); +} + +pub fn rail_fence_decrypt(cipher: &str, key: usize) -> String { + let mut indices: Vec<_> = zigzag(key).zip(1..).take(cipher.len()).collect(); + indices.sort(); + + let mut cipher_text: Vec<_> = cipher + .chars() + .zip(indices) + .map(|(c, (_, i))| (i, c)) + .collect(); + + cipher_text.sort(); + return cipher_text.iter().map(|(_, c)| c).collect(); +} + +fn zigzag(n: usize) -> impl Iterator { + (0..n - 1).chain((1..n).rev()).cycle() +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn rails_basic() { + assert_eq!(rail_fence_encrypt("attack at once", 2), "atc toctaka ne"); + assert_eq!(rail_fence_decrypt("atc toctaka ne", 2), "attack at once"); + + assert_eq!(rail_fence_encrypt("rust is cool", 3), "r cuti olsso"); + assert_eq!(rail_fence_decrypt("r cuti olsso", 3), "rust is cool"); + } +} diff --git a/src/ciphers/salsa.rs b/src/ciphers/salsa.rs index 77ccd7514e0..83b37556ff1 100644 --- a/src/ciphers/salsa.rs +++ b/src/ciphers/salsa.rs @@ -1,10 +1,3 @@ -/* - * Salsa20 implementation based on https://en.wikipedia.org/wiki/Salsa20 - * Salsa20 is a stream cipher developed by Daniel J. Bernstein. To use it, the - * `salsa20` function should be called with appropriate parameters and the - * output of the function should be XORed with plain text. - */ - macro_rules! quarter_round { ($v1:expr,$v2:expr,$v3:expr,$v4:expr) => { $v2 ^= ($v1.wrapping_add($v4).rotate_left(7)); @@ -14,53 +7,57 @@ macro_rules! quarter_round { }; } -#[allow(dead_code)] -pub const C: [u32; 4] = [0x65787061, 0x6e642033, 0x322d6279, 0x7465206b]; - -/** - * `salsa20` function takes as input an array of 16 32-bit integers (512 bits) - * of which 128 bits is the constant 'expand 32-byte k', 256 bits is the key, - * and 128 bits are nonce and counter. It is up to the user to determine how - * many bits each of nonce and counter take, but a default of 64 bits each - * seems to be a sane choice. - * - * The 16 input numbers can be thought of as the elements of a 4x4 matrix like - * the one bellow, on which we do the main operations of the cipher. - * - * +----+----+----+----+ - * | 00 | 01 | 02 | 03 | - * +----+----+----+----+ - * | 04 | 05 | 06 | 07 | - * +----+----+----+----+ - * | 08 | 09 | 10 | 11 | - * +----+----+----+----+ - * | 12 | 13 | 14 | 15 | - * +----+----+----+----+ - * - * As per the diagram bellow, input[0, 5, 10, 15] are the constants mentioned - * above, input[1, 2, 3, 4, 11, 12, 13, 14] is filled with the key, and - * input[6, 7, 8, 9] should be filled with nonce and counter values. The output - * of the function is stored in `output` variable and can be XORed with the - * plain text to produce the cipher text. - * - * +------+------+------+------+ - * | | | | | - * | C[0] | key1 | key2 | key3 | - * | | | | | - * +------+------+------+------+ - * | | | | | - * | key4 | C[1] | no1 | no2 | - * | | | | | - * +------+------+------+------+ - * | | | | | - * | ctr1 | ctr2 | C[2] | key5 | - * | | | | | - * +------+------+------+------+ - * | | | | | - * | key6 | key7 | key8 | C[3] | - * | | | | | - * +------+------+------+------+ -*/ +/// This is a `Salsa20` implementation based on \ +/// `Salsa20` is a stream cipher developed by Daniel J. Bernstein.\ +/// To use it, the `salsa20` function should be called with appropriate parameters and the +/// output of the function should be XORed with plain text. +/// +/// `salsa20` function takes as input an array of 16 32-bit integers (512 bits) +/// of which 128 bits is the constant 'expand 32-byte k', 256 bits is the key, +/// and 128 bits are nonce and counter. It is up to the user to determine how +/// many bits each of nonce and counter take, but a default of 64 bits each +/// seems to be a sane choice. +/// +/// The 16 input numbers can be thought of as the elements of a 4x4 matrix like +/// the one bellow, on which we do the main operations of the cipher. +/// +/// ```text +/// +----+----+----+----+ +/// | 00 | 01 | 02 | 03 | +/// +----+----+----+----+ +/// | 04 | 05 | 06 | 07 | +/// +----+----+----+----+ +/// | 08 | 09 | 10 | 11 | +/// +----+----+----+----+ +/// | 12 | 13 | 14 | 15 | +/// +----+----+----+----+ +/// ``` +/// +/// As per the diagram bellow, `input[0, 5, 10, 15]` are the constants mentioned +/// above, `input[1, 2, 3, 4, 11, 12, 13, 14]` is filled with the key, and +/// `input[6, 7, 8, 9]` should be filled with nonce and counter values. The output +/// of the function is stored in `output` variable and can be XORed with the +/// plain text to produce the cipher text. +/// +/// ```text +/// +------+------+------+------+ +/// | | | | | +/// | C[0] | key1 | key2 | key3 | +/// | | | | | +/// +------+------+------+------+ +/// | | | | | +/// | key4 | C[1] | no1 | no2 | +/// | | | | | +/// +------+------+------+------+ +/// | | | | | +/// | ctr1 | ctr2 | C[2] | key5 | +/// | | | | | +/// +------+------+------+------+ +/// | | | | | +/// | key6 | key7 | key8 | C[3] | +/// | | | | | +/// +------+------+------+------+ +/// ``` pub fn salsa20(input: &[u32; 16], output: &mut [u32; 16]) { output.copy_from_slice(&input[..]); for _ in 0..10 { @@ -86,6 +83,8 @@ mod tests { use super::*; use std::fmt::Write; + const C: [u32; 4] = [0x65787061, 0x6e642033, 0x322d6279, 0x7465206b]; + fn output_hex(inp: &[u32; 16]) -> String { let mut res = String::new(); res.reserve(512 / 4); diff --git a/src/ciphers/sha256.rs b/src/ciphers/sha256.rs index 52591895e24..af6ce814434 100644 --- a/src/ciphers/sha256.rs +++ b/src/ciphers/sha256.rs @@ -5,8 +5,6 @@ * integer multiple of 8 */ -use std::fmt::Write; - // The constants are tested to make sure they are correct pub const H0: [u32; 8] = [ 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, @@ -100,17 +98,6 @@ fn process_block(h: &mut [u32; 8], w: &mut [u32; 64], round: &mut [u32; 8], buf: } } -#[allow(dead_code)] -// Let's keep this utility function -pub fn get_hash_string(hash: &[u8; 32]) -> String { - let mut result = String::new(); - result.reserve(64); - for &ch in hash { - write!(&mut result, "{ch:02x}").unwrap(); - } - result -} - impl SHA256 { pub fn new_default() -> Self { SHA256 { @@ -219,9 +206,20 @@ impl super::Hasher<32> for SHA256 { } #[cfg(test)] -mod tests { +pub mod tests { use super::*; use crate::math::LinearSieve; + use std::fmt::Write; + + // Let's keep this utility function + pub fn get_hash_string(hash: &[u8; 32]) -> String { + let mut result = String::new(); + result.reserve(64); + for &ch in hash { + write!(&mut result, "{ch:02x}").unwrap(); + } + result + } #[test] fn test_constants() { @@ -272,7 +270,7 @@ mod tests { #[test] fn ascii() { let mut res = SHA256::new_default(); - res.update(&b"The quick brown fox jumps over the lazy dog".to_vec()); + res.update(b"The quick brown fox jumps over the lazy dog"); assert_eq!( res.get_hash(), [ @@ -286,7 +284,7 @@ mod tests { #[test] fn ascii_avalanche() { let mut res = SHA256::new_default(); - res.update(&b"The quick brown fox jumps over the lazy dog.".to_vec()); + res.update(b"The quick brown fox jumps over the lazy dog."); assert_eq!( res.get_hash(), [ @@ -308,7 +306,7 @@ mod tests { #[test] fn long_ascii() { let mut res = SHA256::new_default(); - let val = &b"The quick brown fox jumps over the lazy dog.".to_vec(); + let val = b"The quick brown fox jumps over the lazy dog."; for _ in 0..1000 { res.update(val); } @@ -318,7 +316,7 @@ mod tests { "c264fca077807d391df72fadf39dd63be21f1823f65ca530c9637760eabfc18c" ); let mut res = SHA256::new_default(); - let val = &b"a".to_vec(); + let val = b"a"; for _ in 0..999 { res.update(val); } @@ -331,7 +329,7 @@ mod tests { #[test] fn short_ascii() { let mut res = SHA256::new_default(); - let val = &b"a".to_vec(); + let val = b"a"; res.update(val); let hash = res.get_hash(); assert_eq!( diff --git a/src/ciphers/sha3.rs b/src/ciphers/sha3.rs new file mode 100644 index 00000000000..f3791214f3f --- /dev/null +++ b/src/ciphers/sha3.rs @@ -0,0 +1,590 @@ +/// Size of the state array in bits +const B: usize = 1600; + +const W: usize = B / 25; +const L: usize = W.ilog2() as usize; + +const U8BITS: usize = u8::BITS as usize; + +// Macro for looping through the whole state array +macro_rules! iterate { + ( $x:ident, $y:ident, $z:ident => $b:block ) => { + for $y in 0..5 { + for $x in 0..5 { + for $z in 0..W { + $b + } + } + } + }; +} + +/// A function that produces a padding string such that the length of the padding + the length of +/// the string to be padded (2nd parameter) is divisible by the 1st parameter +type PadFn = fn(isize, isize) -> Vec; +type SpongeFn = fn(&[bool]) -> [bool; B]; + +type State = [[[bool; W]; 5]; 5]; + +fn state_new() -> State { + [[[false; W]; 5]; 5] +} + +fn state_fill(dest: &mut State, bits: &[bool]) { + let mut i = 0usize; + + iterate!(x, y, z => { + if i >= bits.len() { return; } + dest[x][y][z] = bits[i]; + i += 1; + }); +} + +fn state_copy(dest: &mut State, src: &State) { + iterate!(x, y, z => { + dest[x][y][z] = src[x][y][z]; + }); +} + +fn state_dump(state: &State) -> [bool; B] { + let mut bits = [false; B]; + + let mut i = 0usize; + + iterate!(x, y, z => { + bits[i] = state[x][y][z]; + i += 1; + }); + + bits +} + +/// XORs the state with the parities of two columns in the state array +fn theta(state: &mut State) { + let mut c = [[false; W]; 5]; + let mut d = [[false; W]; 5]; + + // Assign values of C[x,z] + for x in 0..5 { + for z in 0..W { + c[x][z] = state[x][0][z]; + + for y in 1..5 { + c[x][z] ^= state[x][y][z]; + } + } + } + + // Assign values of D[x,z] + for x in 0..5 { + for z in 0..W { + let x1 = (x as isize - 1).rem_euclid(5) as usize; + let z2 = (z as isize - 1).rem_euclid(W as isize) as usize; + + d[x][z] = c[x1][z] ^ c[(x + 1) % 5][z2]; + } + } + + // Xor values of D[x,z] into our state array + iterate!(x, y, z => { + state[x][y][z] ^= d[x][z]; + }); +} + +/// Rotates each lane by an offset depending of the x and y indeces +fn rho(state: &mut State) { + let mut new_state = state_new(); + + for z in 0..W { + new_state[0][0][z] = state[0][0][z]; + } + + let mut x = 1; + let mut y = 0; + + for t in 0..=23isize { + for z in 0..W { + let z_offset: isize = ((t + 1) * (t + 2)) / 2; + let new_z = (z as isize - z_offset).rem_euclid(W as isize) as usize; + + new_state[x][y][z] = state[x][y][new_z]; + } + + let old_y = y; + y = ((2 * x) + (3 * y)) % 5; + x = old_y; + } + + state_copy(state, &new_state); +} + +/// Rearrange the positions of the lanes of the state array +fn pi(state: &mut State) { + let mut new_state = state_new(); + + iterate!(x, y, z => { + new_state[x][y][z] = state[(x + (3 * y)) % 5][x][z]; + }); + + state_copy(state, &new_state); +} + +fn chi(state: &mut State) { + let mut new_state = state_new(); + + iterate!(x, y, z => { + new_state[x][y][z] = state[x][y][z] ^ ((state[(x + 1) % 5][y][z] ^ true) & state[(x + 2) % 5][y][z]); + }); + + state_copy(state, &new_state); +} + +/// Calculates the round constant depending on what the round number is +fn rc(t: u8) -> bool { + let mut b1: u16; + let mut b2: u16; + let mut r: u16 = 0x80; // tread r as an array of bits + + //if t % 0xFF == 0 { return true; } + + for _i in 0..(t % 255) { + b1 = r >> 8; + b2 = r & 1; + r |= (b1 ^ b2) << 8; + + b1 = (r >> 4) & 1; + r &= 0x1EF; // clear r[4] + r |= (b1 ^ b2) << 4; + + b1 = (r >> 3) & 1; + r &= 0x1F7; // clear r[3] + r |= (b1 ^ b2) << 3; + + b1 = (r >> 2) & 1; + r &= 0x1FB; // clear r[2] + r |= (b1 ^ b2) << 2; + + r >>= 1; + } + + (r >> 7) != 0 +} + +/// Applies the round constant to the first lane of the state array +fn iota(state: &mut State, i_r: u8) { + let mut rc_arr = [false; W]; + + for j in 0..=L { + rc_arr[(1 << j) - 1] = rc((j as u8) + (7 * i_r)); + } + + for (z, bit) in rc_arr.iter().enumerate() { + state[0][0][z] ^= *bit; + } +} + +fn rnd(state: &mut State, i_r: u8) { + theta(state); + rho(state); + pi(state); + chi(state); + iota(state, i_r); +} + +fn keccak_f(bits: &[bool]) -> [bool; B] { + let n_r = 12 + (2 * L); + + let mut state = state_new(); + state_fill(&mut state, bits); + + for i_r in 0..n_r { + rnd(&mut state, i_r as u8); + } + + state_dump(&state) +} + +fn pad101(x: isize, m: isize) -> Vec { + let mut j = -m - 2; + + while j < 0 { + j += x; + } + + j %= x; + + let mut ret = vec![false; (j as usize) + 2]; + *ret.first_mut().unwrap() = true; + *ret.last_mut().unwrap() = true; + + ret +} + +/// Sponge construction is a method of compression needing 1) a function on fixed-length bit +/// strings( here we use keccak_f), 2) a padding function (pad10*1), and 3) a rate. The input and +/// output of this method can be arbitrarily long +fn sponge(f: SpongeFn, pad: PadFn, r: usize, n: &[bool], d: usize) -> Vec { + let mut p = Vec::from(n); + p.append(&mut pad(r as isize, n.len() as isize)); + + assert!(r < B); + + let mut s = [false; B]; + for chunk in p.chunks(r) { + for (s_i, c_i) in s.iter_mut().zip(chunk) { + *s_i ^= c_i; + } + + s = f(&s); + } + + let mut z = Vec::::new(); + while z.len() < d { + z.extend(&s); + + s = f(&s); + } + + z.truncate(d); + z +} + +fn keccak(c: usize, n: &[bool], d: usize) -> Vec { + sponge(keccak_f, pad101, B - c, n, d) +} + +fn h2b(h: &[u8], n: usize) -> Vec { + let mut bits = Vec::with_capacity(h.len() * U8BITS); + + for byte in h { + for i in 0..u8::BITS { + let mask: u8 = 1 << i; + + bits.push((byte & mask) != 0); + } + } + + assert!(bits.len() == h.len() * U8BITS); + + bits.truncate(n); + bits +} + +fn b2h(s: &[bool]) -> Vec { + let m = if s.len() % U8BITS != 0 { + (s.len() / 8) + 1 + } else { + s.len() / 8 + }; + let mut bytes = vec![0u8; m]; + + for (i, bit) in s.iter().enumerate() { + let byte_index = i / U8BITS; + let mask = (*bit as u8) << (i % U8BITS); + + bytes[byte_index] |= mask; + } + + bytes +} +/// Macro to implement all sha3 hash functions as they only differ in digest size +macro_rules! sha3 { + ($name:ident, $n:literal) => { + pub fn $name(m: &[u8]) -> [u8; ($n / U8BITS)] { + let mut temp = h2b(m, m.len() * U8BITS); + temp.append(&mut vec![false, true]); + + temp = keccak($n * 2, &temp, $n); + + let mut ret = [0u8; ($n / U8BITS)]; + + let temp = b2h(&temp); + assert!(temp.len() == $n / U8BITS); + + for (i, byte) in temp.iter().enumerate() { + ret[i] = *byte; + } + + ret + } + }; +} + +sha3!(sha3_224, 224); +sha3!(sha3_256, 256); +sha3!(sha3_384, 384); +sha3!(sha3_512, 512); + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! digest_test { + ($fname:ident, $hash:ident, $size:literal, $message:expr, $expected:expr) => { + #[test] + fn $fname() { + let digest = $hash(&$message); + + let expected: [u8; $size / U8BITS] = $expected; + + assert_eq!(digest, expected); + } + }; + } + + digest_test!( + sha3_224_0, + sha3_224, + 224, + [0; 0], + [ + 0x6b, 0x4e, 0x03, 0x42, 0x36, 0x67, 0xdb, 0xb7, 0x3b, 0x6e, 0x15, 0x45, 0x4f, 0x0e, + 0xb1, 0xab, 0xd4, 0x59, 0x7f, 0x9a, 0x1b, 0x07, 0x8e, 0x3f, 0x5b, 0x5a, 0x6b, 0xc7, + ] + ); + + digest_test!( + sha3_224_8, + sha3_224, + 224, + [1u8], + [ + 0x48, 0x82, 0x86, 0xd9, 0xd3, 0x27, 0x16, 0xe5, 0x88, 0x1e, 0xa1, 0xee, 0x51, 0xf3, + 0x6d, 0x36, 0x60, 0xd7, 0x0f, 0x0d, 0xb0, 0x3b, 0x3f, 0x61, 0x2c, 0xe9, 0xed, 0xa4, + ] + ); + + // Done on large input to verify sponge function is working properly + digest_test!( + sha3_224_2312, + sha3_224, + 224, + [ + 0x31, 0xc8, 0x2d, 0x71, 0x78, 0x5b, 0x7c, 0xa6, 0xb6, 0x51, 0xcb, 0x6c, 0x8c, 0x9a, + 0xd5, 0xe2, 0xac, 0xeb, 0x0b, 0x06, 0x33, 0xc0, 0x88, 0xd3, 0x3a, 0xa2, 0x47, 0xad, + 0xa7, 0xa5, 0x94, 0xff, 0x49, 0x36, 0xc0, 0x23, 0x25, 0x13, 0x19, 0x82, 0x0a, 0x9b, + 0x19, 0xfc, 0x6c, 0x48, 0xde, 0x8a, 0x6f, 0x7a, 0xda, 0x21, 0x41, 0x76, 0xcc, 0xda, + 0xad, 0xae, 0xef, 0x51, 0xed, 0x43, 0x71, 0x4a, 0xc0, 0xc8, 0x26, 0x9b, 0xbd, 0x49, + 0x7e, 0x46, 0xe7, 0x8b, 0xb5, 0xe5, 0x81, 0x96, 0x49, 0x4b, 0x24, 0x71, 0xb1, 0x68, + 0x0e, 0x2d, 0x4c, 0x6d, 0xbd, 0x24, 0x98, 0x31, 0xbd, 0x83, 0xa4, 0xd3, 0xbe, 0x06, + 0xc8, 0xa2, 0xe9, 0x03, 0x93, 0x39, 0x74, 0xaa, 0x05, 0xee, 0x74, 0x8b, 0xfe, 0x6e, + 0xf3, 0x59, 0xf7, 0xa1, 0x43, 0xed, 0xf0, 0xd4, 0x91, 0x8d, 0xa9, 0x16, 0xbd, 0x6f, + 0x15, 0xe2, 0x6a, 0x79, 0x0c, 0xff, 0x51, 0x4b, 0x40, 0xa5, 0xda, 0x7f, 0x72, 0xe1, + 0xed, 0x2f, 0xe6, 0x3a, 0x05, 0xb8, 0x14, 0x95, 0x87, 0xbe, 0xa0, 0x56, 0x53, 0x71, + 0x8c, 0xc8, 0x98, 0x0e, 0xad, 0xbf, 0xec, 0xa8, 0x5b, 0x7c, 0x9c, 0x28, 0x6d, 0xd0, + 0x40, 0x93, 0x65, 0x85, 0x93, 0x8b, 0xe7, 0xf9, 0x82, 0x19, 0x70, 0x0c, 0x83, 0xa9, + 0x44, 0x3c, 0x28, 0x56, 0xa8, 0x0f, 0xf4, 0x68, 0x52, 0xb2, 0x6d, 0x1b, 0x1e, 0xdf, + 0x72, 0xa3, 0x02, 0x03, 0xcf, 0x6c, 0x44, 0xa1, 0x0f, 0xa6, 0xea, 0xf1, 0x92, 0x01, + 0x73, 0xce, 0xdf, 0xb5, 0xc4, 0xcf, 0x3a, 0xc6, 0x65, 0xb3, 0x7a, 0x86, 0xed, 0x02, + 0x15, 0x5b, 0xbb, 0xf1, 0x7d, 0xc2, 0xe7, 0x86, 0xaf, 0x94, 0x78, 0xfe, 0x08, 0x89, + 0xd8, 0x6c, 0x5b, 0xfa, 0x85, 0xa2, 0x42, 0xeb, 0x08, 0x54, 0xb1, 0x48, 0x2b, 0x7b, + 0xd1, 0x6f, 0x67, 0xf8, 0x0b, 0xef, 0x9c, 0x7a, 0x62, 0x8f, 0x05, 0xa1, 0x07, 0x93, + 0x6a, 0x64, 0x27, 0x3a, 0x97, 0xb0, 0x08, 0x8b, 0x0e, 0x51, 0x54, 0x51, 0xf9, 0x16, + 0xb5, 0x65, 0x62, 0x30, 0xa1, 0x2b, 0xa6, 0xdc, 0x78 + ], + [ + 0xaa, 0xb2, 0x3c, 0x9e, 0x7f, 0xb9, 0xd7, 0xda, 0xce, 0xfd, 0xfd, 0x0b, 0x1a, 0xe8, + 0x5a, 0xb1, 0x37, 0x4a, 0xbf, 0xf7, 0xc4, 0xe3, 0xf7, 0x55, 0x6e, 0xca, 0xe4, 0x12 + ] + ); + + digest_test!( + sha3_256_0, + sha3_256, + 256, + [0; 0], + [ + 0xa7, 0xff, 0xc6, 0xf8, 0xbf, 0x1e, 0xd7, 0x66, 0x51, 0xc1, 0x47, 0x56, 0xa0, 0x61, + 0xd6, 0x62, 0xf5, 0x80, 0xff, 0x4d, 0xe4, 0x3b, 0x49, 0xfa, 0x82, 0xd8, 0x0a, 0x4b, + 0x80, 0xf8, 0x43, 0x4a, + ] + ); + + digest_test!( + sha3_256_8, + sha3_256, + 256, + [0xe9u8], + [ + 0xf0, 0xd0, 0x4d, 0xd1, 0xe6, 0xcf, 0xc2, 0x9a, 0x44, 0x60, 0xd5, 0x21, 0x79, 0x68, + 0x52, 0xf2, 0x5d, 0x9e, 0xf8, 0xd2, 0x8b, 0x44, 0xee, 0x91, 0xff, 0x5b, 0x75, 0x9d, + 0x72, 0xc1, 0xe6, 0xd6, + ] + ); + + digest_test!( + sha3_256_2184, + sha3_256, + 256, + [ + 0xb1, 0xca, 0xa3, 0x96, 0x77, 0x1a, 0x09, 0xa1, 0xdb, 0x9b, 0xc2, 0x05, 0x43, 0xe9, + 0x88, 0xe3, 0x59, 0xd4, 0x7c, 0x2a, 0x61, 0x64, 0x17, 0xbb, 0xca, 0x1b, 0x62, 0xcb, + 0x02, 0x79, 0x6a, 0x88, 0x8f, 0xc6, 0xee, 0xff, 0x5c, 0x0b, 0x5c, 0x3d, 0x50, 0x62, + 0xfc, 0xb4, 0x25, 0x6f, 0x6a, 0xe1, 0x78, 0x2f, 0x49, 0x2c, 0x1c, 0xf0, 0x36, 0x10, + 0xb4, 0xa1, 0xfb, 0x7b, 0x81, 0x4c, 0x05, 0x78, 0x78, 0xe1, 0x19, 0x0b, 0x98, 0x35, + 0x42, 0x5c, 0x7a, 0x4a, 0x0e, 0x18, 0x2a, 0xd1, 0xf9, 0x15, 0x35, 0xed, 0x2a, 0x35, + 0x03, 0x3a, 0x5d, 0x8c, 0x67, 0x0e, 0x21, 0xc5, 0x75, 0xff, 0x43, 0xc1, 0x94, 0xa5, + 0x8a, 0x82, 0xd4, 0xa1, 0xa4, 0x48, 0x81, 0xdd, 0x61, 0xf9, 0xf8, 0x16, 0x1f, 0xc6, + 0xb9, 0x98, 0x86, 0x0c, 0xbe, 0x49, 0x75, 0x78, 0x0b, 0xe9, 0x3b, 0x6f, 0x87, 0x98, + 0x0b, 0xad, 0x0a, 0x99, 0xaa, 0x2c, 0xb7, 0x55, 0x6b, 0x47, 0x8c, 0xa3, 0x5d, 0x1f, + 0x37, 0x46, 0xc3, 0x3e, 0x2b, 0xb7, 0xc4, 0x7a, 0xf4, 0x26, 0x64, 0x1c, 0xc7, 0xbb, + 0xb3, 0x42, 0x5e, 0x21, 0x44, 0x82, 0x03, 0x45, 0xe1, 0xd0, 0xea, 0x5b, 0x7d, 0xa2, + 0xc3, 0x23, 0x6a, 0x52, 0x90, 0x6a, 0xcd, 0xc3, 0xb4, 0xd3, 0x4e, 0x47, 0x4d, 0xd7, + 0x14, 0xc0, 0xc4, 0x0b, 0xf0, 0x06, 0xa3, 0xa1, 0xd8, 0x89, 0xa6, 0x32, 0x98, 0x38, + 0x14, 0xbb, 0xc4, 0xa1, 0x4f, 0xe5, 0xf1, 0x59, 0xaa, 0x89, 0x24, 0x9e, 0x7c, 0x73, + 0x8b, 0x3b, 0x73, 0x66, 0x6b, 0xac, 0x2a, 0x61, 0x5a, 0x83, 0xfd, 0x21, 0xae, 0x0a, + 0x1c, 0xe7, 0x35, 0x2a, 0xde, 0x7b, 0x27, 0x8b, 0x58, 0x71, 0x58, 0xfd, 0x2f, 0xab, + 0xb2, 0x17, 0xaa, 0x1f, 0xe3, 0x1d, 0x0b, 0xda, 0x53, 0x27, 0x20, 0x45, 0x59, 0x80, + 0x15, 0xa8, 0xae, 0x4d, 0x8c, 0xec, 0x22, 0x6f, 0xef, 0xa5, 0x8d, 0xaa, 0x05, 0x50, + 0x09, 0x06, 0xc4, 0xd8, 0x5e, 0x75, 0x67 + ], + [ + 0xcb, 0x56, 0x48, 0xa1, 0xd6, 0x1c, 0x6c, 0x5b, 0xda, 0xcd, 0x96, 0xf8, 0x1c, 0x95, + 0x91, 0xde, 0xbc, 0x39, 0x50, 0xdc, 0xf6, 0x58, 0x14, 0x5b, 0x8d, 0x99, 0x65, 0x70, + 0xba, 0x88, 0x1a, 0x05 + ] + ); + + digest_test!( + sha3_384_0, + sha3_384, + 384, + [0; 0], + [ + 0x0c, 0x63, 0xa7, 0x5b, 0x84, 0x5e, 0x4f, 0x7d, 0x01, 0x10, 0x7d, 0x85, 0x2e, 0x4c, + 0x24, 0x85, 0xc5, 0x1a, 0x50, 0xaa, 0xaa, 0x94, 0xfc, 0x61, 0x99, 0x5e, 0x71, 0xbb, + 0xee, 0x98, 0x3a, 0x2a, 0xc3, 0x71, 0x38, 0x31, 0x26, 0x4a, 0xdb, 0x47, 0xfb, 0x6b, + 0xd1, 0xe0, 0x58, 0xd5, 0xf0, 0x04, + ] + ); + + digest_test!( + sha3_384_8, + sha3_384, + 384, + [0x80u8], + [ + 0x75, 0x41, 0x38, 0x48, 0x52, 0xe1, 0x0f, 0xf1, 0x0d, 0x5f, 0xb6, 0xa7, 0x21, 0x3a, + 0x4a, 0x6c, 0x15, 0xcc, 0xc8, 0x6d, 0x8b, 0xc1, 0x06, 0x8a, 0xc0, 0x4f, 0x69, 0x27, + 0x71, 0x42, 0x94, 0x4f, 0x4e, 0xe5, 0x0d, 0x91, 0xfd, 0xc5, 0x65, 0x53, 0xdb, 0x06, + 0xb2, 0xf5, 0x03, 0x9c, 0x8a, 0xb7, + ] + ); + + digest_test!( + sha3_384_2512, + sha3_384, + 384, + [ + 0x03, 0x5a, 0xdc, 0xb6, 0x39, 0xe5, 0xf2, 0x8b, 0xb5, 0xc8, 0x86, 0x58, 0xf4, 0x5c, + 0x1c, 0xe0, 0xbe, 0x16, 0xe7, 0xda, 0xfe, 0x08, 0x3b, 0x98, 0xd0, 0xab, 0x45, 0xe8, + 0xdc, 0xdb, 0xfa, 0x38, 0xe3, 0x23, 0x4d, 0xfd, 0x97, 0x3b, 0xa5, 0x55, 0xb0, 0xcf, + 0x8e, 0xea, 0x3c, 0x82, 0xae, 0x1a, 0x36, 0x33, 0xfc, 0x56, 0x5b, 0x7f, 0x2c, 0xc8, + 0x39, 0x87, 0x6d, 0x39, 0x89, 0xf3, 0x57, 0x31, 0xbe, 0x37, 0x1f, 0x60, 0xde, 0x14, + 0x0e, 0x3c, 0x91, 0x62, 0x31, 0xec, 0x78, 0x0e, 0x51, 0x65, 0xbf, 0x5f, 0x25, 0xd3, + 0xf6, 0x7d, 0xc7, 0x3a, 0x1c, 0x33, 0x65, 0x5d, 0xfd, 0xf4, 0x39, 0xdf, 0xbf, 0x1c, + 0xbb, 0xa8, 0xb7, 0x79, 0x15, 0x8a, 0x81, 0x0a, 0xd7, 0x24, 0x4f, 0x06, 0xec, 0x07, + 0x81, 0x20, 0xcd, 0x18, 0x76, 0x0a, 0xf4, 0x36, 0xa2, 0x38, 0x94, 0x1c, 0xe1, 0xe6, + 0x87, 0x88, 0x0b, 0x5c, 0x87, 0x9d, 0xc9, 0x71, 0xa2, 0x85, 0xa7, 0x4e, 0xe8, 0x5c, + 0x6a, 0x74, 0x67, 0x49, 0xa3, 0x01, 0x59, 0xee, 0x84, 0x2e, 0x9b, 0x03, 0xf3, 0x1d, + 0x61, 0x3d, 0xdd, 0xd2, 0x29, 0x75, 0xcd, 0x7f, 0xed, 0x06, 0xbd, 0x04, 0x9d, 0x77, + 0x2c, 0xb6, 0xcc, 0x5a, 0x70, 0x5f, 0xaa, 0x73, 0x4e, 0x87, 0x32, 0x1d, 0xc8, 0xf2, + 0xa4, 0xea, 0x36, 0x6a, 0x36, 0x8a, 0x98, 0xbf, 0x06, 0xee, 0x2b, 0x0b, 0x54, 0xac, + 0x3a, 0x3a, 0xee, 0xa6, 0x37, 0xca, 0xeb, 0xe7, 0x0a, 0xd0, 0x9c, 0xcd, 0xa9, 0x3c, + 0xc0, 0x6d, 0xe9, 0x5d, 0xf7, 0x33, 0x94, 0xa8, 0x7a, 0xc9, 0xbb, 0xb5, 0x08, 0x3a, + 0x4d, 0x8a, 0x24, 0x58, 0xe9, 0x1c, 0x7d, 0x5b, 0xf1, 0x13, 0xae, 0xca, 0xe0, 0xce, + 0x27, 0x9f, 0xdd, 0xa7, 0x6b, 0xa6, 0x90, 0x78, 0x7d, 0x26, 0x34, 0x5e, 0x94, 0xc3, + 0xed, 0xbc, 0x16, 0xa3, 0x5c, 0x83, 0xc4, 0xd0, 0x71, 0xb1, 0x32, 0xdd, 0x81, 0x18, + 0x7b, 0xcd, 0x99, 0x61, 0x32, 0x30, 0x11, 0x50, 0x9c, 0x8f, 0x64, 0x4a, 0x1c, 0x0a, + 0x3f, 0x14, 0xee, 0x40, 0xd7, 0xdd, 0x18, 0x6f, 0x80, 0x7f, 0x9e, 0xdc, 0x7c, 0x02, + 0xf6, 0x76, 0x10, 0x61, 0xbb, 0xb6, 0xdd, 0x91, 0xa6, 0xc9, 0x6e, 0xc0, 0xb9, 0xf1, + 0x0e, 0xdb, 0xbd, 0x29, 0xdc, 0x52 + ], + [ + 0x02, 0x53, 0x5d, 0x86, 0xcc, 0x75, 0x18, 0x48, 0x4a, 0x2a, 0x23, 0x8c, 0x92, 0x1b, + 0x73, 0x9b, 0x17, 0x04, 0xa5, 0x03, 0x70, 0xa2, 0x92, 0x4a, 0xbf, 0x39, 0x95, 0x8c, + 0x59, 0x76, 0xe6, 0x58, 0xdc, 0x5e, 0x87, 0x44, 0x00, 0x63, 0x11, 0x24, 0x59, 0xbd, + 0xdb, 0x40, 0x30, 0x8b, 0x1c, 0x70 + ] + ); + + digest_test!( + sha3_512_0, + sha3_512, + 512, + [0u8; 0], + [ + 0xa6, 0x9f, 0x73, 0xcc, 0xa2, 0x3a, 0x9a, 0xc5, 0xc8, 0xb5, 0x67, 0xdc, 0x18, 0x5a, + 0x75, 0x6e, 0x97, 0xc9, 0x82, 0x16, 0x4f, 0xe2, 0x58, 0x59, 0xe0, 0xd1, 0xdc, 0xc1, + 0x47, 0x5c, 0x80, 0xa6, 0x15, 0xb2, 0x12, 0x3a, 0xf1, 0xf5, 0xf9, 0x4c, 0x11, 0xe3, + 0xe9, 0x40, 0x2c, 0x3a, 0xc5, 0x58, 0xf5, 0x00, 0x19, 0x9d, 0x95, 0xb6, 0xd3, 0xe3, + 0x01, 0x75, 0x85, 0x86, 0x28, 0x1d, 0xcd, 0x26, + ] + ); + + digest_test!( + sha3_512_8, + sha3_512, + 512, + [0xe5u8], + [ + 0x15, 0x02, 0x40, 0xba, 0xf9, 0x5f, 0xb3, 0x6f, 0x8c, 0xcb, 0x87, 0xa1, 0x9a, 0x41, + 0x76, 0x7e, 0x7a, 0xed, 0x95, 0x12, 0x50, 0x75, 0xa2, 0xb2, 0xdb, 0xba, 0x6e, 0x56, + 0x5e, 0x1c, 0xe8, 0x57, 0x5f, 0x2b, 0x04, 0x2b, 0x62, 0xe2, 0x9a, 0x04, 0xe9, 0x44, + 0x03, 0x14, 0xa8, 0x21, 0xc6, 0x22, 0x41, 0x82, 0x96, 0x4d, 0x8b, 0x55, 0x7b, 0x16, + 0xa4, 0x92, 0xb3, 0x80, 0x6f, 0x4c, 0x39, 0xc1 + ] + ); + + digest_test!( + sha3_512_4080, + sha3_512, + 512, + [ + 0x43, 0x02, 0x56, 0x15, 0x52, 0x1d, 0x66, 0xfe, 0x8e, 0xc3, 0xa3, 0xf8, 0xcc, 0xc5, + 0xab, 0xfa, 0xb8, 0x70, 0xa4, 0x62, 0xc6, 0xb3, 0xd1, 0x39, 0x6b, 0x84, 0x62, 0xb9, + 0x8c, 0x7f, 0x91, 0x0c, 0x37, 0xd0, 0xea, 0x57, 0x91, 0x54, 0xea, 0xf7, 0x0f, 0xfb, + 0xcc, 0x0b, 0xe9, 0x71, 0xa0, 0x32, 0xcc, 0xfd, 0x9d, 0x96, 0xd0, 0xa9, 0xb8, 0x29, + 0xa9, 0xa3, 0x76, 0x2e, 0x21, 0xe3, 0xfe, 0xfc, 0xc6, 0x0e, 0x72, 0xfe, 0xdf, 0x9a, + 0x7f, 0xff, 0xa5, 0x34, 0x33, 0xa4, 0xb0, 0x5e, 0x0f, 0x3a, 0xb0, 0x5d, 0x5e, 0xb2, + 0x5d, 0x52, 0xc5, 0xea, 0xb1, 0xa7, 0x1a, 0x2f, 0x54, 0xac, 0x79, 0xff, 0x58, 0x82, + 0x95, 0x13, 0x26, 0x39, 0x4d, 0x9d, 0xb8, 0x35, 0x80, 0xce, 0x09, 0xd6, 0x21, 0x9b, + 0xca, 0x58, 0x8e, 0xc1, 0x57, 0xf7, 0x1d, 0x06, 0xe9, 0x57, 0xf8, 0xc2, 0x0d, 0x24, + 0x2c, 0x9f, 0x55, 0xf5, 0xfc, 0x9d, 0x4d, 0x77, 0x7b, 0x59, 0xb0, 0xc7, 0x5a, 0x8e, + 0xdc, 0x1f, 0xfe, 0xdc, 0x84, 0xb5, 0xd5, 0xc8, 0xa5, 0xe0, 0xeb, 0x05, 0xbb, 0x7d, + 0xb8, 0xf2, 0x34, 0x91, 0x3d, 0x63, 0x25, 0x30, 0x4f, 0xa4, 0x3c, 0x9d, 0x32, 0xbb, + 0xf6, 0xb2, 0x69, 0xee, 0x11, 0x82, 0xcd, 0x85, 0x45, 0x3e, 0xdd, 0xd1, 0x2f, 0x55, + 0x55, 0x6d, 0x8e, 0xdf, 0x02, 0xc4, 0xb1, 0x3c, 0xd4, 0xd3, 0x30, 0xf8, 0x35, 0x31, + 0xdb, 0xf2, 0x99, 0x4c, 0xf0, 0xbe, 0x56, 0xf5, 0x91, 0x47, 0xb7, 0x1f, 0x74, 0xb9, + 0x4b, 0xe3, 0xdd, 0x9e, 0x83, 0xc8, 0xc9, 0x47, 0x7c, 0x42, 0x6c, 0x6d, 0x1a, 0x78, + 0xde, 0x18, 0x56, 0x4a, 0x12, 0xc0, 0xd9, 0x93, 0x07, 0xb2, 0xc9, 0xab, 0x42, 0xb6, + 0xe3, 0x31, 0x7b, 0xef, 0xca, 0x07, 0x97, 0x02, 0x9e, 0x9d, 0xd6, 0x7b, 0xd1, 0x73, + 0x4e, 0x6c, 0x36, 0xd9, 0x98, 0x56, 0x5b, 0xfa, 0xc9, 0x4d, 0x19, 0x18, 0xa3, 0x58, + 0x69, 0x19, 0x0d, 0x17, 0x79, 0x43, 0xc1, 0xa8, 0x00, 0x44, 0x45, 0xca, 0xce, 0x75, + 0x1c, 0x43, 0xa7, 0x5f, 0x3d, 0x80, 0x51, 0x7f, 0xc4, 0x7c, 0xec, 0x46, 0xe8, 0xe3, + 0x82, 0x64, 0x2d, 0x76, 0xdf, 0x46, 0xda, 0xb1, 0xa3, 0xdd, 0xae, 0xab, 0x95, 0xa2, + 0xcf, 0x3f, 0x3a, 0xd7, 0x03, 0x69, 0xa7, 0x0f, 0x22, 0xf2, 0x93, 0xf0, 0xcc, 0x50, + 0xb0, 0x38, 0x57, 0xc8, 0x3c, 0xfe, 0x0b, 0xd5, 0xd2, 0x3b, 0x92, 0xcd, 0x87, 0x88, + 0xaa, 0xc2, 0x32, 0x29, 0x1d, 0xa6, 0x0b, 0x4b, 0xf3, 0xb3, 0x78, 0x8a, 0xe6, 0x0a, + 0x23, 0xb6, 0x16, 0x9b, 0x50, 0xd7, 0xfe, 0x44, 0x6e, 0x6e, 0xa7, 0x3d, 0xeb, 0xfe, + 0x1b, 0xb3, 0x4d, 0xcb, 0x1d, 0xb3, 0x7f, 0xe2, 0x17, 0x4a, 0x68, 0x59, 0x54, 0xeb, + 0xc2, 0xd8, 0x6f, 0x10, 0x2a, 0x59, 0x0c, 0x24, 0x73, 0x2b, 0xc5, 0xa1, 0x40, 0x3d, + 0x68, 0x76, 0xd2, 0x99, 0x5f, 0xab, 0x1e, 0x2f, 0x6f, 0x47, 0x23, 0xd4, 0xa6, 0x72, + 0x7a, 0x8a, 0x8e, 0xd7, 0x2f, 0x02, 0xa7, 0x4c, 0xcf, 0x5f, 0x14, 0xb5, 0xc2, 0x3d, + 0x95, 0x25, 0xdb, 0xf2, 0xb5, 0x47, 0x2e, 0x13, 0x45, 0xfd, 0x22, 0x3b, 0x08, 0x46, + 0xc7, 0x07, 0xb0, 0x65, 0x69, 0x65, 0x09, 0x40, 0x65, 0x0f, 0x75, 0x06, 0x3b, 0x52, + 0x98, 0x14, 0xe5, 0x14, 0x54, 0x1a, 0x67, 0x15, 0xf8, 0x79, 0xa8, 0x75, 0xb4, 0xf0, + 0x80, 0x77, 0x51, 0x78, 0x12, 0x84, 0x1e, 0x6c, 0x5c, 0x73, 0x2e, 0xed, 0x0c, 0x07, + 0xc0, 0x85, 0x95, 0xb9, 0xff, 0x0a, 0x83, 0xb8, 0xec, 0xc6, 0x0b, 0x2f, 0x98, 0xd4, + 0xe7, 0xc6, 0x96, 0xcd, 0x61, 0x6b, 0xb0, 0xa5, 0xad, 0x52, 0xd9, 0xcf, 0x7b, 0x3a, + 0x63, 0xa8, 0xcd, 0xf3, 0x72, 0x12 + ], + [ + 0xe7, 0xba, 0x73, 0x40, 0x7a, 0xa4, 0x56, 0xae, 0xce, 0x21, 0x10, 0x77, 0xd9, 0x20, + 0x87, 0xd5, 0xcd, 0x28, 0x3e, 0x38, 0x68, 0xd2, 0x84, 0xe0, 0x7e, 0xd1, 0x24, 0xb2, + 0x7c, 0xbc, 0x66, 0x4a, 0x6a, 0x47, 0x5a, 0x8d, 0x7b, 0x4c, 0xf6, 0xa8, 0xa4, 0x92, + 0x7e, 0xe0, 0x59, 0xa2, 0x62, 0x6a, 0x4f, 0x98, 0x39, 0x23, 0x36, 0x01, 0x45, 0xb2, + 0x65, 0xeb, 0xfd, 0x4f, 0x5b, 0x3c, 0x44, 0xfd + ] + ); +} diff --git a/src/ciphers/transposition.rs b/src/ciphers/transposition.rs index 2fc0d352a91..d5b2a75196e 100644 --- a/src/ciphers/transposition.rs +++ b/src/ciphers/transposition.rs @@ -5,22 +5,22 @@ //! original message. The most commonly referred to Transposition Cipher is the //! COLUMNAR TRANSPOSITION cipher, which is demonstrated below. -use std::ops::Range; +use std::ops::RangeInclusive; /// Encrypts or decrypts a message, using multiple keys. The /// encryption is based on the columnar transposition method. pub fn transposition(decrypt_mode: bool, msg: &str, key: &str) -> String { - let key_uppercase: String = key.to_uppercase(); + let key_uppercase = key.to_uppercase(); let mut cipher_msg: String = msg.to_string(); - let keys: Vec<&str> = match decrypt_mode { - false => key_uppercase.split_whitespace().collect(), - true => key_uppercase.split_whitespace().rev().collect(), + let keys: Vec<&str> = if decrypt_mode { + key_uppercase.split_whitespace().rev().collect() + } else { + key_uppercase.split_whitespace().collect() }; for cipher_key in keys.iter() { let mut key_order: Vec = Vec::new(); - let mut counter: u8 = 0; // Removes any non-alphabet characters from 'msg' cipher_msg = cipher_msg @@ -36,10 +36,9 @@ pub fn transposition(decrypt_mode: bool, msg: &str, key: &str) -> String { key_ascii.sort_by_key(|&(_, key)| key); - key_ascii.iter_mut().for_each(|(_, key)| { - *key = counter; - counter += 1; - }); + for (counter, (_, key)) in key_ascii.iter_mut().enumerate() { + *key = counter as u8; + } key_ascii.sort_by_key(|&(index, _)| index); @@ -49,9 +48,10 @@ pub fn transposition(decrypt_mode: bool, msg: &str, key: &str) -> String { // Determines whether to encrypt or decrypt the message, // and returns the result - cipher_msg = match decrypt_mode { - false => encrypt(cipher_msg, key_order), - true => decrypt(cipher_msg, key_order), + cipher_msg = if decrypt_mode { + decrypt(cipher_msg, key_order) + } else { + encrypt(cipher_msg, key_order) }; } @@ -63,7 +63,7 @@ fn encrypt(mut msg: String, key_order: Vec) -> String { let mut encrypted_msg: String = String::from(""); let mut encrypted_vec: Vec = Vec::new(); - let msg_len: usize = msg.len(); + let msg_len = msg.len(); let key_len: usize = key_order.len(); let mut msg_index: usize = msg_len; @@ -77,7 +77,7 @@ fn encrypt(mut msg: String, key_order: Vec) -> String { // Loop every nth character, determined by key length, to create a column while index < msg_index { - let ch: char = msg.remove(index); + let ch = msg.remove(index); chars.push(ch); index += key_index; @@ -91,18 +91,16 @@ fn encrypt(mut msg: String, key_order: Vec) -> String { // alphabetical order of the keyword's characters let mut indexed_vec: Vec<(usize, &String)> = Vec::new(); let mut indexed_msg: String = String::from(""); - let mut counter: usize = 0; - key_order.into_iter().for_each(|key_index| { + for (counter, key_index) in key_order.into_iter().enumerate() { indexed_vec.push((key_index, &encrypted_vec[counter])); - counter += 1; - }); + } indexed_vec.sort(); - indexed_vec.into_iter().for_each(|(_, column)| { + for (_, column) in indexed_vec { indexed_msg.push_str(column); - }); + } // Split the message by a space every nth character, determined by // 'message length divided by keyword length' to the next highest integer. @@ -127,7 +125,7 @@ fn decrypt(mut msg: String, key_order: Vec) -> String { let mut decrypted_vec: Vec = Vec::new(); let mut indexed_vec: Vec<(usize, String)> = Vec::new(); - let msg_len: usize = msg.len(); + let msg_len = msg.len(); let key_len: usize = key_order.len(); // Split the message into columns, determined by 'message length divided by keyword length'. @@ -144,8 +142,8 @@ fn decrypt(mut msg: String, key_order: Vec) -> String { split_large.iter_mut().rev().for_each(|key_index| { counter -= 1; - let range: Range = - ((*key_index * split_size) + counter)..(((*key_index + 1) * split_size) + counter + 1); + let range: RangeInclusive = + ((*key_index * split_size) + counter)..=(((*key_index + 1) * split_size) + counter); let slice: String = msg[range.clone()].to_string(); indexed_vec.push((*key_index, slice)); @@ -153,19 +151,19 @@ fn decrypt(mut msg: String, key_order: Vec) -> String { msg.replace_range(range, ""); }); - split_small.iter_mut().for_each(|key_index| { + for key_index in split_small.iter_mut() { let (slice, rest_of_msg) = msg.split_at(split_size); indexed_vec.push((*key_index, (slice.to_string()))); msg = rest_of_msg.to_string(); - }); + } indexed_vec.sort(); - key_order.into_iter().for_each(|key| { + for key in key_order { if let Some((_, column)) = indexed_vec.iter().find(|(key_index, _)| key_index == &key) { - decrypted_vec.push(column.to_string()); + decrypted_vec.push(column.clone()); } - }); + } // Concatenate the columns into a string, determined by the // alphabetical order of the keyword's characters diff --git a/src/ciphers/xor.rs b/src/ciphers/xor.rs index fe97f315957..a01351611da 100644 --- a/src/ciphers/xor.rs +++ b/src/ciphers/xor.rs @@ -35,7 +35,7 @@ mod tests { #[test] fn test_zero_byte() { let test_string = "The quick brown fox jumps over the lazy dog"; - let key = ' ' as u8; + let key = b' '; let ciphered_text = xor(test_string, key); assert_eq!(test_string.as_bytes(), xor_bytes(&ciphered_text, key)); } @@ -43,7 +43,7 @@ mod tests { #[test] fn test_invalid_byte() { let test_string = "The quick brown fox jumps over the lazy dog"; - let key = !0 as u8; + let key = !0; let ciphered_text = xor(test_string, key); assert_eq!(test_string.as_bytes(), xor_bytes(&ciphered_text, key)); } diff --git a/src/compression/mod.rs b/src/compression/mod.rs new file mode 100644 index 00000000000..7acbee56ec5 --- /dev/null +++ b/src/compression/mod.rs @@ -0,0 +1,5 @@ +mod move_to_front; +mod run_length_encoding; + +pub use self::move_to_front::{move_to_front_decode, move_to_front_encode}; +pub use self::run_length_encoding::{run_length_decode, run_length_encode}; diff --git a/src/compression/move_to_front.rs b/src/compression/move_to_front.rs new file mode 100644 index 00000000000..fe38b02ef7c --- /dev/null +++ b/src/compression/move_to_front.rs @@ -0,0 +1,60 @@ +// https://en.wikipedia.org/wiki/Move-to-front_transform + +fn blank_char_table() -> Vec { + (0..=255).map(|ch| ch as u8 as char).collect() +} + +pub fn move_to_front_encode(text: &str) -> Vec { + let mut char_table = blank_char_table(); + let mut result = Vec::new(); + + for ch in text.chars() { + if let Some(position) = char_table.iter().position(|&x| x == ch) { + result.push(position as u8); + char_table.remove(position); + char_table.insert(0, ch); + } + } + + result +} + +pub fn move_to_front_decode(encoded: &[u8]) -> String { + let mut char_table = blank_char_table(); + let mut result = String::new(); + + for &pos in encoded { + let ch = char_table[pos as usize]; + result.push(ch); + char_table.remove(pos as usize); + char_table.insert(0, ch); + } + + result +} + +#[cfg(test)] +mod test { + use super::*; + + macro_rules! test_mtf { + ($($name:ident: ($text:expr, $encoded:expr),)*) => { + $( + #[test] + fn $name() { + assert_eq!(move_to_front_encode($text), $encoded); + assert_eq!(move_to_front_decode($encoded), $text); + } + )* + } + } + + test_mtf! { + empty: ("", &[]), + single_char: ("@", &[64]), + repeated_chars: ("aaba", &[97, 0, 98, 1]), + mixed_chars: ("aZ!", &[97, 91, 35]), + word: ("banana", &[98, 98, 110, 1, 1, 1]), + special_chars: ("\0\n\t", &[0, 10, 10]), + } +} diff --git a/src/compression/run_length_encoding.rs b/src/compression/run_length_encoding.rs new file mode 100644 index 00000000000..03b554b0a3e --- /dev/null +++ b/src/compression/run_length_encoding.rs @@ -0,0 +1,74 @@ +// https://en.wikipedia.org/wiki/Run-length_encoding + +pub fn run_length_encode(text: &str) -> Vec<(char, i32)> { + let mut count = 1; + let mut encoded: Vec<(char, i32)> = vec![]; + + for (i, c) in text.chars().enumerate() { + if i + 1 < text.len() && c == text.chars().nth(i + 1).unwrap() { + count += 1; + } else { + encoded.push((c, count)); + count = 1; + } + } + + encoded +} + +pub fn run_length_decode(encoded: &[(char, i32)]) -> String { + let res = encoded + .iter() + .map(|x| (x.0).to_string().repeat(x.1 as usize)) + .collect::(); + + res +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_run_length_decode() { + let res = run_length_decode(&[('A', 0)]); + assert_eq!(res, ""); + let res = run_length_decode(&[('B', 1)]); + assert_eq!(res, "B"); + let res = run_length_decode(&[('A', 5), ('z', 3), ('B', 1)]); + assert_eq!(res, "AAAAAzzzB"); + } + + #[test] + fn test_run_length_encode() { + let res = run_length_encode(""); + assert_eq!(res, []); + + let res = run_length_encode("A"); + assert_eq!(res, [('A', 1)]); + + let res = run_length_encode("AA"); + assert_eq!(res, [('A', 2)]); + + let res = run_length_encode("AAAABBBCCDAA"); + assert_eq!(res, [('A', 4), ('B', 3), ('C', 2), ('D', 1), ('A', 2)]); + + let res = run_length_encode("Rust-Trends"); + assert_eq!( + res, + [ + ('R', 1), + ('u', 1), + ('s', 1), + ('t', 1), + ('-', 1), + ('T', 1), + ('r', 1), + ('e', 1), + ('n', 1), + ('d', 1), + ('s', 1) + ] + ); + } +} diff --git a/src/conversions/binary_to_decimal.rs b/src/conversions/binary_to_decimal.rs new file mode 100644 index 00000000000..2be7b0f9f07 --- /dev/null +++ b/src/conversions/binary_to_decimal.rs @@ -0,0 +1,90 @@ +use num_traits::CheckedAdd; + +pub fn binary_to_decimal(binary: &str) -> Option { + if binary.len() > 128 { + return None; + } + let mut num = 0; + let mut idx_val = 1; + for bit in binary.chars().rev() { + match bit { + '1' => { + if let Some(sum) = num.checked_add(&idx_val) { + num = sum; + } else { + return None; + } + } + '0' => {} + _ => return None, + } + idx_val <<= 1; + } + Some(num) +} + +#[cfg(test)] +mod tests { + use super::binary_to_decimal; + + #[test] + fn basic_binary_to_decimal() { + assert_eq!(binary_to_decimal("0000000110"), Some(6)); + assert_eq!(binary_to_decimal("1000011110"), Some(542)); + assert_eq!(binary_to_decimal("1111111111"), Some(1023)); + } + #[test] + fn big_binary_to_decimal() { + assert_eq!( + binary_to_decimal("111111111111111111111111"), + Some(16_777_215) + ); + // 32 bits + assert_eq!( + binary_to_decimal("11111111111111111111111111111111"), + Some(4_294_967_295) + ); + // 64 bits + assert_eq!( + binary_to_decimal("1111111111111111111111111111111111111111111111111111111111111111"), + Some(18_446_744_073_709_551_615u128) + ); + } + #[test] + fn very_big_binary_to_decimal() { + // 96 bits + assert_eq!( + binary_to_decimal( + "1111111111111111111111111111111111111111111111111111111111111111\ + 11111111111111111111111111111111" + ), + Some(79_228_162_514_264_337_593_543_950_335u128) + ); + + // 128 bits + assert_eq!( + binary_to_decimal( + "1111111111111111111111111111111111111111111111111111111111111111\ + 1111111111111111111111111111111111111111111111111111111111111111" + ), + Some(340_282_366_920_938_463_463_374_607_431_768_211_455u128) + ); + // 129 bits, should overflow + assert!(binary_to_decimal( + "1111111111111111111111111111111111111111111111111111111111111111\ + 11111111111111111111111111111111111111111111111111111111111111111" + ) + .is_none()); + // obviously none + assert!(binary_to_decimal( + "1111111111111111111111111111111111111111111111111111111111111111\ + 1111111111111111111111111111111111111111111111111111111111111\ + 1111111111111111111111111111111111111111111111111111111111111\ + 1111111111111111111111111111111111111111111111111111111111111\ + 1111111111111111111111111111111111111111111111111111111111111\ + 1111111111111111111111111111111111111111111111111111111111111\ + 1111111111111111111111111111111111111111111111111111111111111" + ) + .is_none()); + } +} diff --git a/src/conversions/binary_to_hexadecimal.rs b/src/conversions/binary_to_hexadecimal.rs new file mode 100644 index 00000000000..0a4e52291a6 --- /dev/null +++ b/src/conversions/binary_to_hexadecimal.rs @@ -0,0 +1,106 @@ +// Author : cyrixninja +// Binary to Hex Converter : Converts Binary to Hexadecimal +// Wikipedia References : 1. https://en.wikipedia.org/wiki/Hexadecimal +// 2. https://en.wikipedia.org/wiki/Binary_number + +static BITS_TO_HEX: &[(u8, &str)] = &[ + (0b0000, "0"), + (0b0001, "1"), + (0b0010, "2"), + (0b0011, "3"), + (0b0100, "4"), + (0b0101, "5"), + (0b0110, "6"), + (0b0111, "7"), + (0b1000, "8"), + (0b1001, "9"), + (0b1010, "a"), + (0b1011, "b"), + (0b1100, "c"), + (0b1101, "d"), + (0b1110, "e"), + (0b1111, "f"), +]; + +pub fn binary_to_hexadecimal(binary_str: &str) -> String { + let binary_str = binary_str.trim(); + + if binary_str.is_empty() { + return String::from("Invalid Input"); + } + + let is_negative = binary_str.starts_with('-'); + let binary_str = if is_negative { + &binary_str[1..] + } else { + binary_str + }; + + if !binary_str.chars().all(|c| c == '0' || c == '1') { + return String::from("Invalid Input"); + } + + let padded_len = (4 - (binary_str.len() % 4)) % 4; + let binary_str = format!( + "{:0width$}", + binary_str, + width = binary_str.len() + padded_len + ); + + // Convert binary to hexadecimal + let mut hexadecimal = String::with_capacity(binary_str.len() / 4 + 2); + hexadecimal.push_str("0x"); + + for chunk in binary_str.as_bytes().chunks(4) { + let mut nibble = 0; + for (i, &byte) in chunk.iter().enumerate() { + nibble |= (byte - b'0') << (3 - i); + } + + let hex_char = BITS_TO_HEX + .iter() + .find(|&&(bits, _)| bits == nibble) + .map(|&(_, hex)| hex) + .unwrap(); + hexadecimal.push_str(hex_char); + } + + if is_negative { + format!("-{hexadecimal}") + } else { + hexadecimal + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_string() { + let input = ""; + let expected = "Invalid Input"; + assert_eq!(binary_to_hexadecimal(input), expected); + } + + #[test] + fn test_invalid_binary() { + let input = "a"; + let expected = "Invalid Input"; + assert_eq!(binary_to_hexadecimal(input), expected); + } + + #[test] + fn test_binary() { + let input = "00110110"; + let expected = "0x36"; + assert_eq!(binary_to_hexadecimal(input), expected); + } + + #[test] + fn test_padded_binary() { + let input = " 1010 "; + let expected = "0xa"; + assert_eq!(binary_to_hexadecimal(input), expected); + } +} diff --git a/src/conversions/decimal_to_binary.rs b/src/conversions/decimal_to_binary.rs new file mode 100644 index 00000000000..40e180b2914 --- /dev/null +++ b/src/conversions/decimal_to_binary.rs @@ -0,0 +1,26 @@ +pub fn decimal_to_binary(base_num: u64) -> String { + let mut num = base_num; + let mut binary_num = String::new(); + loop { + let bit = (num % 2).to_string(); + binary_num.push_str(&bit); + num /= 2; + if num == 0 { + break; + } + } + + let bits = binary_num.chars(); + bits.rev().collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn converting_decimal_to_binary() { + assert_eq!(decimal_to_binary(542), "1000011110"); + assert_eq!(decimal_to_binary(92), "1011100"); + } +} diff --git a/src/conversions/decimal_to_hexadecimal.rs b/src/conversions/decimal_to_hexadecimal.rs new file mode 100644 index 00000000000..57ff3e11b71 --- /dev/null +++ b/src/conversions/decimal_to_hexadecimal.rs @@ -0,0 +1,56 @@ +pub fn decimal_to_hexadecimal(base_num: u64) -> String { + let mut num = base_num; + let mut hexadecimal_num = String::new(); + + loop { + let remainder = num % 16; + let hex_char = if remainder < 10 { + (remainder as u8 + b'0') as char + } else { + (remainder as u8 - 10 + b'A') as char + }; + + hexadecimal_num.insert(0, hex_char); + num /= 16; + if num == 0 { + break; + } + } + + hexadecimal_num +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_zero() { + assert_eq!(decimal_to_hexadecimal(0), "0"); + } + + #[test] + fn test_single_digit_decimal() { + assert_eq!(decimal_to_hexadecimal(9), "9"); + } + + #[test] + fn test_single_digit_hexadecimal() { + assert_eq!(decimal_to_hexadecimal(12), "C"); + } + + #[test] + fn test_multiple_digit_hexadecimal() { + assert_eq!(decimal_to_hexadecimal(255), "FF"); + } + + #[test] + fn test_big() { + assert_eq!(decimal_to_hexadecimal(u64::MAX), "FFFFFFFFFFFFFFFF"); + } + + #[test] + fn test_random() { + assert_eq!(decimal_to_hexadecimal(123456), "1E240"); + } +} diff --git a/src/conversions/hexadecimal_to_binary.rs b/src/conversions/hexadecimal_to_binary.rs new file mode 100644 index 00000000000..490b69e8fb0 --- /dev/null +++ b/src/conversions/hexadecimal_to_binary.rs @@ -0,0 +1,67 @@ +// Author : cyrixninja +// Hexadecimal to Binary Converter : Converts Hexadecimal to Binary +// Wikipedia References : 1. https://en.wikipedia.org/wiki/Hexadecimal +// 2. https://en.wikipedia.org/wiki/Binary_number +// Other References for Testing : https://www.rapidtables.com/convert/number/hex-to-binary.html + +pub fn hexadecimal_to_binary(hex_str: &str) -> Result { + let hex_chars = hex_str.chars().collect::>(); + let mut binary = String::new(); + + for c in hex_chars { + let bin_rep = match c { + '0' => "0000", + '1' => "0001", + '2' => "0010", + '3' => "0011", + '4' => "0100", + '5' => "0101", + '6' => "0110", + '7' => "0111", + '8' => "1000", + '9' => "1001", + 'a' | 'A' => "1010", + 'b' | 'B' => "1011", + 'c' | 'C' => "1100", + 'd' | 'D' => "1101", + 'e' | 'E' => "1110", + 'f' | 'F' => "1111", + _ => return Err("Invalid".to_string()), + }; + binary.push_str(bin_rep); + } + + Ok(binary) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_string() { + let input = ""; + let expected = Ok("".to_string()); + assert_eq!(hexadecimal_to_binary(input), expected); + } + + #[test] + fn test_hexadecimal() { + let input = "1a2"; + let expected = Ok("000110100010".to_string()); + assert_eq!(hexadecimal_to_binary(input), expected); + } + #[test] + fn test_hexadecimal2() { + let input = "1b3"; + let expected = Ok("000110110011".to_string()); + assert_eq!(hexadecimal_to_binary(input), expected); + } + + #[test] + fn test_invalid_hexadecimal() { + let input = "1g3"; + let expected = Err("Invalid".to_string()); + assert_eq!(hexadecimal_to_binary(input), expected); + } +} diff --git a/src/conversions/hexadecimal_to_decimal.rs b/src/conversions/hexadecimal_to_decimal.rs new file mode 100644 index 00000000000..5f71716d039 --- /dev/null +++ b/src/conversions/hexadecimal_to_decimal.rs @@ -0,0 +1,61 @@ +pub fn hexadecimal_to_decimal(hexadecimal_str: &str) -> Result { + if hexadecimal_str.is_empty() { + return Err("Empty input"); + } + + for hexadecimal_str in hexadecimal_str.chars() { + if !hexadecimal_str.is_ascii_hexdigit() { + return Err("Input was not a hexadecimal number"); + } + } + + match u64::from_str_radix(hexadecimal_str, 16) { + Ok(decimal) => Ok(decimal), + Err(_e) => Err("Failed to convert octal to hexadecimal"), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_hexadecimal_to_decimal_empty() { + assert_eq!(hexadecimal_to_decimal(""), Err("Empty input")); + } + + #[test] + fn test_hexadecimal_to_decimal_invalid() { + assert_eq!( + hexadecimal_to_decimal("xyz"), + Err("Input was not a hexadecimal number") + ); + assert_eq!( + hexadecimal_to_decimal("0xabc"), + Err("Input was not a hexadecimal number") + ); + } + + #[test] + fn test_hexadecimal_to_decimal_valid1() { + assert_eq!(hexadecimal_to_decimal("45"), Ok(69)); + assert_eq!(hexadecimal_to_decimal("2b3"), Ok(691)); + assert_eq!(hexadecimal_to_decimal("4d2"), Ok(1234)); + assert_eq!(hexadecimal_to_decimal("1267a"), Ok(75386)); + } + + #[test] + fn test_hexadecimal_to_decimal_valid2() { + assert_eq!(hexadecimal_to_decimal("1a"), Ok(26)); + assert_eq!(hexadecimal_to_decimal("ff"), Ok(255)); + assert_eq!(hexadecimal_to_decimal("a1b"), Ok(2587)); + assert_eq!(hexadecimal_to_decimal("7fffffff"), Ok(2147483647)); + } + + #[test] + fn test_hexadecimal_to_decimal_valid3() { + assert_eq!(hexadecimal_to_decimal("0"), Ok(0)); + assert_eq!(hexadecimal_to_decimal("7f"), Ok(127)); + assert_eq!(hexadecimal_to_decimal("80000000"), Ok(2147483648)); + } +} diff --git a/src/conversions/length_conversion.rs b/src/conversions/length_conversion.rs new file mode 100644 index 00000000000..4a056ed3052 --- /dev/null +++ b/src/conversions/length_conversion.rs @@ -0,0 +1,94 @@ +/// Author : https://github.com/ali77gh +/// Conversion of length units. +/// +/// Available Units: +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Millimeter +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Centimeter +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Meter +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Kilometer +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Inch +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Foot +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Yard +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Mile + +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub enum LengthUnit { + Millimeter, + Centimeter, + Meter, + Kilometer, + Inch, + Foot, + Yard, + Mile, +} + +fn unit_to_meter_multiplier(from: LengthUnit) -> f64 { + match from { + LengthUnit::Millimeter => 0.001, + LengthUnit::Centimeter => 0.01, + LengthUnit::Meter => 1.0, + LengthUnit::Kilometer => 1000.0, + LengthUnit::Inch => 0.0254, + LengthUnit::Foot => 0.3048, + LengthUnit::Yard => 0.9144, + LengthUnit::Mile => 1609.34, + } +} + +fn unit_to_meter(input: f64, from: LengthUnit) -> f64 { + input * unit_to_meter_multiplier(from) +} + +fn meter_to_unit(input: f64, to: LengthUnit) -> f64 { + input / unit_to_meter_multiplier(to) +} + +/// This function will convert a value in unit of [from] to value in unit of [to] +/// by first converting it to meter and than convert it to destination unit +pub fn length_conversion(input: f64, from: LengthUnit, to: LengthUnit) -> f64 { + meter_to_unit(unit_to_meter(input, from), to) +} + +#[cfg(test)] +mod length_conversion_tests { + use std::collections::HashMap; + + use super::LengthUnit::*; + use super::*; + + #[test] + fn zero_to_zero() { + let units = vec![ + Millimeter, Centimeter, Meter, Kilometer, Inch, Foot, Yard, Mile, + ]; + + for u1 in units.clone() { + for u2 in units.clone() { + assert_eq!(length_conversion(0f64, u1, u2), 0f64); + } + } + } + + #[test] + fn length_of_one_meter() { + let meter_in_different_units = HashMap::from([ + (Millimeter, 1000f64), + (Centimeter, 100f64), + (Kilometer, 0.001f64), + (Inch, 39.37007874015748f64), + (Foot, 3.280839895013123f64), + (Yard, 1.0936132983377078f64), + (Mile, 0.0006213727366498068f64), + ]); + for (input_unit, input_value) in &meter_in_different_units { + for (target_unit, target_value) in &meter_in_different_units { + assert!( + num_traits::abs( + length_conversion(*input_value, *input_unit, *target_unit) - *target_value + ) < 0.0000001 + ); + } + } + } +} diff --git a/src/conversions/mod.rs b/src/conversions/mod.rs new file mode 100644 index 00000000000..a83c46bf600 --- /dev/null +++ b/src/conversions/mod.rs @@ -0,0 +1,20 @@ +mod binary_to_decimal; +mod binary_to_hexadecimal; +mod decimal_to_binary; +mod decimal_to_hexadecimal; +mod hexadecimal_to_binary; +mod hexadecimal_to_decimal; +mod length_conversion; +mod octal_to_binary; +mod octal_to_decimal; +mod rgb_cmyk_conversion; +pub use self::binary_to_decimal::binary_to_decimal; +pub use self::binary_to_hexadecimal::binary_to_hexadecimal; +pub use self::decimal_to_binary::decimal_to_binary; +pub use self::decimal_to_hexadecimal::decimal_to_hexadecimal; +pub use self::hexadecimal_to_binary::hexadecimal_to_binary; +pub use self::hexadecimal_to_decimal::hexadecimal_to_decimal; +pub use self::length_conversion::length_conversion; +pub use self::octal_to_binary::octal_to_binary; +pub use self::octal_to_decimal::octal_to_decimal; +pub use self::rgb_cmyk_conversion::rgb_to_cmyk; diff --git a/src/conversions/octal_to_binary.rs b/src/conversions/octal_to_binary.rs new file mode 100644 index 00000000000..ba4a9ccebd2 --- /dev/null +++ b/src/conversions/octal_to_binary.rs @@ -0,0 +1,60 @@ +// Author : cyrixninja +// Octal to Binary Converter : Converts Octal to Binary +// Wikipedia References : 1. https://en.wikipedia.org/wiki/Octal +// 2. https://en.wikipedia.org/wiki/Binary_number + +pub fn octal_to_binary(octal_str: &str) -> Result { + let octal_str = octal_str.trim(); + + if octal_str.is_empty() { + return Err("Empty"); + } + + if !octal_str.chars().all(|c| ('0'..'7').contains(&c)) { + return Err("Non-octal Value"); + } + + // Convert octal to binary + let binary = octal_str + .chars() + .map(|c| match c { + '0' => "000", + '1' => "001", + '2' => "010", + '3' => "011", + '4' => "100", + '5' => "101", + '6' => "110", + '7' => "111", + _ => unreachable!(), + }) + .collect::(); + + Ok(binary) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_string() { + let input = ""; + let expected = Err("Empty"); + assert_eq!(octal_to_binary(input), expected); + } + + #[test] + fn test_invalid_octal() { + let input = "89"; + let expected = Err("Non-octal Value"); + assert_eq!(octal_to_binary(input), expected); + } + + #[test] + fn test_valid_octal() { + let input = "123"; + let expected = Ok("001010011".to_string()); + assert_eq!(octal_to_binary(input), expected); + } +} diff --git a/src/conversions/octal_to_decimal.rs b/src/conversions/octal_to_decimal.rs new file mode 100644 index 00000000000..18ab5076916 --- /dev/null +++ b/src/conversions/octal_to_decimal.rs @@ -0,0 +1,60 @@ +// Author: cyrixninja +// Octal to Decimal Converter: Converts Octal to Decimal +// Wikipedia References: +// 1. https://en.wikipedia.org/wiki/Octal +// 2. https://en.wikipedia.org/wiki/Decimal + +pub fn octal_to_decimal(octal_str: &str) -> Result { + let octal_str = octal_str.trim(); + + if octal_str.is_empty() { + return Err("Empty"); + } + + if !octal_str.chars().all(|c| ('0'..='7').contains(&c)) { + return Err("Non-octal Value"); + } + + // Convert octal to decimal and directly return the Result + u64::from_str_radix(octal_str, 8).map_err(|_| "Conversion error") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_string() { + let input = ""; + let expected = Err("Empty"); + assert_eq!(octal_to_decimal(input), expected); + } + + #[test] + fn test_invalid_octal() { + let input = "89"; + let expected = Err("Non-octal Value"); + assert_eq!(octal_to_decimal(input), expected); + } + + #[test] + fn test_valid_octal() { + let input = "123"; + let expected = Ok(83); + assert_eq!(octal_to_decimal(input), expected); + } + + #[test] + fn test_valid_octal2() { + let input = "1234"; + let expected = Ok(668); + assert_eq!(octal_to_decimal(input), expected); + } + + #[test] + fn test_valid_octal3() { + let input = "12345"; + let expected = Ok(5349); + assert_eq!(octal_to_decimal(input), expected); + } +} diff --git a/src/conversions/rgb_cmyk_conversion.rs b/src/conversions/rgb_cmyk_conversion.rs new file mode 100644 index 00000000000..30a8bc9bd84 --- /dev/null +++ b/src/conversions/rgb_cmyk_conversion.rs @@ -0,0 +1,60 @@ +/// Author : https://github.com/ali77gh\ +/// References:\ +/// RGB: https://en.wikipedia.org/wiki/RGB_color_model\ +/// CMYK: https://en.wikipedia.org/wiki/CMYK_color_model\ + +/// This function Converts RGB to CMYK format +/// +/// ### Params +/// * `r` - red +/// * `g` - green +/// * `b` - blue +/// +/// ### Returns +/// (C, M, Y, K) +pub fn rgb_to_cmyk(rgb: (u8, u8, u8)) -> (u8, u8, u8, u8) { + // Safety: no need to check if input is positive and less than 255 because it's u8 + + // change scale from [0,255] to [0,1] + let (r, g, b) = ( + rgb.0 as f64 / 255f64, + rgb.1 as f64 / 255f64, + rgb.2 as f64 / 255f64, + ); + + match 1f64 - r.max(g).max(b) { + 1f64 => (0, 0, 0, 100), // pure black + k => ( + (100f64 * (1f64 - r - k) / (1f64 - k)) as u8, // c + (100f64 * (1f64 - g - k) / (1f64 - k)) as u8, // m + (100f64 * (1f64 - b - k) / (1f64 - k)) as u8, // y + (100f64 * k) as u8, // k + ), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_rgb_to_cmyk { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (rgb, cmyk) = $tc; + assert_eq!(rgb_to_cmyk(rgb), cmyk); + } + )* + } + } + + test_rgb_to_cmyk! { + white: ((255, 255, 255), (0, 0, 0, 0)), + gray: ((128, 128, 128), (0, 0, 0, 49)), + black: ((0, 0, 0), (0, 0, 0, 100)), + red: ((255, 0, 0), (0, 100, 100, 0)), + green: ((0, 255, 0), (100, 0, 100, 0)), + blue: ((0, 0, 255), (100, 100, 0, 0)), + } +} diff --git a/src/data_structures/avl_tree.rs b/src/data_structures/avl_tree.rs index 8c7cd887e66..64800a405ca 100644 --- a/src/data_structures/avl_tree.rs +++ b/src/data_structures/avl_tree.rs @@ -365,7 +365,7 @@ mod tests { #[test] fn sorted() { let tree: AVLTree<_> = (1..8).rev().collect(); - assert!((1..8).eq(tree.iter().map(|&x| x))); + assert!((1..8).eq(tree.iter().copied())); } #[test] diff --git a/src/data_structures/b_tree.rs b/src/data_structures/b_tree.rs index a17a85215d3..bf7266932ae 100644 --- a/src/data_structures/b_tree.rs +++ b/src/data_structures/b_tree.rs @@ -100,19 +100,18 @@ impl BTreeProps { self.insert_non_full(&mut node.children[u_index], key); } } - - fn traverse_node(&self, node: &Node, depth: usize) { + fn traverse_node(node: &Node, depth: usize) { if node.is_leaf() { print!(" {0:{<1$}{2:?}{0:}<1$} ", "", depth, node.keys); } else { let _depth = depth + 1; for (index, key) in node.keys.iter().enumerate() { - self.traverse_node(&node.children[index], _depth); + Self::traverse_node(&node.children[index], _depth); // Check https://doc.rust-lang.org/std/fmt/index.html // And https://stackoverflow.com/a/35280799/2849127 print!("{0:{<1$}{2:?}{0:}<1$}", "", depth, key); } - self.traverse_node(node.children.last().unwrap(), _depth); + Self::traverse_node(node.children.last().unwrap(), _depth); } } } @@ -141,26 +140,21 @@ where } pub fn traverse(&self) { - self.props.traverse_node(&self.root, 0); + BTreeProps::traverse_node(&self.root, 0); println!(); } pub fn search(&self, key: T) -> bool { let mut current_node = &self.root; - let mut index: isize; loop { - index = isize::try_from(current_node.keys.len()).ok().unwrap() - 1; - while index >= 0 && current_node.keys[index as usize] > key { - index -= 1; - } - - let u_index: usize = usize::try_from(index + 1).ok().unwrap(); - if index >= 0 && current_node.keys[u_index - 1] == key { - break true; - } else if current_node.is_leaf() { - break false; - } else { - current_node = ¤t_node.children[u_index]; + match current_node.keys.binary_search(&key) { + Ok(_) => return true, + Err(index) => { + if current_node.is_leaf() { + return false; + } + current_node = ¤t_node.children[index]; + } } } } @@ -170,19 +164,46 @@ where mod test { use super::BTree; - #[test] - fn test_search() { - let mut tree = BTree::new(2); - tree.insert(10); - tree.insert(20); - tree.insert(30); - tree.insert(5); - tree.insert(6); - tree.insert(7); - tree.insert(11); - tree.insert(12); - tree.insert(15); - assert!(tree.search(15)); - assert_eq!(tree.search(16), false); + macro_rules! test_search { + ($($name:ident: $number_of_children:expr,)*) => { + $( + #[test] + fn $name() { + let mut tree = BTree::new($number_of_children); + tree.insert(10); + tree.insert(20); + tree.insert(30); + tree.insert(5); + tree.insert(6); + tree.insert(7); + tree.insert(11); + tree.insert(12); + tree.insert(15); + assert!(!tree.search(4)); + assert!(tree.search(5)); + assert!(tree.search(6)); + assert!(tree.search(7)); + assert!(!tree.search(8)); + assert!(!tree.search(9)); + assert!(tree.search(10)); + assert!(tree.search(11)); + assert!(tree.search(12)); + assert!(!tree.search(13)); + assert!(!tree.search(14)); + assert!(tree.search(15)); + assert!(!tree.search(16)); + } + )* + } + } + + test_search! { + children_2: 2, + children_3: 3, + children_4: 4, + children_5: 5, + children_10: 10, + children_60: 60, + children_101: 101, } } diff --git a/src/data_structures/binary_search_tree.rs b/src/data_structures/binary_search_tree.rs index 9470dd0f8f3..193fb485408 100644 --- a/src/data_structures/binary_search_tree.rs +++ b/src/data_structures/binary_search_tree.rs @@ -34,7 +34,7 @@ where } } - /// Find a value in this tree. Returns True iff value is in this + /// Find a value in this tree. Returns True if value is in this /// tree, and false otherwise pub fn search(&self, value: &T) -> bool { match &self.value { @@ -71,26 +71,22 @@ where /// Insert a value into the appropriate location in this tree. pub fn insert(&mut self, value: T) { - if self.value.is_none() { - self.value = Some(value); - } else { - match &self.value { - None => (), - Some(key) => { - let target_node = if value < *key { - &mut self.left - } else { - &mut self.right - }; - match target_node { - Some(ref mut node) => { - node.insert(value); - } - None => { - let mut node = BinarySearchTree::new(); - node.insert(value); - *target_node = Some(Box::new(node)); - } + match &self.value { + None => self.value = Some(value), + Some(key) => { + let target_node = if value < *key { + &mut self.left + } else { + &mut self.right + }; + match target_node { + Some(ref mut node) => { + node.insert(value); + } + None => { + let mut node = BinarySearchTree::new(); + node.value = Some(value); + *target_node = Some(Box::new(node)); } } } @@ -188,7 +184,7 @@ where stack: Vec<&'a BinarySearchTree>, } -impl<'a, T> BinarySearchTreeIter<'a, T> +impl BinarySearchTreeIter<'_, T> where T: Ord, { diff --git a/src/data_structures/fenwick_tree.rs b/src/data_structures/fenwick_tree.rs index 5066d2ccb9b..c4b9c571de4 100644 --- a/src/data_structures/fenwick_tree.rs +++ b/src/data_structures/fenwick_tree.rs @@ -1,75 +1,264 @@ -use std::ops::{Add, AddAssign}; +use std::ops::{Add, AddAssign, Sub, SubAssign}; -/// Fenwick Tree / Binary Indexed Tree -/// Consider we have an array arr[0 . . . n-1]. We would like to -/// 1. Compute the sum of the first i elements. -/// 2. Modify the value of a specified element of the array arr[i] = x where 0 <= i <= n-1.Fenwick tree -pub struct FenwickTree { +/// A Fenwick Tree (also known as a Binary Indexed Tree) that supports efficient +/// prefix sum, range sum and point queries, as well as point updates. +/// +/// The Fenwick Tree uses **1-based** indexing internally but presents a **0-based** interface to the user. +/// This design improves efficiency and simplifies both internal operations and external usage. +pub struct FenwickTree +where + T: Add + AddAssign + Sub + SubAssign + Copy + Default, +{ + /// Internal storage of the Fenwick Tree. The first element (index 0) is unused + /// to simplify index calculations, so the effective tree size is `data.len() - 1`. data: Vec, } -impl + AddAssign + Copy + Default> FenwickTree { - /// construct a new FenwickTree with given length - pub fn with_len(len: usize) -> Self { +/// Enum representing the possible errors that can occur during FenwickTree operations. +#[derive(Debug, PartialEq, Eq)] +pub enum FenwickTreeError { + /// Error indicating that an index was out of the valid range. + IndexOutOfBounds, + /// Error indicating that a provided range was invalid (e.g., left > right). + InvalidRange, +} + +impl FenwickTree +where + T: Add + AddAssign + Sub + SubAssign + Copy + Default, +{ + /// Creates a new Fenwick Tree with a specified capacity. + /// + /// The tree will have `capacity + 1` elements, all initialized to the default + /// value of type `T`. The additional element allows for 1-based indexing internally. + /// + /// # Arguments + /// + /// * `capacity` - The number of elements the tree can hold (excluding the extra element). + /// + /// # Returns + /// + /// A new `FenwickTree` instance. + pub fn with_capacity(capacity: usize) -> Self { FenwickTree { - data: vec![T::default(); len + 1], + data: vec![T::default(); capacity + 1], } } - /// add `val` to `idx` - pub fn add(&mut self, i: usize, val: T) { - assert!(i < self.data.len()); - let mut i = i + 1; - while i < self.data.len() { - self.data[i] += val; - i += lowbit(i); + /// Updates the tree by adding a value to the element at a specified index. + /// + /// This operation also propagates the update to subsequent elements in the tree. + /// + /// # Arguments + /// + /// * `index` - The zero-based index where the value should be added. + /// * `value` - The value to add to the element at the specified index. + /// + /// # Returns + /// + /// A `Result` indicating success (`Ok`) or an error (`FenwickTreeError::IndexOutOfBounds`) + /// if the index is out of bounds. + pub fn update(&mut self, index: usize, value: T) -> Result<(), FenwickTreeError> { + if index >= self.data.len() - 1 { + return Err(FenwickTreeError::IndexOutOfBounds); + } + + let mut idx = index + 1; + while idx < self.data.len() { + self.data[idx] += value; + idx += lowbit(idx); } + + Ok(()) } - /// get the sum of [0, i] - pub fn prefix_sum(&self, i: usize) -> T { - assert!(i < self.data.len()); - let mut i = i + 1; - let mut res = T::default(); - while i > 0 { - res += self.data[i]; - i -= lowbit(i); + /// Computes the sum of elements from the start of the tree up to a specified index. + /// + /// This operation efficiently calculates the prefix sum using the tree structure. + /// + /// # Arguments + /// + /// * `index` - The zero-based index up to which the sum should be computed. + /// + /// # Returns + /// + /// A `Result` containing the prefix sum (`Ok(sum)`) or an error (`FenwickTreeError::IndexOutOfBounds`) + /// if the index is out of bounds. + pub fn prefix_query(&self, index: usize) -> Result { + if index >= self.data.len() - 1 { + return Err(FenwickTreeError::IndexOutOfBounds); + } + + let mut idx = index + 1; + let mut result = T::default(); + while idx > 0 { + result += self.data[idx]; + idx -= lowbit(idx); } - res + + Ok(result) + } + + /// Computes the sum of elements within a specified range `[left, right]`. + /// + /// This operation calculates the range sum by performing two prefix sum queries. + /// + /// # Arguments + /// + /// * `left` - The zero-based starting index of the range. + /// * `right` - The zero-based ending index of the range. + /// + /// # Returns + /// + /// A `Result` containing the range sum (`Ok(sum)`) or an error (`FenwickTreeError::InvalidRange`) + /// if the left index is greater than the right index or the right index is out of bounds. + pub fn range_query(&self, left: usize, right: usize) -> Result { + if left > right || right >= self.data.len() - 1 { + return Err(FenwickTreeError::InvalidRange); + } + + let right_query = self.prefix_query(right)?; + let left_query = if left == 0 { + T::default() + } else { + self.prefix_query(left - 1)? + }; + + Ok(right_query - left_query) + } + + /// Retrieves the value at a specific index by isolating it from the prefix sum. + /// + /// This operation determines the value at `index` by subtracting the prefix sum up to `index - 1` + /// from the prefix sum up to `index`. + /// + /// # Arguments + /// + /// * `index` - The zero-based index of the element to retrieve. + /// + /// # Returns + /// + /// A `Result` containing the value at the specified index (`Ok(value)`) or an error (`FenwickTreeError::IndexOutOfBounds`) + /// if the index is out of bounds. + pub fn point_query(&self, index: usize) -> Result { + if index >= self.data.len() - 1 { + return Err(FenwickTreeError::IndexOutOfBounds); + } + + let index_query = self.prefix_query(index)?; + let prev_query = if index == 0 { + T::default() + } else { + self.prefix_query(index - 1)? + }; + + Ok(index_query - prev_query) + } + + /// Sets the value at a specific index in the tree, updating the structure accordingly. + /// + /// This operation updates the value at `index` by computing the difference between the + /// desired value and the current value, then applying that difference using `update`. + /// + /// # Arguments + /// + /// * `index` - The zero-based index of the element to set. + /// * `value` - The new value to set at the specified index. + /// + /// # Returns + /// + /// A `Result` indicating success (`Ok`) or an error (`FenwickTreeError::IndexOutOfBounds`) + /// if the index is out of bounds. + pub fn set(&mut self, index: usize, value: T) -> Result<(), FenwickTreeError> { + self.update(index, value - self.point_query(index)?) } } -/// get the lowest bit of `i` +/// Computes the lowest set bit (rightmost `1` bit) of a number. +/// +/// This function isolates the lowest set bit in the binary representation of `x`. +/// It's used to navigate the Fenwick Tree by determining the next index to update or query. +/// +/// +/// In a Fenwick Tree, operations like updating and querying use bitwise manipulation +/// (via the lowbit function). These operations naturally align with 1-based indexing, +/// making traversal between parent and child nodes more straightforward. +/// +/// # Arguments +/// +/// * `x` - The input number whose lowest set bit is to be determined. +/// +/// # Returns +/// +/// The value of the lowest set bit in `x`. const fn lowbit(x: usize) -> usize { - let x = x as isize; - (x & (-x)) as usize + x & (!x + 1) } #[cfg(test)] mod tests { use super::*; + #[test] - fn it_works() { - let mut ft = FenwickTree::with_len(10); - ft.add(0, 1); - ft.add(1, 2); - ft.add(2, 3); - ft.add(3, 4); - ft.add(4, 5); - ft.add(5, 6); - ft.add(6, 7); - ft.add(7, 8); - ft.add(8, 9); - ft.add(9, 10); - assert_eq!(ft.prefix_sum(0), 1); - assert_eq!(ft.prefix_sum(1), 3); - assert_eq!(ft.prefix_sum(2), 6); - assert_eq!(ft.prefix_sum(3), 10); - assert_eq!(ft.prefix_sum(4), 15); - assert_eq!(ft.prefix_sum(5), 21); - assert_eq!(ft.prefix_sum(6), 28); - assert_eq!(ft.prefix_sum(7), 36); - assert_eq!(ft.prefix_sum(8), 45); - assert_eq!(ft.prefix_sum(9), 55); + fn test_fenwick_tree() { + let mut fenwick_tree = FenwickTree::with_capacity(10); + + assert_eq!(fenwick_tree.update(0, 5), Ok(())); + assert_eq!(fenwick_tree.update(1, 3), Ok(())); + assert_eq!(fenwick_tree.update(2, -2), Ok(())); + assert_eq!(fenwick_tree.update(3, 6), Ok(())); + assert_eq!(fenwick_tree.update(4, -4), Ok(())); + assert_eq!(fenwick_tree.update(5, 7), Ok(())); + assert_eq!(fenwick_tree.update(6, -1), Ok(())); + assert_eq!(fenwick_tree.update(7, 2), Ok(())); + assert_eq!(fenwick_tree.update(8, -3), Ok(())); + assert_eq!(fenwick_tree.update(9, 4), Ok(())); + assert_eq!(fenwick_tree.set(3, 10), Ok(())); + assert_eq!(fenwick_tree.point_query(3), Ok(10)); + assert_eq!(fenwick_tree.set(5, 0), Ok(())); + assert_eq!(fenwick_tree.point_query(5), Ok(0)); + assert_eq!( + fenwick_tree.update(10, 11), + Err(FenwickTreeError::IndexOutOfBounds) + ); + assert_eq!( + fenwick_tree.set(10, 11), + Err(FenwickTreeError::IndexOutOfBounds) + ); + + assert_eq!(fenwick_tree.prefix_query(0), Ok(5)); + assert_eq!(fenwick_tree.prefix_query(1), Ok(8)); + assert_eq!(fenwick_tree.prefix_query(2), Ok(6)); + assert_eq!(fenwick_tree.prefix_query(3), Ok(16)); + assert_eq!(fenwick_tree.prefix_query(4), Ok(12)); + assert_eq!(fenwick_tree.prefix_query(5), Ok(12)); + assert_eq!(fenwick_tree.prefix_query(6), Ok(11)); + assert_eq!(fenwick_tree.prefix_query(7), Ok(13)); + assert_eq!(fenwick_tree.prefix_query(8), Ok(10)); + assert_eq!(fenwick_tree.prefix_query(9), Ok(14)); + assert_eq!( + fenwick_tree.prefix_query(10), + Err(FenwickTreeError::IndexOutOfBounds) + ); + + assert_eq!(fenwick_tree.range_query(0, 4), Ok(12)); + assert_eq!(fenwick_tree.range_query(3, 7), Ok(7)); + assert_eq!(fenwick_tree.range_query(2, 5), Ok(4)); + assert_eq!( + fenwick_tree.range_query(4, 3), + Err(FenwickTreeError::InvalidRange) + ); + assert_eq!( + fenwick_tree.range_query(2, 10), + Err(FenwickTreeError::InvalidRange) + ); + + assert_eq!(fenwick_tree.point_query(0), Ok(5)); + assert_eq!(fenwick_tree.point_query(4), Ok(-4)); + assert_eq!(fenwick_tree.point_query(9), Ok(4)); + assert_eq!( + fenwick_tree.point_query(10), + Err(FenwickTreeError::IndexOutOfBounds) + ); } } diff --git a/src/data_structures/floyds_algorithm.rs b/src/data_structures/floyds_algorithm.rs new file mode 100644 index 00000000000..b475d07d963 --- /dev/null +++ b/src/data_structures/floyds_algorithm.rs @@ -0,0 +1,95 @@ +// floyds_algorithm.rs +// https://github.com/rust-lang/rust/blob/master/library/alloc/src/collections/linked_list.rs#L113 +// use std::collections::linked_list::LinkedList; +// https://www.reddit.com/r/rust/comments/t7wquc/is_it_possible_to_solve_leetcode_problem141/ + +use crate::data_structures::linked_list::LinkedList; // Import the LinkedList from linked_list.rs + +pub fn detect_cycle(linked_list: &LinkedList) -> Option { + let mut current = linked_list.head; + let mut checkpoint = linked_list.head; + let mut steps_until_reset = 1; + let mut times_reset = 0; + + while let Some(node) = current { + steps_until_reset -= 1; + if steps_until_reset == 0 { + checkpoint = current; + times_reset += 1; + steps_until_reset = 1 << times_reset; // 2^times_reset + } + + unsafe { + let node_ptr = node.as_ptr(); + let next = (*node_ptr).next; + current = next; + } + if current == checkpoint { + return Some(linked_list.length as usize); + } + } + + None +} + +pub fn has_cycle(linked_list: &LinkedList) -> bool { + let mut slow = linked_list.head; + let mut fast = linked_list.head; + + while let (Some(slow_node), Some(fast_node)) = (slow, fast) { + unsafe { + slow = slow_node.as_ref().next; + fast = fast_node.as_ref().next; + + if let Some(fast_next) = fast { + // fast = (*fast_next.as_ptr()).next; + fast = fast_next.as_ref().next; + } else { + return false; // If fast reaches the end, there's no cycle + } + + if slow == fast { + return true; // Cycle detected + } + } + } + // println!("{}", flag); + false // No cycle detected +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_detect_cycle_no_cycle() { + let mut linked_list = LinkedList::new(); + linked_list.insert_at_tail(1); + linked_list.insert_at_tail(2); + linked_list.insert_at_tail(3); + + assert!(!has_cycle(&linked_list)); + + assert_eq!(detect_cycle(&linked_list), None); + } + + #[test] + fn test_detect_cycle_with_cycle() { + let mut linked_list = LinkedList::new(); + linked_list.insert_at_tail(1); + linked_list.insert_at_tail(2); + linked_list.insert_at_tail(3); + + // Create a cycle for testing + unsafe { + if let Some(mut tail) = linked_list.tail { + if let Some(head) = linked_list.head { + tail.as_mut().next = Some(head); + } + } + } + + assert!(has_cycle(&linked_list)); + assert_eq!(detect_cycle(&linked_list), Some(3)); + } +} diff --git a/src/data_structures/graph.rs b/src/data_structures/graph.rs index 5740f82ec24..2bf3e64046b 100644 --- a/src/data_structures/graph.rs +++ b/src/data_structures/graph.rs @@ -135,7 +135,7 @@ mod test_undirected_graph { (&String::from("c"), &String::from("b"), 10), ]; for edge in expected_edges.iter() { - assert_eq!(graph.edges().contains(edge), true); + assert!(graph.edges().contains(edge)); } } @@ -188,7 +188,7 @@ mod test_directed_graph { (&String::from("b"), &String::from("c"), 10), ]; for edge in expected_edges.iter() { - assert_eq!(graph.edges().contains(edge), true); + assert!(graph.edges().contains(edge)); } } @@ -212,9 +212,9 @@ mod test_directed_graph { graph.add_node("a"); graph.add_node("b"); graph.add_node("c"); - assert_eq!(graph.contains("a"), true); - assert_eq!(graph.contains("b"), true); - assert_eq!(graph.contains("c"), true); - assert_eq!(graph.contains("d"), false); + assert!(graph.contains("a")); + assert!(graph.contains("b")); + assert!(graph.contains("c")); + assert!(!graph.contains("d")); } } diff --git a/src/data_structures/hash_table.rs b/src/data_structures/hash_table.rs new file mode 100644 index 00000000000..8eb39bdefb3 --- /dev/null +++ b/src/data_structures/hash_table.rs @@ -0,0 +1,145 @@ +use std::collections::LinkedList; + +pub struct HashTable { + elements: Vec>, + count: usize, +} + +impl Default for HashTable { + fn default() -> Self { + Self::new() + } +} + +pub trait Hashable { + fn hash(&self) -> usize; +} + +impl HashTable { + pub fn new() -> HashTable { + let initial_capacity = 3000; + let mut elements = Vec::with_capacity(initial_capacity); + + for _ in 0..initial_capacity { + elements.push(LinkedList::new()); + } + + HashTable { elements, count: 0 } + } + + pub fn insert(&mut self, key: K, value: V) { + if self.count >= self.elements.len() * 3 / 4 { + self.resize(); + } + let index = key.hash() % self.elements.len(); + self.elements[index].push_back((key, value)); + self.count += 1; + } + + pub fn search(&self, key: K) -> Option<&V> { + let index = key.hash() % self.elements.len(); + self.elements[index] + .iter() + .find(|(k, _)| *k == key) + .map(|(_, v)| v) + } + + fn resize(&mut self) { + let new_size = self.elements.len() * 2; + let mut new_elements = Vec::with_capacity(new_size); + + for _ in 0..new_size { + new_elements.push(LinkedList::new()); + } + + for old_list in self.elements.drain(..) { + for (key, value) in old_list { + let new_index = key.hash() % new_size; + new_elements[new_index].push_back((key, value)); + } + } + + self.elements = new_elements; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug, PartialEq, Eq)] + struct TestKey(usize); + + impl Hashable for TestKey { + fn hash(&self) -> usize { + self.0 + } + } + + #[test] + fn test_insert_and_search() { + let mut hash_table = HashTable::new(); + let key = TestKey(1); + let value = TestKey(10); + + hash_table.insert(key, value); + let result = hash_table.search(TestKey(1)); + + assert_eq!(result, Some(&TestKey(10))); + } + + #[test] + fn test_resize() { + let mut hash_table = HashTable::new(); + let initial_capacity = hash_table.elements.capacity(); + + for i in 0..=initial_capacity * 3 / 4 { + hash_table.insert(TestKey(i), TestKey(i + 10)); + } + + assert!(hash_table.elements.capacity() > initial_capacity); + } + + #[test] + fn test_search_nonexistent() { + let mut hash_table = HashTable::new(); + let key = TestKey(1); + let value = TestKey(10); + + hash_table.insert(key, value); + let result = hash_table.search(TestKey(2)); + + assert_eq!(result, None); + } + + #[test] + fn test_multiple_inserts_and_searches() { + let mut hash_table = HashTable::new(); + for i in 0..10 { + hash_table.insert(TestKey(i), TestKey(i + 100)); + } + + for i in 0..10 { + let result = hash_table.search(TestKey(i)); + assert_eq!(result, Some(&TestKey(i + 100))); + } + } + + #[test] + fn test_not_overwrite_existing_key() { + let mut hash_table = HashTable::new(); + hash_table.insert(TestKey(1), TestKey(100)); + hash_table.insert(TestKey(1), TestKey(200)); + + let result = hash_table.search(TestKey(1)); + assert_eq!(result, Some(&TestKey(100))); + } + + #[test] + fn test_empty_search() { + let hash_table: HashTable = HashTable::new(); + let result = hash_table.search(TestKey(1)); + + assert_eq!(result, None); + } +} diff --git a/src/data_structures/heap.rs b/src/data_structures/heap.rs index 03c2b6d1bcb..cb48ff1bbd1 100644 --- a/src/data_structures/heap.rs +++ b/src/data_structures/heap.rs @@ -1,215 +1,333 @@ -// Heap data structure -// Takes a closure as a comparator to allow for min-heap, max-heap, and works with custom key functions +//! A generic heap data structure. +//! +//! This module provides a `Heap` implementation that can function as either a +//! min-heap or a max-heap. It supports common heap operations such as adding, +//! removing, and iterating over elements. The heap can also be created from +//! an unsorted vector and supports custom comparators for flexible sorting +//! behavior. -use std::cmp::Ord; -use std::default::Default; +use std::{cmp::Ord, slice::Iter}; -pub struct Heap -where - T: Default, -{ - count: usize, +/// A heap data structure that can be used as a min-heap, max-heap or with +/// custom comparators. +/// +/// This struct manages a collection of items where the heap property is maintained. +/// The heap can be configured to order elements based on a provided comparator function, +/// allowing for both min-heap and max-heap functionalities, as well as custom sorting orders. +pub struct Heap { items: Vec, comparator: fn(&T, &T) -> bool, } -impl Heap -where - T: Default, -{ +impl Heap { + /// Creates a new, empty heap with a custom comparator function. + /// + /// # Parameters + /// - `comparator`: A function that defines the heap's ordering. + /// + /// # Returns + /// A new `Heap` instance. pub fn new(comparator: fn(&T, &T) -> bool) -> Self { Self { - count: 0, - // Add a default in the first spot to offset indexes - // for the parent/child math to work out. - // Vecs have to have all the same type so using Default - // is a way to add an unused item. - items: vec![T::default()], + items: vec![], comparator, } } + /// Creates a heap from a vector and a custom comparator function. + /// + /// # Parameters + /// - `items`: A vector of items to be turned into a heap. + /// - `comparator`: A function that defines the heap's ordering. + /// + /// # Returns + /// A `Heap` instance with the elements from the provided vector. + pub fn from_vec(items: Vec, comparator: fn(&T, &T) -> bool) -> Self { + let mut heap = Self { items, comparator }; + heap.build_heap(); + heap + } + + /// Constructs the heap from an unsorted vector by applying the heapify process. + fn build_heap(&mut self) { + let last_parent_idx = (self.len() / 2).wrapping_sub(1); + for idx in (0..=last_parent_idx).rev() { + self.heapify_down(idx); + } + } + + /// Returns the number of elements in the heap. + /// + /// # Returns + /// The number of elements in the heap. pub fn len(&self) -> usize { - self.count + self.items.len() } + /// Checks if the heap is empty. + /// + /// # Returns + /// `true` if the heap is empty, `false` otherwise. pub fn is_empty(&self) -> bool { self.len() == 0 } + /// Adds a new element to the heap and maintains the heap property. + /// + /// # Parameters + /// - `value`: The value to add to the heap. pub fn add(&mut self, value: T) { - self.count += 1; self.items.push(value); + self.heapify_up(self.len() - 1); + } + + /// Removes and returns the root element from the heap. + /// + /// # Returns + /// The root element if the heap is not empty, otherwise `None`. + pub fn pop(&mut self) -> Option { + if self.is_empty() { + return None; + } + let next = Some(self.items.swap_remove(0)); + if !self.is_empty() { + self.heapify_down(0); + } + next + } - // Heapify Up - let mut idx = self.count; - while self.parent_idx(idx) > 0 { - let pdx = self.parent_idx(idx); + /// Returns an iterator over the elements in the heap. + /// + /// # Returns + /// An iterator over the elements in the heap, in their internal order. + pub fn iter(&self) -> Iter<'_, T> { + self.items.iter() + } + + /// Moves an element upwards to restore the heap property. + /// + /// # Parameters + /// - `idx`: The index of the element to heapify up. + fn heapify_up(&mut self, mut idx: usize) { + while let Some(pdx) = self.parent_idx(idx) { if (self.comparator)(&self.items[idx], &self.items[pdx]) { self.items.swap(idx, pdx); + idx = pdx; + } else { + break; } - idx = pdx; } } - fn parent_idx(&self, idx: usize) -> usize { - idx / 2 + /// Moves an element downwards to restore the heap property. + /// + /// # Parameters + /// - `idx`: The index of the element to heapify down. + fn heapify_down(&mut self, mut idx: usize) { + while self.children_present(idx) { + let cdx = { + if self.right_child_idx(idx) >= self.len() { + self.left_child_idx(idx) + } else { + let ldx = self.left_child_idx(idx); + let rdx = self.right_child_idx(idx); + if (self.comparator)(&self.items[ldx], &self.items[rdx]) { + ldx + } else { + rdx + } + } + }; + + if (self.comparator)(&self.items[cdx], &self.items[idx]) { + self.items.swap(idx, cdx); + idx = cdx; + } else { + break; + } + } } + /// Returns the index of the parent of the element at `idx`. + /// + /// # Parameters + /// - `idx`: The index of the element. + /// + /// # Returns + /// The index of the parent element if it exists, otherwise `None`. + fn parent_idx(&self, idx: usize) -> Option { + if idx > 0 { + Some((idx - 1) / 2) + } else { + None + } + } + + /// Checks if the element at `idx` has children. + /// + /// # Parameters + /// - `idx`: The index of the element. + /// + /// # Returns + /// `true` if the element has children, `false` otherwise. fn children_present(&self, idx: usize) -> bool { - self.left_child_idx(idx) <= self.count + self.left_child_idx(idx) < self.len() } + /// Returns the index of the left child of the element at `idx`. + /// + /// # Parameters + /// - `idx`: The index of the element. + /// + /// # Returns + /// The index of the left child. fn left_child_idx(&self, idx: usize) -> usize { - idx * 2 + idx * 2 + 1 } + /// Returns the index of the right child of the element at `idx`. + /// + /// # Parameters + /// - `idx`: The index of the element. + /// + /// # Returns + /// The index of the right child. fn right_child_idx(&self, idx: usize) -> usize { self.left_child_idx(idx) + 1 } - - fn smallest_child_idx(&self, idx: usize) -> usize { - if self.right_child_idx(idx) > self.count { - self.left_child_idx(idx) - } else { - let ldx = self.left_child_idx(idx); - let rdx = self.right_child_idx(idx); - if (self.comparator)(&self.items[ldx], &self.items[rdx]) { - ldx - } else { - rdx - } - } - } } impl Heap where - T: Default + Ord, + T: Ord, { - /// Create a new MinHeap - pub fn new_min() -> Self { + /// Creates a new min-heap. + /// + /// # Returns + /// A new `Heap` instance configured as a min-heap. + pub fn new_min() -> Heap { Self::new(|a, b| a < b) } - /// Create a new MaxHeap - pub fn new_max() -> Self { + /// Creates a new max-heap. + /// + /// # Returns + /// A new `Heap` instance configured as a max-heap. + pub fn new_max() -> Heap { Self::new(|a, b| a > b) } -} -impl Iterator for Heap -where - T: Default, -{ - type Item = T; - - fn next(&mut self) -> Option { - if self.count == 0 { - return None; - } - // This feels like a function built for heap impl :) - // Removes an item at an index and fills in with the last item - // of the Vec - let next = Some(self.items.swap_remove(1)); - self.count -= 1; - - if self.count > 0 { - // Heapify Down - let mut idx = 1; - while self.children_present(idx) { - let cdx = self.smallest_child_idx(idx); - if !(self.comparator)(&self.items[idx], &self.items[cdx]) { - self.items.swap(idx, cdx); - } - idx = cdx; - } - } - - next + /// Creates a min-heap from an unsorted vector. + /// + /// # Parameters + /// - `items`: A vector of items to be turned into a min-heap. + /// + /// # Returns + /// A `Heap` instance configured as a min-heap. + pub fn from_vec_min(items: Vec) -> Heap { + Self::from_vec(items, |a, b| a < b) } -} -pub struct MinHeap; - -impl MinHeap { - #[allow(clippy::new_ret_no_self)] - pub fn new() -> Heap - where - T: Default + Ord, - { - Heap::new(|a, b| a < b) - } -} - -pub struct MaxHeap; - -impl MaxHeap { - #[allow(clippy::new_ret_no_self)] - pub fn new() -> Heap - where - T: Default + Ord, - { - Heap::new(|a, b| a > b) + /// Creates a max-heap from an unsorted vector. + /// + /// # Parameters + /// - `items`: A vector of items to be turned into a max-heap. + /// + /// # Returns + /// A `Heap` instance configured as a max-heap. + pub fn from_vec_max(items: Vec) -> Heap { + Self::from_vec(items, |a, b| a > b) } } #[cfg(test)] mod tests { use super::*; + #[test] fn test_empty_heap() { - let mut heap = MaxHeap::new::(); - assert_eq!(heap.next(), None); + let mut heap: Heap = Heap::new_max(); + assert_eq!(heap.pop(), None); } #[test] fn test_min_heap() { - let mut heap = MinHeap::new(); + let mut heap = Heap::new_min(); heap.add(4); heap.add(2); heap.add(9); heap.add(11); assert_eq!(heap.len(), 4); - assert_eq!(heap.next(), Some(2)); - assert_eq!(heap.next(), Some(4)); - assert_eq!(heap.next(), Some(9)); + assert_eq!(heap.pop(), Some(2)); + assert_eq!(heap.pop(), Some(4)); + assert_eq!(heap.pop(), Some(9)); heap.add(1); - assert_eq!(heap.next(), Some(1)); + assert_eq!(heap.pop(), Some(1)); + assert_eq!(heap.pop(), Some(11)); + assert_eq!(heap.pop(), None); } #[test] fn test_max_heap() { - let mut heap = MaxHeap::new(); + let mut heap = Heap::new_max(); heap.add(4); heap.add(2); heap.add(9); heap.add(11); assert_eq!(heap.len(), 4); - assert_eq!(heap.next(), Some(11)); - assert_eq!(heap.next(), Some(9)); - assert_eq!(heap.next(), Some(4)); + assert_eq!(heap.pop(), Some(11)); + assert_eq!(heap.pop(), Some(9)); + assert_eq!(heap.pop(), Some(4)); heap.add(1); - assert_eq!(heap.next(), Some(2)); + assert_eq!(heap.pop(), Some(2)); + assert_eq!(heap.pop(), Some(1)); + assert_eq!(heap.pop(), None); } - struct Point(/* x */ i32, /* y */ i32); - impl Default for Point { - fn default() -> Self { - Self(0, 0) - } + #[test] + fn test_iter_heap() { + let mut heap = Heap::new_min(); + heap.add(4); + heap.add(2); + heap.add(9); + heap.add(11); + + let mut iter = heap.iter(); + assert_eq!(iter.next(), Some(&2)); + assert_eq!(iter.next(), Some(&4)); + assert_eq!(iter.next(), Some(&9)); + assert_eq!(iter.next(), Some(&11)); + assert_eq!(iter.next(), None); + + assert_eq!(heap.len(), 4); + assert_eq!(heap.pop(), Some(2)); + assert_eq!(heap.pop(), Some(4)); + assert_eq!(heap.pop(), Some(9)); + assert_eq!(heap.pop(), Some(11)); + assert_eq!(heap.pop(), None); + } + + #[test] + fn test_from_vec_min() { + let vec = vec![3, 1, 4, 1, 5, 9, 2, 6, 5]; + let mut heap = Heap::from_vec_min(vec); + assert_eq!(heap.len(), 9); + assert_eq!(heap.pop(), Some(1)); + assert_eq!(heap.pop(), Some(1)); + assert_eq!(heap.pop(), Some(2)); + heap.add(0); + assert_eq!(heap.pop(), Some(0)); } #[test] - fn test_key_heap() { - let mut heap: Heap = Heap::new(|a, b| a.0 < b.0); - heap.add(Point(1, 5)); - heap.add(Point(3, 10)); - heap.add(Point(-2, 4)); - assert_eq!(heap.len(), 3); - assert_eq!(heap.next().unwrap().0, -2); - assert_eq!(heap.next().unwrap().0, 1); - heap.add(Point(50, 34)); - assert_eq!(heap.next().unwrap().0, 3); + fn test_from_vec_max() { + let vec = vec![3, 1, 4, 1, 5, 9, 2, 6, 5]; + let mut heap = Heap::from_vec_max(vec); + assert_eq!(heap.len(), 9); + assert_eq!(heap.pop(), Some(9)); + assert_eq!(heap.pop(), Some(6)); + assert_eq!(heap.pop(), Some(5)); + heap.add(10); + assert_eq!(heap.pop(), Some(10)); } } diff --git a/src/data_structures/lazy_segment_tree.rs b/src/data_structures/lazy_segment_tree.rs new file mode 100644 index 00000000000..d34b0d35432 --- /dev/null +++ b/src/data_structures/lazy_segment_tree.rs @@ -0,0 +1,278 @@ +use std::fmt::{Debug, Display}; +use std::ops::{Add, AddAssign, Range}; + +pub struct LazySegmentTree> +{ + len: usize, + tree: Vec, + lazy: Vec>, + merge: fn(T, T) -> T, +} + +impl> LazySegmentTree { + pub fn from_vec(arr: &[T], merge: fn(T, T) -> T) -> Self { + let len = arr.len(); + let mut sgtr = LazySegmentTree { + len, + tree: vec![T::default(); 4 * len], + lazy: vec![None; 4 * len], + merge, + }; + if len != 0 { + sgtr.build_recursive(arr, 1, 0..len, merge); + } + sgtr + } + + fn build_recursive( + &mut self, + arr: &[T], + idx: usize, + range: Range, + merge: fn(T, T) -> T, + ) { + if range.end - range.start == 1 { + self.tree[idx] = arr[range.start]; + } else { + let mid = range.start + (range.end - range.start) / 2; + self.build_recursive(arr, 2 * idx, range.start..mid, merge); + self.build_recursive(arr, 2 * idx + 1, mid..range.end, merge); + self.tree[idx] = merge(self.tree[2 * idx], self.tree[2 * idx + 1]); + } + } + + pub fn query(&mut self, range: Range) -> Option { + self.query_recursive(1, 0..self.len, &range) + } + + fn query_recursive( + &mut self, + idx: usize, + element_range: Range, + query_range: &Range, + ) -> Option { + if element_range.start >= query_range.end || element_range.end <= query_range.start { + return None; + } + if self.lazy[idx].is_some() { + self.propagation(idx, &element_range, T::default()); + } + if element_range.start >= query_range.start && element_range.end <= query_range.end { + return Some(self.tree[idx]); + } + let mid = element_range.start + (element_range.end - element_range.start) / 2; + let left = self.query_recursive(idx * 2, element_range.start..mid, query_range); + let right = self.query_recursive(idx * 2 + 1, mid..element_range.end, query_range); + match (left, right) { + (None, None) => None, + (None, Some(r)) => Some(r), + (Some(l), None) => Some(l), + (Some(l), Some(r)) => Some((self.merge)(l, r)), + } + } + + pub fn update(&mut self, target_range: Range, val: T) { + self.update_recursive(1, 0..self.len, &target_range, val); + } + + fn update_recursive( + &mut self, + idx: usize, + element_range: Range, + target_range: &Range, + val: T, + ) { + if element_range.start >= target_range.end || element_range.end <= target_range.start { + return; + } + if element_range.end - element_range.start == 1 { + self.tree[idx] += val; + return; + } + if element_range.start >= target_range.start && element_range.end <= target_range.end { + self.lazy[idx] = match self.lazy[idx] { + Some(lazy) => Some(lazy + val), + None => Some(val), + }; + return; + } + if self.lazy[idx].is_some() && self.lazy[idx].unwrap() != T::default() { + self.propagation(idx, &element_range, T::default()); + } + let mid = element_range.start + (element_range.end - element_range.start) / 2; + self.update_recursive(idx * 2, element_range.start..mid, target_range, val); + self.update_recursive(idx * 2 + 1, mid..element_range.end, target_range, val); + self.tree[idx] = (self.merge)(self.tree[idx * 2], self.tree[idx * 2 + 1]); + self.lazy[idx] = Some(T::default()); + } + + fn propagation(&mut self, idx: usize, element_range: &Range, parent_lazy: T) { + if element_range.end - element_range.start == 1 { + self.tree[idx] += parent_lazy; + return; + } + + let lazy = self.lazy[idx].unwrap_or_default(); + self.lazy[idx] = None; + + let mid = element_range.start + (element_range.end - element_range.start) / 2; + self.propagation(idx * 2, &(element_range.start..mid), parent_lazy + lazy); + self.propagation(idx * 2 + 1, &(mid..element_range.end), parent_lazy + lazy); + self.tree[idx] = (self.merge)(self.tree[idx * 2], self.tree[idx * 2 + 1]); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use quickcheck::TestResult; + use quickcheck_macros::quickcheck; + use std::cmp::{max, min}; + + #[test] + fn test_min_segments() { + let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut min_seg_tree = LazySegmentTree::from_vec(&vec, min); + // [-30, 2, -4, 7, (3, -5, 6), 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(-5), min_seg_tree.query(4..7)); + // [(-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8)] + assert_eq!(Some(-30), min_seg_tree.query(0..vec.len())); + // [(-30, 2), -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(-30), min_seg_tree.query(0..2)); + // [-30, (2, -4), 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(-4), min_seg_tree.query(1..3)); + // [-30, (2, -4, 7, 3, -5, 6), 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(-5), min_seg_tree.query(1..7)); + } + + #[test] + fn test_max_segments() { + let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut max_seg_tree = LazySegmentTree::from_vec(&vec, max); + // [-30, 2, -4, 7, (3, -5, 6), 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(6), max_seg_tree.query(4..7)); + // [(-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8)] + assert_eq!(Some(15), max_seg_tree.query(0..vec.len())); + // [(-30, 2), -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(2), max_seg_tree.query(0..2)); + // [-30, (2, -4), 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(2), max_seg_tree.query(1..3)); + // [-30, (2, -4, 7, 3, -5, 6), 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(7), max_seg_tree.query(1..7)); + } + + #[test] + fn test_sum_segments() { + let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut max_seg_tree = LazySegmentTree::from_vec(&vec, |x, y| x + y); + // [-30, 2, -4, 7, (3, -5, 6), 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(4), max_seg_tree.query(4..7)); + // [(-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8)] + assert_eq!(Some(7), max_seg_tree.query(0..vec.len())); + // [(-30, 2), -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(-28), max_seg_tree.query(0..2)); + // [-30, (2, -4), 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(-2), max_seg_tree.query(1..3)); + // [-30, (2, -4, 7, 3, -5, 6), 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(9), max_seg_tree.query(1..7)); + } + + #[test] + fn test_update_segments_tiny() { + let vec = vec![0, 0, 0, 0, 0]; + let mut update_seg_tree = LazySegmentTree::from_vec(&vec, |x, y| x + y); + update_seg_tree.update(0..3, 3); + update_seg_tree.update(2..5, 3); + assert_eq!(Some(3), update_seg_tree.query(0..1)); + assert_eq!(Some(3), update_seg_tree.query(1..2)); + assert_eq!(Some(6), update_seg_tree.query(2..3)); + assert_eq!(Some(3), update_seg_tree.query(3..4)); + assert_eq!(Some(3), update_seg_tree.query(4..5)); + } + + #[test] + fn test_update_segments() { + let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut update_seg_tree = LazySegmentTree::from_vec(&vec, |x, y| x + y); + // -> [-30, (5, -1, 10, 6), -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + update_seg_tree.update(1..5, 3); + + // [-30, 5, -1, 10, (6 -5, 6), 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(7), update_seg_tree.query(4..7)); + // [(-30, 5, -1, 10, 6 , -5, 6, 11, -20, 9, 14, 15, 5, 2, -8)] + assert_eq!(Some(19), update_seg_tree.query(0..vec.len())); + // [(-30, 5), -1, 10, 6, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(-25), update_seg_tree.query(0..2)); + // [-30, (5, -1), 10, 6, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(4), update_seg_tree.query(1..3)); + // [-30, (5, -1, 10, 6, -5, 6), 11, -20, 9, 14, 15, 5, 2, -8] + assert_eq!(Some(21), update_seg_tree.query(1..7)); + } + + // Some properties over segment trees: + // When asking for the range of the overall array, return the same as iter().min() or iter().max(), etc. + // When asking for an interval containing a single value, return this value, no matter the merge function + + #[quickcheck] + fn check_overall_interval_min(array: Vec) -> TestResult { + let mut seg_tree = LazySegmentTree::from_vec(&array, min); + TestResult::from_bool(array.iter().min().copied() == seg_tree.query(0..array.len())) + } + + #[quickcheck] + fn check_overall_interval_max(array: Vec) -> TestResult { + let mut seg_tree = LazySegmentTree::from_vec(&array, max); + TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len())) + } + + #[quickcheck] + fn check_overall_interval_sum(array: Vec) -> TestResult { + let mut seg_tree = LazySegmentTree::from_vec(&array, max); + TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len())) + } + + #[quickcheck] + fn check_single_interval_min(array: Vec) -> TestResult { + let mut seg_tree = LazySegmentTree::from_vec(&array, min); + for (i, value) in array.into_iter().enumerate() { + let res = seg_tree.query(Range { + start: i, + end: i + 1, + }); + if res != Some(value) { + return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); + } + } + TestResult::passed() + } + + #[quickcheck] + fn check_single_interval_max(array: Vec) -> TestResult { + let mut seg_tree = LazySegmentTree::from_vec(&array, max); + for (i, value) in array.into_iter().enumerate() { + let res = seg_tree.query(Range { + start: i, + end: i + 1, + }); + if res != Some(value) { + return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); + } + } + TestResult::passed() + } + + #[quickcheck] + fn check_single_interval_sum(array: Vec) -> TestResult { + let mut seg_tree = LazySegmentTree::from_vec(&array, max); + for (i, value) in array.into_iter().enumerate() { + let res = seg_tree.query(Range { + start: i, + end: i + 1, + }); + if res != Some(value) { + return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); + } + } + TestResult::passed() + } +} diff --git a/src/data_structures/linked_list.rs b/src/data_structures/linked_list.rs index adb0f5fdbc7..5f782d82967 100644 --- a/src/data_structures/linked_list.rs +++ b/src/data_structures/linked_list.rs @@ -2,9 +2,9 @@ use std::fmt::{self, Display, Formatter}; use std::marker::PhantomData; use std::ptr::NonNull; -struct Node { - val: T, - next: Option>>, +pub struct Node { + pub val: T, + pub next: Option>>, prev: Option>>, } @@ -19,9 +19,9 @@ impl Node { } pub struct LinkedList { - length: u32, - head: Option>>, - tail: Option>>, + pub length: u32, + pub head: Option>>, + pub tail: Option>>, // Act like we own boxed nodes since we construct and leak them marker: PhantomData>>, } @@ -46,7 +46,7 @@ impl LinkedList { let mut node = Box::new(Node::new(obj)); node.next = self.head; node.prev = None; - let node_ptr = Some(unsafe { NonNull::new_unchecked(Box::into_raw(node)) }); + let node_ptr = NonNull::new(Box::into_raw(node)); match self.head { None => self.tail = node_ptr, Some(head_ptr) => unsafe { (*head_ptr.as_ptr()).prev = node_ptr }, @@ -59,7 +59,7 @@ impl LinkedList { let mut node = Box::new(Node::new(obj)); node.next = None; node.prev = self.tail; - let node_ptr = Some(unsafe { NonNull::new_unchecked(Box::into_raw(node)) }); + let node_ptr = NonNull::new(Box::into_raw(node)); match self.tail { None => self.head = node_ptr, Some(tail_ptr) => unsafe { (*tail_ptr.as_ptr()).next = node_ptr }, @@ -73,7 +73,7 @@ impl LinkedList { panic!("Index out of bounds"); } - if index == 0 || self.head == None { + if index == 0 || self.head.is_none() { self.insert_at_head(obj); return; } @@ -98,9 +98,10 @@ impl LinkedList { node.prev = (*ith_node.as_ptr()).prev; node.next = Some(ith_node); if let Some(p) = (*ith_node.as_ptr()).prev { - let node_ptr = Some(NonNull::new_unchecked(Box::into_raw(node))); + let node_ptr = NonNull::new(Box::into_raw(node)); println!("{:?}", (*p.as_ptr()).next); (*p.as_ptr()).next = node_ptr; + (*ith_node.as_ptr()).prev = node_ptr; self.length += 1; } } @@ -110,6 +111,10 @@ impl LinkedList { pub fn delete_head(&mut self) -> Option { // Safety: head_ptr points to a leaked boxed node managed by this list // We reassign pointers that pointed to the head node + if self.length == 0 { + return None; + } + self.head.map(|head_ptr| unsafe { let old_head = Box::from_raw(head_ptr.as_ptr()); match old_head.next { @@ -117,9 +122,10 @@ impl LinkedList { None => self.tail = None, } self.head = old_head.next; - self.length -= 1; + self.length = self.length.checked_add_signed(-1).unwrap_or(0); old_head.val }) + // None } pub fn delete_tail(&mut self) -> Option { @@ -138,15 +144,15 @@ impl LinkedList { } pub fn delete_ith(&mut self, index: u32) -> Option { - if self.length < index { + if self.length <= index { panic!("Index out of bounds"); } - if index == 0 || self.head == None { + if index == 0 || self.head.is_none() { return self.delete_head(); } - if self.length == index { + if self.length - 1 == index { return self.delete_tail(); } @@ -177,16 +183,16 @@ impl LinkedList { } } - pub fn get(&mut self, index: i32) -> Option<&T> { - self.get_ith_node(self.head, index) + pub fn get(&self, index: i32) -> Option<&T> { + Self::get_ith_node(self.head, index).map(|ptr| unsafe { &(*ptr.as_ptr()).val }) } - fn get_ith_node(&mut self, node: Option>>, index: i32) -> Option<&T> { + fn get_ith_node(node: Option>>, index: i32) -> Option>> { match node { None => None, Some(next_ptr) => match index { - 0 => Some(unsafe { &(*next_ptr.as_ptr()).val }), - _ => self.get_ith_node(unsafe { (*next_ptr.as_ptr()).next }, index - 1), + 0 => Some(next_ptr), + _ => Self::get_ith_node(unsafe { (*next_ptr.as_ptr()).next }, index - 1), }, } } @@ -235,10 +241,10 @@ mod tests { let second_value = 2; list.insert_at_tail(1); list.insert_at_tail(second_value); - println!("Linked List is {}", list); + println!("Linked List is {list}"); match list.get(1) { Some(val) => assert_eq!(*val, second_value), - None => panic!("Expected to find {} at index 1", second_value), + None => panic!("Expected to find {second_value} at index 1"), } } #[test] @@ -247,10 +253,10 @@ mod tests { let second_value = 2; list.insert_at_head(1); list.insert_at_head(second_value); - println!("Linked List is {}", list); + println!("Linked List is {list}"); match list.get(0) { Some(val) => assert_eq!(*val, second_value), - None => panic!("Expected to find {} at index 0", second_value), + None => panic!("Expected to find {second_value} at index 0"), } } @@ -260,10 +266,10 @@ mod tests { let second_value = 2; list.insert_at_ith(0, 0); list.insert_at_ith(1, second_value); - println!("Linked List is {}", list); + println!("Linked List is {list}"); match list.get(1) { Some(val) => assert_eq!(*val, second_value), - None => panic!("Expected to find {} at index 1", second_value), + None => panic!("Expected to find {second_value} at index 1"), } } @@ -273,10 +279,10 @@ mod tests { let second_value = 2; list.insert_at_ith(0, 1); list.insert_at_ith(0, second_value); - println!("Linked List is {}", list); + println!("Linked List is {list}"); match list.get(0) { Some(val) => assert_eq!(*val, second_value), - None => panic!("Expected to find {} at index 0", second_value), + None => panic!("Expected to find {second_value} at index 0"), } } @@ -288,15 +294,45 @@ mod tests { list.insert_at_ith(0, 1); list.insert_at_ith(1, second_value); list.insert_at_ith(1, third_value); - println!("Linked List is {}", list); + println!("Linked List is {list}"); match list.get(1) { Some(val) => assert_eq!(*val, third_value), - None => panic!("Expected to find {} at index 1", third_value), + None => panic!("Expected to find {third_value} at index 1"), } match list.get(2) { Some(val) => assert_eq!(*val, second_value), - None => panic!("Expected to find {} at index 1", second_value), + None => panic!("Expected to find {second_value} at index 1"), + } + } + + #[test] + fn insert_at_ith_and_delete_at_ith_in_the_middle() { + // Insert and delete in the middle of the list to ensure pointers are updated correctly + let mut list = LinkedList::::new(); + let first_value = 0; + let second_value = 1; + let third_value = 2; + let fourth_value = 3; + + list.insert_at_ith(0, first_value); + list.insert_at_ith(1, fourth_value); + list.insert_at_ith(1, third_value); + list.insert_at_ith(1, second_value); + + list.delete_ith(2); + list.insert_at_ith(2, third_value); + + for (i, expected) in [ + (0, first_value), + (1, second_value), + (2, third_value), + (3, fourth_value), + ] { + match list.get(i) { + Some(val) => assert_eq!(*val, expected), + None => panic!("Expected to find {expected} at index {i}"), + } } } @@ -343,13 +379,13 @@ mod tests { list.insert_at_tail(second_value); match list.delete_tail() { Some(val) => assert_eq!(val, 2), - None => panic!("Expected to remove {} at tail", second_value), + None => panic!("Expected to remove {second_value} at tail"), } - println!("Linked List is {}", list); + println!("Linked List is {list}"); match list.get(0) { Some(val) => assert_eq!(*val, first_value), - None => panic!("Expected to find {} at index 0", first_value), + None => panic!("Expected to find {first_value} at index 0"), } } @@ -362,13 +398,13 @@ mod tests { list.insert_at_tail(second_value); match list.delete_head() { Some(val) => assert_eq!(val, 1), - None => panic!("Expected to remove {} at head", first_value), + None => panic!("Expected to remove {first_value} at head"), } - println!("Linked List is {}", list); + println!("Linked List is {list}"); match list.get(0) { Some(val) => assert_eq!(*val, second_value), - None => panic!("Expected to find {} at index 0", second_value), + None => panic!("Expected to find {second_value} at index 0"), } } @@ -381,7 +417,7 @@ mod tests { list.insert_at_tail(second_value); match list.delete_ith(1) { Some(val) => assert_eq!(val, 2), - None => panic!("Expected to remove {} at tail", second_value), + None => panic!("Expected to remove {second_value} at tail"), } assert_eq!(list.length, 1); @@ -396,7 +432,7 @@ mod tests { list.insert_at_tail(second_value); match list.delete_ith(0) { Some(val) => assert_eq!(val, 1), - None => panic!("Expected to remove {} at tail", first_value), + None => panic!("Expected to remove {first_value} at tail"), } assert_eq!(list.length, 1); @@ -413,12 +449,12 @@ mod tests { list.insert_at_tail(third_value); match list.delete_ith(1) { Some(val) => assert_eq!(val, 2), - None => panic!("Expected to remove {} at tail", second_value), + None => panic!("Expected to remove {second_value} at tail"), } match list.get(1) { Some(val) => assert_eq!(*val, third_value), - None => panic!("Expected to find {} at index 1", third_value), + None => panic!("Expected to find {third_value} at index 1"), } } @@ -428,7 +464,7 @@ mod tests { list.insert_at_tail(1); list.insert_at_tail(2); list.insert_at_tail(3); - println!("Linked List is {}", list); + println!("Linked List is {list}"); assert_eq!(3, list.length); } @@ -438,7 +474,7 @@ mod tests { list_str.insert_at_tail("A".to_string()); list_str.insert_at_tail("B".to_string()); list_str.insert_at_tail("C".to_string()); - println!("Linked List is {}", list_str); + println!("Linked List is {list_str}"); assert_eq!(3, list_str.length); } @@ -447,10 +483,10 @@ mod tests { let mut list = LinkedList::::new(); list.insert_at_tail(1); list.insert_at_tail(2); - println!("Linked List is {}", list); + println!("Linked List is {list}"); let retrived_item = list.get(1); assert!(retrived_item.is_some()); - assert_eq!(2 as i32, *retrived_item.unwrap()); + assert_eq!(2, *retrived_item.unwrap()); } #[test] @@ -458,9 +494,19 @@ mod tests { let mut list_str = LinkedList::::new(); list_str.insert_at_tail("A".to_string()); list_str.insert_at_tail("B".to_string()); - println!("Linked List is {}", list_str); + println!("Linked List is {list_str}"); let retrived_item = list_str.get(1); assert!(retrived_item.is_some()); assert_eq!("B", *retrived_item.unwrap()); } + + #[test] + #[should_panic(expected = "Index out of bounds")] + fn delete_ith_panics_if_index_equals_length() { + let mut list = LinkedList::::new(); + list.insert_at_tail(1); + list.insert_at_tail(2); + // length is 2, so index 2 is out of bounds + list.delete_ith(2); + } } diff --git a/src/data_structures/mod.rs b/src/data_structures/mod.rs index 42d25dc3791..621ff290360 100644 --- a/src/data_structures/mod.rs +++ b/src/data_structures/mod.rs @@ -2,27 +2,44 @@ mod avl_tree; mod b_tree; mod binary_search_tree; mod fenwick_tree; -mod graph; +mod floyds_algorithm; +pub mod graph; +mod hash_table; mod heap; +mod lazy_segment_tree; mod linked_list; +mod probabilistic; mod queue; +mod range_minimum_query; mod rb_tree; mod segment_tree; +mod segment_tree_recursive; mod stack_using_singly_linked_list; +mod treap; mod trie; mod union_find; +mod veb_tree; pub use self::avl_tree::AVLTree; pub use self::b_tree::BTree; pub use self::binary_search_tree::BinarySearchTree; pub use self::fenwick_tree::FenwickTree; +pub use self::floyds_algorithm::{detect_cycle, has_cycle}; pub use self::graph::DirectedGraph; pub use self::graph::UndirectedGraph; -pub use self::heap::{Heap, MaxHeap, MinHeap}; +pub use self::hash_table::HashTable; +pub use self::heap::Heap; +pub use self::lazy_segment_tree::LazySegmentTree; pub use self::linked_list::LinkedList; +pub use self::probabilistic::bloom_filter; +pub use self::probabilistic::count_min_sketch; pub use self::queue::Queue; +pub use self::range_minimum_query::RangeMinimumQuery; pub use self::rb_tree::RBTree; pub use self::segment_tree::SegmentTree; +pub use self::segment_tree_recursive::SegmentTree as SegmentTreeRecursive; pub use self::stack_using_singly_linked_list::Stack; +pub use self::treap::Treap; pub use self::trie::Trie; pub use self::union_find::UnionFind; +pub use self::veb_tree::VebTree; diff --git a/src/data_structures/probabilistic/bloom_filter.rs b/src/data_structures/probabilistic/bloom_filter.rs new file mode 100644 index 00000000000..b75fe5b1c90 --- /dev/null +++ b/src/data_structures/probabilistic/bloom_filter.rs @@ -0,0 +1,269 @@ +use std::collections::hash_map::{DefaultHasher, RandomState}; +use std::hash::{BuildHasher, Hash, Hasher}; + +/// A Bloom Filter is a probabilistic data structure testing whether an element belongs to a set or not +/// Therefore, its contract looks very close to the one of a set, for example a `HashSet` +pub trait BloomFilter { + fn insert(&mut self, item: Item); + fn contains(&self, item: &Item) -> bool; +} + +/// What is the point of using a Bloom Filter if it acts like a Set? +/// Let's imagine we have a huge number of elements to store (like un unbounded data stream) a Set storing every element will most likely take up too much space, at some point. +/// As other probabilistic data structures like Count-min Sketch, the goal of a Bloom Filter is to trade off exactitude for constant space. +/// We won't have a strictly exact result of whether the value belongs to the set, but we'll use constant space instead + +/// Let's start with the basic idea behind the implementation +/// Let's start by trying to make a `HashSet` with constant space: +/// Instead of storing every element and grow the set infinitely, let's use a vector with constant capacity `CAPACITY` +/// Each element of this vector will be a boolean. +/// When a new element is inserted, we hash its value and set the index at index `hash(item) % CAPACITY` to `true` +/// When looking for an item, we hash its value and retrieve the boolean at index `hash(item) % CAPACITY` +/// If it's `false` it's absolutely sure the item isn't present +/// If it's `true` the item may be present, or maybe another one produces the same hash +#[derive(Debug)] +struct BasicBloomFilter { + vec: [bool; CAPACITY], +} + +impl Default for BasicBloomFilter { + fn default() -> Self { + Self { + vec: [false; CAPACITY], + } + } +} + +impl BloomFilter for BasicBloomFilter { + fn insert(&mut self, item: Item) { + let mut hasher = DefaultHasher::new(); + item.hash(&mut hasher); + let idx = (hasher.finish() % CAPACITY as u64) as usize; + self.vec[idx] = true; + } + + fn contains(&self, item: &Item) -> bool { + let mut hasher = DefaultHasher::new(); + item.hash(&mut hasher); + let idx = (hasher.finish() % CAPACITY as u64) as usize; + self.vec[idx] + } +} + +/// Can we improve it? Certainly, in different ways. +/// One pattern you may have identified here is that we use a "binary array" (a vector of binary values) +/// For instance, we might have `[0,1,0,0,1,0]`, which is actually the binary representation of 9 +/// This means we can immediately replace our `Vec` by an actual number +/// What would it mean to set a `1` at index `i`? +/// Imagine a `CAPACITY` of `6`. The initial value for our mask is `000000`. +/// We want to store `"Bloom"`. Its hash modulo `CAPACITY` is `5`. Which means we need to set `1` at the last index. +/// It can be performed by doing `000000 | 000001` +/// Meaning we can hash the item value, use a modulo to find the index, and do a binary `or` between the current number and the index +#[allow(dead_code)] +#[derive(Debug, Default)] +struct SingleBinaryBloomFilter { + fingerprint: u128, // let's use 128 bits, the equivalent of using CAPACITY=128 in the previous example +} + +/// Given a value and a hash function, compute the hash and return the bit mask +fn mask_128(hasher: &mut DefaultHasher, item: T) -> u128 { + item.hash(hasher); + let idx = (hasher.finish() % 128) as u32; + // idx is where we want to put a 1, let's convert this into a proper binary mask + 2_u128.pow(idx) +} + +impl BloomFilter for SingleBinaryBloomFilter { + fn insert(&mut self, item: T) { + self.fingerprint |= mask_128(&mut DefaultHasher::new(), &item); + } + + fn contains(&self, item: &T) -> bool { + (self.fingerprint & mask_128(&mut DefaultHasher::new(), item)) > 0 + } +} + +/// We may have made some progress in term of CPU efficiency, using binary operators. +/// But we might still run into a lot of collisions with a single 128-bits number. +/// Can we use greater numbers then? Currently, our implementation is limited to 128 bits. +/// +/// Should we go back to using an array, then? +/// We could! But instead of using `Vec` we could use `Vec`. +/// Each `u8` can act as a mask as we've done before, and is actually 1 byte in memory (same as a boolean!) +/// That'd allow us to go over 128 bits, but would divide by 8 the memory footprint. +/// That's one thing, and will involve dividing / shifting by 8 in different places. +/// +/// But still, can we reduce the collisions furthermore? +/// +/// As we did with count-min-sketch, we could use multiple hash function. +/// When inserting a value, we compute its hash with every hash function (`hash_i`) and perform the same operation as above (the OR with `fingerprint`) +/// Then when looking for a value, if **ANY** of the tests (`hash` then `AND`) returns 0 then this means the value is missing from the set, otherwise it would have returned 1 +/// If it returns `1`, it **may** be that the item is present, but could also be a collision +/// This is what a Bloom Filter is about: returning `false` means the value is necessarily absent, and returning true means it may be present +pub struct MultiBinaryBloomFilter { + filter_size: usize, + bytes: Vec, + hash_builders: Vec, +} + +impl MultiBinaryBloomFilter { + pub fn with_dimensions(filter_size: usize, hash_count: usize) -> Self { + let bytes_count = filter_size / 8 + usize::from(filter_size % 8 > 0); // we need 8 times less entries in the array, since we are using bytes. Careful that we have at least one element though + Self { + filter_size, + bytes: vec![0; bytes_count], + hash_builders: vec![RandomState::new(); hash_count], + } + } + + pub fn from_estimate( + estimated_count_of_items: usize, + max_false_positive_probability: f64, + ) -> Self { + // Check Wikipedia for these formulae + let optimal_filter_size = (-(estimated_count_of_items as f64) + * max_false_positive_probability.ln() + / (2.0_f64.ln().powi(2))) + .ceil() as usize; + let optimal_hash_count = ((optimal_filter_size as f64 / estimated_count_of_items as f64) + * 2.0_f64.ln()) + .ceil() as usize; + Self::with_dimensions(optimal_filter_size, optimal_hash_count) + } +} + +impl BloomFilter for MultiBinaryBloomFilter { + fn insert(&mut self, item: Item) { + for builder in &self.hash_builders { + let mut hasher = builder.build_hasher(); + item.hash(&mut hasher); + let hash = builder.hash_one(&item); + let index = hash % self.filter_size as u64; + let byte_index = index as usize / 8; // this is this byte that we need to modify + let bit_index = (index % 8) as u8; // we cannot only OR with value 1 this time, since we have 8 bits + self.bytes[byte_index] |= 1 << bit_index; + } + } + + fn contains(&self, item: &Item) -> bool { + for builder in &self.hash_builders { + let mut hasher = builder.build_hasher(); + item.hash(&mut hasher); + let hash = builder.hash_one(item); + let index = hash % self.filter_size as u64; + let byte_index = index as usize / 8; // this is this byte that we need to modify + let bit_index = (index % 8) as u8; // we cannot only OR with value 1 this time, since we have 8 bits + if self.bytes[byte_index] & (1 << bit_index) == 0 { + return false; + } + } + true + } +} + +#[cfg(test)] +mod tests { + use crate::data_structures::probabilistic::bloom_filter::{ + BasicBloomFilter, BloomFilter, MultiBinaryBloomFilter, SingleBinaryBloomFilter, + }; + use quickcheck::{Arbitrary, Gen}; + use quickcheck_macros::quickcheck; + use std::collections::HashSet; + + #[derive(Debug, Clone)] + struct TestSet { + to_insert: HashSet, + to_test: Vec, + } + + impl Arbitrary for TestSet { + fn arbitrary(g: &mut Gen) -> Self { + let mut qty = usize::arbitrary(g) % 5_000; + if qty < 50 { + qty += 50; // won't be perfectly uniformly distributed + } + let mut to_insert = HashSet::with_capacity(qty); + let mut to_test = Vec::with_capacity(qty); + for _ in 0..(qty) { + to_insert.insert(i32::arbitrary(g)); + to_test.push(i32::arbitrary(g)); + } + TestSet { to_insert, to_test } + } + } + + #[quickcheck] + fn basic_filter_must_not_return_false_negative(TestSet { to_insert, to_test }: TestSet) { + let mut basic_filter = BasicBloomFilter::<10_000>::default(); + for item in &to_insert { + basic_filter.insert(*item); + } + for other in to_test { + if !basic_filter.contains(&other) { + assert!(!to_insert.contains(&other)) + } + } + } + + #[quickcheck] + fn binary_filter_must_not_return_false_negative(TestSet { to_insert, to_test }: TestSet) { + let mut binary_filter = SingleBinaryBloomFilter::default(); + for item in &to_insert { + binary_filter.insert(*item); + } + for other in to_test { + if !binary_filter.contains(&other) { + assert!(!to_insert.contains(&other)) + } + } + } + + #[quickcheck] + fn a_basic_filter_of_capacity_128_is_the_same_as_a_binary_filter( + TestSet { to_insert, to_test }: TestSet, + ) { + let mut basic_filter = BasicBloomFilter::<128>::default(); // change 32 to anything else here, and the test won't pass + let mut binary_filter = SingleBinaryBloomFilter::default(); + for item in &to_insert { + basic_filter.insert(*item); + binary_filter.insert(*item); + } + for other in to_test { + // Since we use the same DefaultHasher::new(), and both have size 32, we should have exactly the same results + assert_eq!( + basic_filter.contains(&other), + binary_filter.contains(&other) + ); + } + } + + const FALSE_POSITIVE_MAX: f64 = 0.05; + + #[quickcheck] + fn a_multi_binary_bloom_filter_must_not_return_false_negatives( + TestSet { to_insert, to_test }: TestSet, + ) { + let n = to_insert.len(); + if n == 0 { + // avoid dividing by 0 when adjusting the size + return; + } + // See Wikipedia for those formula + let mut binary_filter = MultiBinaryBloomFilter::from_estimate(n, FALSE_POSITIVE_MAX); + for item in &to_insert { + binary_filter.insert(*item); + } + let tests = to_test.len(); + let mut false_positives = 0; + for other in to_test { + if !binary_filter.contains(&other) { + assert!(!to_insert.contains(&other)) + } else if !to_insert.contains(&other) { + // false positive + false_positives += 1; + } + } + let fp_rate = false_positives as f64 / tests as f64; + assert!(fp_rate < 1.0); // This isn't really a test, but so that you have the `fp_rate` variable to print out, or evaluate + } +} diff --git a/src/data_structures/probabilistic/count_min_sketch.rs b/src/data_structures/probabilistic/count_min_sketch.rs new file mode 100644 index 00000000000..0aec3bff577 --- /dev/null +++ b/src/data_structures/probabilistic/count_min_sketch.rs @@ -0,0 +1,247 @@ +use std::collections::hash_map::RandomState; +use std::fmt::{Debug, Formatter}; +use std::hash::{BuildHasher, Hash}; + +/// A probabilistic data structure holding an approximate count for diverse items efficiently (using constant space) +/// +/// Let's imagine we want to count items from an incoming (unbounded) data stream +/// One way to do this would be to hold a frequency hashmap, counting element hashes +/// This works extremely well, but unfortunately would require a lot of memory if we have a huge diversity of incoming items in the data stream +/// +/// CountMinSketch aims at solving this problem, trading off the exact count for an approximate one, but getting from potentially unbounded space complexity to constant complexity +/// See the implementation below for more details +/// +/// Here is the definition of the different allowed operations on a CountMinSketch: +/// * increment the count of an item +/// * retrieve the count of an item +pub trait CountMinSketch { + type Item; + + fn increment(&mut self, item: Self::Item); + fn increment_by(&mut self, item: Self::Item, count: usize); + fn get_count(&self, item: Self::Item) -> usize; +} + +/// The common implementation of a CountMinSketch +/// Holding a DEPTH x WIDTH matrix of counts +/// +/// The idea behind the implementation is the following: +/// Let's start from our problem statement above. We have a frequency map of counts, and want to go reduce its space complexity +/// The immediate way to do this would be to use a Vector with a fixed size, let this size be `WIDTH` +/// We will be holding the count of each item `item` in the Vector, at index `i = hash(item) % WIDTH` where `hash` is a hash function: `item -> usize` +/// We now have constant space. +/// +/// The problem though is that we'll potentially run into a lot of collisions. +/// Taking an extreme example, if `WIDTH = 1`, all items will have the same count, which is the sum of counts of every items +/// We could reduce the amount of collisions by using a bigger `WIDTH` but this wouldn't be way more efficient than the "big" frequency map +/// How do we improve the solution, but still keeping constant space? +/// +/// The idea is to use not just one vector, but multiple (`DEPTH`) ones and attach different `hash` functions to each vector +/// This would lead to the following data structure: +/// <- WIDTH = 5 -> +/// D hash1: [0, 0, 0, 0, 0] +/// E hash2: [0, 0, 0, 0, 0] +/// P hash3: [0, 0, 0, 0, 0] +/// T hash4: [0, 0, 0, 0, 0] +/// H hash5: [0, 0, 0, 0, 0] +/// = hash6: [0, 0, 0, 0, 0] +/// 7 hash7: [0, 0, 0, 0, 0] +/// Every hash function must return a different value for the same item. +/// Let's say we hash "TEST" and: +/// hash1("TEST") = 42 => idx = 2 +/// hash2("TEST") = 26 => idx = 1 +/// hash3("TEST") = 10 => idx = 0 +/// hash4("TEST") = 33 => idx = 3 +/// hash5("TEST") = 54 => idx = 4 +/// hash6("TEST") = 11 => idx = 1 +/// hash7("TEST") = 50 => idx = 0 +/// This would lead our structure to become: +/// <- WIDTH = 5 -> +/// D hash1: [0, 0, 1, 0, 0] +/// E hash2: [0, 1, 0, 0, 0] +/// P hash3: [1, 0, 0, 0, 0] +/// T hash4: [0, 0, 0, 1, 0] +/// H hash5: [0, 0, 0, 0, 1] +/// = hash6: [0, 1, 0, 0, 0] +/// 7 hash7: [1, 0, 0, 0, 0] +/// +/// Now say we hash "OTHER" and: +/// hash1("OTHER") = 23 => idx = 3 +/// hash2("OTHER") = 11 => idx = 1 +/// hash3("OTHER") = 52 => idx = 2 +/// hash4("OTHER") = 25 => idx = 0 +/// hash5("OTHER") = 31 => idx = 1 +/// hash6("OTHER") = 24 => idx = 4 +/// hash7("OTHER") = 30 => idx = 0 +/// Leading our data structure to become: +/// <- WIDTH = 5 -> +/// D hash1: [0, 0, 1, 1, 0] +/// E hash2: [0, 2, 0, 0, 0] +/// P hash3: [1, 0, 1, 0, 0] +/// T hash4: [1, 0, 0, 1, 0] +/// H hash5: [0, 1, 0, 0, 1] +/// = hash6: [0, 1, 0, 0, 1] +/// 7 hash7: [2, 0, 0, 0, 0] +/// +/// We actually can witness some collisions (invalid counts of `2` above in some rows). +/// This means that if we have to return the count for "TEST", we'd actually fetch counts from every row and return the minimum value +/// +/// This could potentially be overestimated if we have a huge number of entries and a lot of collisions. +/// But an interesting property is that the count we return for "TEST" cannot be underestimated +pub struct HashCountMinSketch { + phantom: std::marker::PhantomData, // just a marker for Item to be used + counts: [[usize; WIDTH]; DEPTH], + hashers: [RandomState; DEPTH], +} + +impl Debug + for HashCountMinSketch +{ + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Item").field("vecs", &self.counts).finish() + } +} + +impl Default + for HashCountMinSketch +{ + fn default() -> Self { + let hashers = std::array::from_fn(|_| RandomState::new()); + + Self { + phantom: std::marker::PhantomData, + counts: [[0; WIDTH]; DEPTH], + hashers, + } + } +} + +impl CountMinSketch + for HashCountMinSketch +{ + type Item = Item; + + fn increment(&mut self, item: Self::Item) { + self.increment_by(item, 1) + } + + fn increment_by(&mut self, item: Self::Item, count: usize) { + for (row, r) in self.hashers.iter_mut().enumerate() { + let mut h = r.build_hasher(); + item.hash(&mut h); + let hashed = r.hash_one(&item); + let col = (hashed % WIDTH as u64) as usize; + self.counts[row][col] += count; + } + } + + fn get_count(&self, item: Self::Item) -> usize { + self.hashers + .iter() + .enumerate() + .map(|(row, r)| { + let mut h = r.build_hasher(); + item.hash(&mut h); + let hashed = r.hash_one(&item); + let col = (hashed % WIDTH as u64) as usize; + self.counts[row][col] + }) + .min() + .unwrap() + } +} + +#[cfg(test)] +mod tests { + use crate::data_structures::probabilistic::count_min_sketch::{ + CountMinSketch, HashCountMinSketch, + }; + use quickcheck::{Arbitrary, Gen}; + use std::collections::HashSet; + + #[test] + fn hash_functions_should_hash_differently() { + let mut sketch: HashCountMinSketch<&str, 50, 50> = HashCountMinSketch::default(); // use a big DEPTH + sketch.increment("something"); + // We want to check that our hash functions actually produce different results, so we'll store the indices where we encounter a count=1 in a set + let mut indices_of_ones: HashSet = HashSet::default(); + for counts in sketch.counts { + let ones = counts + .into_iter() + .enumerate() + .filter_map(|(idx, count)| (count == 1).then_some(idx)) + .collect::>(); + assert_eq!(1, ones.len()); + indices_of_ones.insert(ones[0]); + } + // Given the parameters (WIDTH = 50, DEPTH = 50) it's extremely unlikely that all hash functions hash to the same index + assert!(indices_of_ones.len() > 1); // but we want to avoid a bug where all hash functions would produce the same hash (or hash to the same index) + } + + #[test] + fn inspect_counts() { + let mut sketch: HashCountMinSketch<&str, 5, 7> = HashCountMinSketch::default(); + sketch.increment("test"); + // Inspect internal state: + for counts in sketch.counts { + let zeroes = counts.iter().filter(|count| **count == 0).count(); + assert_eq!(4, zeroes); + let ones = counts.iter().filter(|count| **count == 1).count(); + assert_eq!(1, ones); + } + sketch.increment("test"); + for counts in sketch.counts { + let zeroes = counts.iter().filter(|count| **count == 0).count(); + assert_eq!(4, zeroes); + let twos = counts.iter().filter(|count| **count == 2).count(); + assert_eq!(1, twos); + } + + // This one is actually deterministic + assert_eq!(2, sketch.get_count("test")); + } + + #[derive(Debug, Clone, Eq, PartialEq, Hash)] + struct TestItem { + item: String, + count: usize, + } + + const MAX_STR_LEN: u8 = 30; + const MAX_COUNT: usize = 20; + + impl Arbitrary for TestItem { + fn arbitrary(g: &mut Gen) -> Self { + let str_len = u8::arbitrary(g) % MAX_STR_LEN; + let mut str = String::with_capacity(str_len as usize); + for _ in 0..str_len { + str.push(char::arbitrary(g)); + } + let count = usize::arbitrary(g) % MAX_COUNT; + TestItem { item: str, count } + } + } + + #[quickcheck_macros::quickcheck] + fn must_not_understimate_count(test_items: Vec) { + let test_items = test_items.into_iter().collect::>(); // remove duplicated (would lead to weird counts) + let n = test_items.len(); + let mut sketch: HashCountMinSketch = HashCountMinSketch::default(); + let mut exact_count = 0; + for TestItem { item, count } in &test_items { + sketch.increment_by(item.clone(), *count); + } + for TestItem { item, count } in test_items { + let stored_count = sketch.get_count(item); + assert!(stored_count >= count); + if count == stored_count { + exact_count += 1; + } + } + if n > 20 { + // if n is too short, the stat isn't really relevant + let exact_ratio = exact_count as f64 / n as f64; + assert!(exact_ratio > 0.7); // the proof is quite hard, but this should be OK + } + } +} diff --git a/src/data_structures/probabilistic/mod.rs b/src/data_structures/probabilistic/mod.rs new file mode 100644 index 00000000000..de55027f15f --- /dev/null +++ b/src/data_structures/probabilistic/mod.rs @@ -0,0 +1,2 @@ +pub mod bloom_filter; +pub mod count_min_sketch; diff --git a/src/data_structures/queue.rs b/src/data_structures/queue.rs index 3e4ed85334e..a0299155490 100644 --- a/src/data_structures/queue.rs +++ b/src/data_structures/queue.rs @@ -1,3 +1,8 @@ +//! This module provides a generic `Queue` data structure, implemented using +//! Rust's `LinkedList` from the standard library. The queue follows the FIFO +//! (First-In-First-Out) principle, where elements are added to the back of +//! the queue and removed from the front. + use std::collections::LinkedList; #[derive(Debug)] @@ -6,33 +11,50 @@ pub struct Queue { } impl Queue { + // Creates a new empty Queue pub fn new() -> Queue { Queue { elements: LinkedList::new(), } } + // Adds an element to the back of the queue pub fn enqueue(&mut self, value: T) { self.elements.push_back(value) } + // Removes and returns the front element from the queue, or None if empty pub fn dequeue(&mut self) -> Option { self.elements.pop_front() } + // Returns a reference to the front element of the queue, or None if empty pub fn peek_front(&self) -> Option<&T> { self.elements.front() } + // Returns a reference to the back element of the queue, or None if empty + pub fn peek_back(&self) -> Option<&T> { + self.elements.back() + } + + // Returns the number of elements in the queue pub fn len(&self) -> usize { self.elements.len() } + // Checks if the queue is empty pub fn is_empty(&self) -> bool { self.elements.is_empty() } + + // Clears all elements from the queue + pub fn drain(&mut self) { + self.elements.clear(); + } } +// Implementing the Default trait for Queue impl Default for Queue { fn default() -> Queue { Queue::new() @@ -44,35 +66,26 @@ mod tests { use super::Queue; #[test] - fn test_enqueue() { - let mut queue: Queue = Queue::new(); - queue.enqueue(64); - assert_eq!(queue.is_empty(), false); - } - - #[test] - fn test_dequeue() { - let mut queue: Queue = Queue::new(); - queue.enqueue(32); - queue.enqueue(64); - let retrieved_dequeue = queue.dequeue(); - assert_eq!(retrieved_dequeue, Some(32)); - } + fn test_queue_functionality() { + let mut queue: Queue = Queue::default(); - #[test] - fn test_peek_front() { - let mut queue: Queue = Queue::new(); + assert!(queue.is_empty()); queue.enqueue(8); queue.enqueue(16); - let retrieved_peek = queue.peek_front(); - assert_eq!(retrieved_peek, Some(&8)); - } + assert!(!queue.is_empty()); + assert_eq!(queue.len(), 2); - #[test] - fn test_size() { - let mut queue: Queue = Queue::new(); - queue.enqueue(8); - queue.enqueue(16); - assert_eq!(2, queue.len()); + assert_eq!(queue.peek_front(), Some(&8)); + assert_eq!(queue.peek_back(), Some(&16)); + + assert_eq!(queue.dequeue(), Some(8)); + assert_eq!(queue.len(), 1); + assert_eq!(queue.peek_front(), Some(&16)); + assert_eq!(queue.peek_back(), Some(&16)); + + queue.drain(); + assert!(queue.is_empty()); + assert_eq!(queue.len(), 0); + assert_eq!(queue.dequeue(), None); } } diff --git a/src/data_structures/range_minimum_query.rs b/src/data_structures/range_minimum_query.rs new file mode 100644 index 00000000000..8bb74a7a1fe --- /dev/null +++ b/src/data_structures/range_minimum_query.rs @@ -0,0 +1,194 @@ +//! Range Minimum Query (RMQ) Implementation +//! +//! This module provides an efficient implementation of a Range Minimum Query data structure using a +//! sparse table approach. It allows for quick retrieval of the minimum value within a specified subdata +//! of a given data after an initial preprocessing phase. +//! +//! The RMQ is particularly useful in scenarios requiring multiple queries on static data, as it +//! allows querying in constant time after an O(n log(n)) preprocessing time. +//! +//! References: [Wikipedia](https://en.wikipedia.org/wiki/Range_minimum_query) + +use std::cmp::PartialOrd; + +/// Custom error type for invalid range queries. +#[derive(Debug, PartialEq, Eq)] +pub enum RangeError { + /// Indicates that the provided range is invalid (start index is not less than end index). + InvalidRange, + /// Indicates that one or more indices are out of bounds for the data. + IndexOutOfBound, +} + +/// A data structure for efficiently answering range minimum queries on static data. +pub struct RangeMinimumQuery { + /// The original input data on which range queries are performed. + data: Vec, + /// The sparse table for storing preprocessed range minimum information. Each entry + /// contains the index of the minimum element in the range starting at `j` and having a length of `2^i`. + sparse_table: Vec>, +} + +impl RangeMinimumQuery { + /// Creates a new `RangeMinimumQuery` instance with the provided input data. + /// + /// # Arguments + /// + /// * `input` - A slice of elements of type `T` that implement `PartialOrd` and `Copy`. + /// + /// # Returns + /// + /// A `RangeMinimumQuery` instance that can be used to perform range minimum queries. + pub fn new(input: &[T]) -> RangeMinimumQuery { + RangeMinimumQuery { + data: input.to_vec(), + sparse_table: build_sparse_table(input), + } + } + + /// Retrieves the minimum value in the specified range [start, end). + /// + /// # Arguments + /// + /// * `start` - The starting index of the range (inclusive). + /// * `end` - The ending index of the range (exclusive). + /// + /// # Returns + /// + /// * `Ok(T)` - The minimum value found in the specified range. + /// * `Err(RangeError)` - An error indicating the reason for failure, such as an invalid range + /// or indices out of bounds. + pub fn get_range_min(&self, start: usize, end: usize) -> Result { + // Validate range + if start >= end { + return Err(RangeError::InvalidRange); + } + if start >= self.data.len() || end > self.data.len() { + return Err(RangeError::IndexOutOfBound); + } + + // Calculate the log length and the index for the sparse table + let log_len = (end - start).ilog2() as usize; + let idx: usize = end - (1 << log_len); + + // Retrieve the indices of the minimum values from the sparse table + let min_idx_start = self.sparse_table[log_len][start]; + let min_idx_end = self.sparse_table[log_len][idx]; + + // Compare the values at the retrieved indices and return the minimum + if self.data[min_idx_start] < self.data[min_idx_end] { + Ok(self.data[min_idx_start]) + } else { + Ok(self.data[min_idx_end]) + } + } +} + +/// Builds a sparse table for the provided data to support range minimum queries. +/// +/// # Arguments +/// +/// * `data` - A slice of elements of type `T` that implement `PartialOrd`. +/// +/// # Returns +/// +/// A 2D vector representing the sparse table, where each entry contains the index of the minimum +/// element in the range defined by the starting index and the power of two lengths. +fn build_sparse_table(data: &[T]) -> Vec> { + let mut sparse_table: Vec> = vec![(0..data.len()).collect()]; + let len = data.len(); + + // Fill the sparse table + for log_len in 1..=len.ilog2() { + let mut row = Vec::new(); + for idx in 0..=len - (1 << log_len) { + let min_idx_start = sparse_table[sparse_table.len() - 1][idx]; + let min_idx_end = sparse_table[sparse_table.len() - 1][idx + (1 << (log_len - 1))]; + if data[min_idx_start] < data[min_idx_end] { + row.push(min_idx_start); + } else { + row.push(min_idx_end); + } + } + sparse_table.push(row); + } + + sparse_table +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_build_sparse_table { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (data, expected) = $inputs; + assert_eq!(build_sparse_table(&data), expected); + } + )* + } + } + + test_build_sparse_table! { + small: ( + [1, 6, 3], + vec![ + vec![0, 1, 2], + vec![0, 2] + ] + ), + medium: ( + [1, 3, 6, 123, 7, 235, 3, -4, 6, 2], + vec![ + vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + vec![0, 1, 2, 4, 4, 6, 7, 7, 9], + vec![0, 1, 2, 6, 7, 7, 7], + vec![7, 7, 7] + ] + ), + large: ( + [20, 13, -13, 2, 3634, -2, 56, 3, 67, 8, 23, 0, -23, 1, 5, 85, 3, 24, 5, -10, 3, 4, 20], + vec![ + vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22], + vec![1, 2, 2, 3, 5, 5, 7, 7, 9, 9, 11, 12, 12, 13, 14, 16, 16, 18, 19, 19, 20, 21], + vec![2, 2, 2, 5, 5, 5, 7, 7, 11, 12, 12, 12, 12, 13, 16, 16, 19, 19, 19, 19], + vec![2, 2, 2, 5, 5, 12, 12, 12, 12, 12, 12, 12, 12, 19, 19, 19], + vec![12, 12, 12, 12, 12, 12, 12, 12] + ] + ), + } + + #[test] + fn simple_query_tests() { + let rmq = RangeMinimumQuery::new(&[1, 3, 6, 123, 7, 235, 3, -4, 6, 2]); + + assert_eq!(rmq.get_range_min(1, 6), Ok(3)); + assert_eq!(rmq.get_range_min(0, 10), Ok(-4)); + assert_eq!(rmq.get_range_min(8, 9), Ok(6)); + assert_eq!(rmq.get_range_min(4, 3), Err(RangeError::InvalidRange)); + assert_eq!(rmq.get_range_min(0, 1000), Err(RangeError::IndexOutOfBound)); + assert_eq!( + rmq.get_range_min(1000, 1001), + Err(RangeError::IndexOutOfBound) + ); + } + + #[test] + fn float_query_tests() { + let rmq = RangeMinimumQuery::new(&[0.4, -2.3, 0.0, 234.22, 12.2, -3.0]); + + assert_eq!(rmq.get_range_min(0, 6), Ok(-3.0)); + assert_eq!(rmq.get_range_min(0, 4), Ok(-2.3)); + assert_eq!(rmq.get_range_min(3, 5), Ok(12.2)); + assert_eq!(rmq.get_range_min(2, 3), Ok(0.0)); + assert_eq!(rmq.get_range_min(4, 3), Err(RangeError::InvalidRange)); + assert_eq!(rmq.get_range_min(0, 1000), Err(RangeError::IndexOutOfBound)); + assert_eq!( + rmq.get_range_min(1000, 1001), + Err(RangeError::IndexOutOfBound) + ); + } +} diff --git a/src/data_structures/rb_tree.rs b/src/data_structures/rb_tree.rs index 1d4b721cb9f..3465ad5d4d3 100644 --- a/src/data_structures/rb_tree.rs +++ b/src/data_structures/rb_tree.rs @@ -215,7 +215,7 @@ impl RBTree { } /* release resource */ - Box::from_raw(node); + drop(Box::from_raw(node)); if matches!(deleted_color, Color::Black) { delete_fixup(self, parent); } @@ -616,7 +616,7 @@ mod tests { #[test] fn find() { let mut tree = RBTree::::new(); - for (k, v) in String::from("hello, world!").chars().enumerate() { + for (k, v) in "hello, world!".chars().enumerate() { tree.insert(k, v); } assert_eq!(*tree.find(&3).unwrap_or(&'*'), 'l'); @@ -628,7 +628,7 @@ mod tests { #[test] fn insert() { let mut tree = RBTree::::new(); - for (k, v) in String::from("hello, world!").chars().enumerate() { + for (k, v) in "hello, world!".chars().enumerate() { tree.insert(k, v); } let s: String = tree.iter().map(|x| x.value).collect(); @@ -638,7 +638,7 @@ mod tests { #[test] fn delete() { let mut tree = RBTree::::new(); - for (k, v) in String::from("hello, world!").chars().enumerate() { + for (k, v) in "hello, world!".chars().enumerate() { tree.insert(k, v); } tree.delete(&1); diff --git a/src/data_structures/segment_tree.rs b/src/data_structures/segment_tree.rs index 8897dd623f6..f569381967e 100644 --- a/src/data_structures/segment_tree.rs +++ b/src/data_structures/segment_tree.rs @@ -1,87 +1,224 @@ -/// This stucture implements a segmented tree that -/// can efficiently answer range queries on arrays. -pub struct SegmentTree { - len: usize, - buf: Vec, - op: Ops, +//! A module providing a Segment Tree data structure for efficient range queries +//! and updates. It supports operations like finding the minimum, maximum, +//! and sum of segments in an array. + +use std::fmt::Debug; +use std::ops::Range; + +/// Custom error types representing possible errors that can occur during operations on the `SegmentTree`. +#[derive(Debug, PartialEq, Eq)] +pub enum SegmentTreeError { + /// Error indicating that an index is out of bounds. + IndexOutOfBounds, + /// Error indicating that a range provided for a query is invalid. + InvalidRange, } -pub enum Ops { - Max, - Min, +/// A structure representing a Segment Tree. This tree can be used to efficiently +/// perform range queries and updates on an array of elements. +pub struct SegmentTree +where + T: Debug + Default + Ord + Copy, + F: Fn(T, T) -> T, +{ + /// The length of the input array for which the segment tree is built. + size: usize, + /// A vector representing the segment tree. + nodes: Vec, + /// A merging function defined as a closure or callable type. + merge_fn: F, } -impl SegmentTree { - /// function to build the tree - pub fn from_vec(arr: &[T], op: Ops) -> Self { - let len = arr.len(); - let mut buf: Vec = vec![T::default(); 2 * len]; - buf[len..(len + len)].clone_from_slice(&arr[0..len]); - for i in (1..len).rev() { - buf[i] = match op { - Ops::Max => buf[2 * i].max(buf[2 * i + 1]), - Ops::Min => buf[2 * i].min(buf[2 * i + 1]), - }; +impl SegmentTree +where + T: Debug + Default + Ord + Copy, + F: Fn(T, T) -> T, +{ + /// Creates a new `SegmentTree` from the provided slice of elements. + /// + /// # Arguments + /// + /// * `arr`: A slice of elements of type `T` to initialize the segment tree. + /// * `merge`: A merging function that defines how to merge two elements of type `T`. + /// + /// # Returns + /// + /// A new `SegmentTree` instance populated with the given elements. + pub fn from_vec(arr: &[T], merge: F) -> Self { + let size = arr.len(); + let mut buffer: Vec = vec![T::default(); 2 * size]; + + // Populate the leaves of the tree + buffer[size..(2 * size)].clone_from_slice(arr); + for idx in (1..size).rev() { + buffer[idx] = merge(buffer[2 * idx], buffer[2 * idx + 1]); + } + + SegmentTree { + size, + nodes: buffer, + merge_fn: merge, } - SegmentTree { len, buf, op } } - /// function to get sum on interval [l, r] - pub fn query(&self, mut l: usize, mut r: usize) -> T { - l += self.len; - r += self.len; - let mut res = self.buf[l]; - while l <= r { - if l % 2 == 1 { - res = match self.op { - Ops::Max => res.max(self.buf[l]), - Ops::Min => res.min(self.buf[l]), - }; - l += 1; + /// Queries the segment tree for the result of merging the elements in the given range. + /// + /// # Arguments + /// + /// * `range`: A range specified as `Range`, indicating the start (inclusive) + /// and end (exclusive) indices of the segment to query. + /// + /// # Returns + /// + /// * `Ok(Some(result))` if the query was successful and there are elements in the range, + /// * `Ok(None)` if the range is empty, + /// * `Err(SegmentTreeError::InvalidRange)` if the provided range is invalid. + pub fn query(&self, range: Range) -> Result, SegmentTreeError> { + if range.start >= self.size || range.end > self.size { + return Err(SegmentTreeError::InvalidRange); + } + + let mut left = range.start + self.size; + let mut right = range.end + self.size; + let mut result = None; + + // Iterate through the segment tree to accumulate results + while left < right { + if left % 2 == 1 { + result = Some(match result { + None => self.nodes[left], + Some(old) => (self.merge_fn)(old, self.nodes[left]), + }); + left += 1; } - if r % 2 == 0 { - res = match self.op { - Ops::Max => res.max(self.buf[r]), - Ops::Min => res.min(self.buf[r]), - }; - r -= 1; + if right % 2 == 1 { + right -= 1; + result = Some(match result { + None => self.nodes[right], + Some(old) => (self.merge_fn)(old, self.nodes[right]), + }); } - l /= 2; - r /= 2; + left /= 2; + right /= 2; } - res + + Ok(result) } - /// function to update a tree node - pub fn update(&mut self, mut idx: usize, val: T) { - idx += self.len; - self.buf[idx] = val; - idx /= 2; - - while idx != 0 { - self.buf[idx] = match self.op { - Ops::Max => self.buf[2 * idx].max(self.buf[2 * idx + 1]), - Ops::Min => self.buf[2 * idx].min(self.buf[2 * idx + 1]), - }; - idx /= 2; + /// Updates the value at the specified index in the segment tree. + /// + /// # Arguments + /// + /// * `idx`: The index (0-based) of the element to update. + /// * `val`: The new value of type `T` to set at the specified index. + /// + /// # Returns + /// + /// * `Ok(())` if the update was successful, + /// * `Err(SegmentTreeError::IndexOutOfBounds)` if the index is out of bounds. + pub fn update(&mut self, idx: usize, val: T) -> Result<(), SegmentTreeError> { + if idx >= self.size { + return Err(SegmentTreeError::IndexOutOfBounds); + } + + let mut index = idx + self.size; + if self.nodes[index] == val { + return Ok(()); } + + self.nodes[index] = val; + while index > 1 { + index /= 2; + self.nodes[index] = (self.merge_fn)(self.nodes[2 * index], self.nodes[2 * index + 1]); + } + + Ok(()) } } #[cfg(test)] mod tests { use super::*; + use std::cmp::{max, min}; + + #[test] + fn test_min_segments() { + let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut min_seg_tree = SegmentTree::from_vec(&vec, min); + assert_eq!(min_seg_tree.query(4..7), Ok(Some(-5))); + assert_eq!(min_seg_tree.query(0..vec.len()), Ok(Some(-30))); + assert_eq!(min_seg_tree.query(0..2), Ok(Some(-30))); + assert_eq!(min_seg_tree.query(1..3), Ok(Some(-4))); + assert_eq!(min_seg_tree.query(1..7), Ok(Some(-5))); + assert_eq!(min_seg_tree.update(5, 10), Ok(())); + assert_eq!(min_seg_tree.update(14, -8), Ok(())); + assert_eq!(min_seg_tree.query(4..7), Ok(Some(3))); + assert_eq!( + min_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(min_seg_tree.query(5..5), Ok(None)); + assert_eq!( + min_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + min_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) + ); + } + + #[test] + fn test_max_segments() { + let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut max_seg_tree = SegmentTree::from_vec(&vec, max); + assert_eq!(max_seg_tree.query(0..vec.len()), Ok(Some(15))); + assert_eq!(max_seg_tree.query(3..5), Ok(Some(7))); + assert_eq!(max_seg_tree.query(4..8), Ok(Some(11))); + assert_eq!(max_seg_tree.query(8..10), Ok(Some(9))); + assert_eq!(max_seg_tree.query(9..12), Ok(Some(15))); + assert_eq!(max_seg_tree.update(4, 10), Ok(())); + assert_eq!(max_seg_tree.update(14, -8), Ok(())); + assert_eq!(max_seg_tree.query(3..5), Ok(Some(10))); + assert_eq!( + max_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(max_seg_tree.query(5..5), Ok(None)); + assert_eq!( + max_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + max_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) + ); + } #[test] - fn it_works() { + fn test_sum_segments() { let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; - let min_seg_tree = SegmentTree::from_vec(&vec, Ops::Min); - assert_eq!(-5, min_seg_tree.query(4, 6)); - assert_eq!(-20, min_seg_tree.query(0, vec.len() - 1)); - let mut max_seg_tree = SegmentTree::from_vec(&vec, Ops::Max); - assert_eq!(6, max_seg_tree.query(4, 6)); - assert_eq!(15, max_seg_tree.query(0, vec.len() - 1)); - max_seg_tree.update(6, 8); - assert_eq!(8, max_seg_tree.query(4, 6)); + let mut sum_seg_tree = SegmentTree::from_vec(&vec, |a, b| a + b); + assert_eq!(sum_seg_tree.query(0..vec.len()), Ok(Some(38))); + assert_eq!(sum_seg_tree.query(1..4), Ok(Some(5))); + assert_eq!(sum_seg_tree.query(4..7), Ok(Some(4))); + assert_eq!(sum_seg_tree.query(6..9), Ok(Some(-3))); + assert_eq!(sum_seg_tree.query(9..vec.len()), Ok(Some(37))); + assert_eq!(sum_seg_tree.update(5, 10), Ok(())); + assert_eq!(sum_seg_tree.update(14, -8), Ok(())); + assert_eq!(sum_seg_tree.query(4..7), Ok(Some(19))); + assert_eq!( + sum_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(sum_seg_tree.query(5..5), Ok(None)); + assert_eq!( + sum_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + sum_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) + ); } } diff --git a/src/data_structures/segment_tree_recursive.rs b/src/data_structures/segment_tree_recursive.rs new file mode 100644 index 00000000000..7a64a978563 --- /dev/null +++ b/src/data_structures/segment_tree_recursive.rs @@ -0,0 +1,263 @@ +use std::fmt::Debug; +use std::ops::Range; + +/// Custom error types representing possible errors that can occur during operations on the `SegmentTree`. +#[derive(Debug, PartialEq, Eq)] +pub enum SegmentTreeError { + /// Error indicating that an index is out of bounds. + IndexOutOfBounds, + /// Error indicating that a range provided for a query is invalid. + InvalidRange, +} + +/// A data structure representing a Segment Tree. Which is used for efficient +/// range queries and updates on an array of elements. +pub struct SegmentTree +where + T: Debug + Default + Ord + Copy, + F: Fn(T, T) -> T, +{ + /// The number of elements in the original input array for which the segment tree is built. + size: usize, + /// A vector representing the nodes of the segment tree. + nodes: Vec, + /// A function that merges two elements of type `T`. + merge_fn: F, +} + +impl SegmentTree +where + T: Debug + Default + Ord + Copy, + F: Fn(T, T) -> T, +{ + /// Creates a new `SegmentTree` from the provided slice of elements. + /// + /// # Arguments + /// + /// * `arr`: A slice of elements of type `T` that initializes the segment tree. + /// * `merge_fn`: A merging function that specifies how to combine two elements of type `T`. + /// + /// # Returns + /// + /// A new `SegmentTree` instance initialized with the given elements. + pub fn from_vec(arr: &[T], merge_fn: F) -> Self { + let size = arr.len(); + let mut seg_tree = SegmentTree { + size, + nodes: vec![T::default(); 4 * size], + merge_fn, + }; + if size != 0 { + seg_tree.build_recursive(arr, 1, 0..size); + } + seg_tree + } + + /// Recursively builds the segment tree from the provided array. + /// + /// # Parameters + /// + /// * `arr` - The original array of values. + /// * `node_idx` - The index of the current node in the segment tree. + /// * `node_range` - The range of elements in the original array that the current node covers. + fn build_recursive(&mut self, arr: &[T], node_idx: usize, node_range: Range) { + if node_range.end - node_range.start == 1 { + self.nodes[node_idx] = arr[node_range.start]; + } else { + let mid = node_range.start + (node_range.end - node_range.start) / 2; + self.build_recursive(arr, 2 * node_idx, node_range.start..mid); + self.build_recursive(arr, 2 * node_idx + 1, mid..node_range.end); + self.nodes[node_idx] = + (self.merge_fn)(self.nodes[2 * node_idx], self.nodes[2 * node_idx + 1]); + } + } + + /// Queries the segment tree for the result of merging the elements in the specified range. + /// + /// # Arguments + /// + /// * `target_range`: A range specified as `Range`, indicating the start (inclusive) + /// and end (exclusive) indices of the segment to query. + /// + /// # Returns + /// + /// * `Ok(Some(result))` if the query is successful and there are elements in the range, + /// * `Ok(None)` if the range is empty, + /// * `Err(SegmentTreeError::InvalidRange)` if the provided range is invalid. + pub fn query(&self, target_range: Range) -> Result, SegmentTreeError> { + if target_range.start >= self.size || target_range.end > self.size { + return Err(SegmentTreeError::InvalidRange); + } + Ok(self.query_recursive(1, 0..self.size, &target_range)) + } + + /// Recursively performs a range query to find the merged result of the specified range. + /// + /// # Parameters + /// + /// * `node_idx` - The index of the current node in the segment tree. + /// * `tree_range` - The range of elements covered by the current node. + /// * `target_range` - The range for which the query is being performed. + /// + /// # Returns + /// + /// An `Option` containing the result of the merge operation on the range if within bounds, + /// or `None` if the range is outside the covered range. + fn query_recursive( + &self, + node_idx: usize, + tree_range: Range, + target_range: &Range, + ) -> Option { + if tree_range.start >= target_range.end || tree_range.end <= target_range.start { + return None; + } + if tree_range.start >= target_range.start && tree_range.end <= target_range.end { + return Some(self.nodes[node_idx]); + } + let mid = tree_range.start + (tree_range.end - tree_range.start) / 2; + let left_res = self.query_recursive(node_idx * 2, tree_range.start..mid, target_range); + let right_res = self.query_recursive(node_idx * 2 + 1, mid..tree_range.end, target_range); + match (left_res, right_res) { + (None, None) => None, + (None, Some(r)) => Some(r), + (Some(l), None) => Some(l), + (Some(l), Some(r)) => Some((self.merge_fn)(l, r)), + } + } + + /// Updates the value at the specified index in the segment tree. + /// + /// # Arguments + /// + /// * `target_idx`: The index (0-based) of the element to update. + /// * `val`: The new value of type `T` to set at the specified index. + /// + /// # Returns + /// + /// * `Ok(())` if the update was successful, + /// * `Err(SegmentTreeError::IndexOutOfBounds)` if the index is out of bounds. + pub fn update(&mut self, target_idx: usize, val: T) -> Result<(), SegmentTreeError> { + if target_idx >= self.size { + return Err(SegmentTreeError::IndexOutOfBounds); + } + self.update_recursive(1, 0..self.size, target_idx, val); + Ok(()) + } + + /// Recursively updates the segment tree for a specific index with a new value. + /// + /// # Parameters + /// + /// * `node_idx` - The index of the current node in the segment tree. + /// * `tree_range` - The range of elements covered by the current node. + /// * `target_idx` - The index in the original array to update. + /// * `val` - The new value to set at `target_idx`. + fn update_recursive( + &mut self, + node_idx: usize, + tree_range: Range, + target_idx: usize, + val: T, + ) { + if tree_range.start > target_idx || tree_range.end <= target_idx { + return; + } + if tree_range.end - tree_range.start <= 1 && tree_range.start == target_idx { + self.nodes[node_idx] = val; + return; + } + let mid = tree_range.start + (tree_range.end - tree_range.start) / 2; + self.update_recursive(node_idx * 2, tree_range.start..mid, target_idx, val); + self.update_recursive(node_idx * 2 + 1, mid..tree_range.end, target_idx, val); + self.nodes[node_idx] = + (self.merge_fn)(self.nodes[node_idx * 2], self.nodes[node_idx * 2 + 1]); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::cmp::{max, min}; + + #[test] + fn test_min_segments() { + let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut min_seg_tree = SegmentTree::from_vec(&vec, min); + assert_eq!(min_seg_tree.query(4..7), Ok(Some(-5))); + assert_eq!(min_seg_tree.query(0..vec.len()), Ok(Some(-30))); + assert_eq!(min_seg_tree.query(0..2), Ok(Some(-30))); + assert_eq!(min_seg_tree.query(1..3), Ok(Some(-4))); + assert_eq!(min_seg_tree.query(1..7), Ok(Some(-5))); + assert_eq!(min_seg_tree.update(5, 10), Ok(())); + assert_eq!(min_seg_tree.update(14, -8), Ok(())); + assert_eq!(min_seg_tree.query(4..7), Ok(Some(3))); + assert_eq!( + min_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(min_seg_tree.query(5..5), Ok(None)); + assert_eq!( + min_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + min_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) + ); + } + + #[test] + fn test_max_segments() { + let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut max_seg_tree = SegmentTree::from_vec(&vec, max); + assert_eq!(max_seg_tree.query(0..vec.len()), Ok(Some(15))); + assert_eq!(max_seg_tree.query(3..5), Ok(Some(7))); + assert_eq!(max_seg_tree.query(4..8), Ok(Some(11))); + assert_eq!(max_seg_tree.query(8..10), Ok(Some(9))); + assert_eq!(max_seg_tree.query(9..12), Ok(Some(15))); + assert_eq!(max_seg_tree.update(4, 10), Ok(())); + assert_eq!(max_seg_tree.update(14, -8), Ok(())); + assert_eq!(max_seg_tree.query(3..5), Ok(Some(10))); + assert_eq!( + max_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(max_seg_tree.query(5..5), Ok(None)); + assert_eq!( + max_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + max_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) + ); + } + + #[test] + fn test_sum_segments() { + let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; + let mut sum_seg_tree = SegmentTree::from_vec(&vec, |a, b| a + b); + assert_eq!(sum_seg_tree.query(0..vec.len()), Ok(Some(38))); + assert_eq!(sum_seg_tree.query(1..4), Ok(Some(5))); + assert_eq!(sum_seg_tree.query(4..7), Ok(Some(4))); + assert_eq!(sum_seg_tree.query(6..9), Ok(Some(-3))); + assert_eq!(sum_seg_tree.query(9..vec.len()), Ok(Some(37))); + assert_eq!(sum_seg_tree.update(5, 10), Ok(())); + assert_eq!(sum_seg_tree.update(14, -8), Ok(())); + assert_eq!(sum_seg_tree.query(4..7), Ok(Some(19))); + assert_eq!( + sum_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(sum_seg_tree.query(5..5), Ok(None)); + assert_eq!( + sum_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + sum_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) + ); + } +} diff --git a/src/data_structures/stack_using_singly_linked_list.rs b/src/data_structures/stack_using_singly_linked_list.rs index c3a2ad0b169..f3f6db1d5c9 100644 --- a/src/data_structures/stack_using_singly_linked_list.rs +++ b/src/data_structures/stack_using_singly_linked_list.rs @@ -1,4 +1,4 @@ -// the public struct can hide the implementation detail +// The public struct can hide the implementation detail pub struct Stack { head: Link, } @@ -21,8 +21,8 @@ impl Stack { } // we need to return the variant, so there without the ; } - // As we know the primary forms that self can take: self, &mut self and &self, push will change the linked list - // so we need &mut + // Here are the primary forms that self can take are: self, &mut self and &self. + // Since push will modify the linked list, we need a mutable reference `&mut`. // The push method which the signature's first parameter is self pub fn push(&mut self, elem: T) { let new_node = Box::new(Node { @@ -32,17 +32,17 @@ impl Stack { // don't forget replace the head with new node for stack self.head = Some(new_node); } + + /// The pop function removes the head and returns its value. /// - /// In pop function, we trying to: - /// * check if the list is empty, so we use enum Option, it can either be Some(T) or None - /// * if it's empty, return None - /// * if it's not empty - /// * remove the head of the list - /// * remove its elem - /// * replace the list's head with its next - /// * return Some(elem), as the situation if need - /// - /// so, we need to remove the head, and return the value of the head + /// To do so, we'll need to match the `head` of the list, which is of enum type `Option`.\ + /// It has two variants: `Some(T)` and `None`. + /// * `None` - the list is empty: + /// * return an enum `Result` of variant `Err()`, as there is nothing to pop. + /// * `Some(node)` - the list is not empty: + /// * remove the head of the list, + /// * relink the list's head `head` to its following node `next`, + /// * return `Ok(elem)`. pub fn pop(&mut self) -> Result { match self.head.take() { None => Err("Stack is empty"), @@ -54,7 +54,7 @@ impl Stack { } pub fn is_empty(&self) -> bool { - // Returns true if the option is a [None] value. + // Returns true if head is of variant `None`. self.head.is_none() } @@ -95,16 +95,17 @@ impl Default for Stack { } } -/// The drop method of singly linked list. There's a question that do we need to worry about cleaning up our list? -/// As we all know the ownership and borrow mechanism, so we know the type will clean automatically after it goes out the scope, -/// this implement by the Rust compiler automatically did which mean add trait `drop` for the automatically. +/// The drop method of singly linked list. /// -/// So, the complier will implements Drop for `List->Link->Box ->Node` automatically and tail recursive to clean the elements -/// one by one. And we know the recursive will stop at Box -/// https://rust-unofficial.github.io/too-many-lists/first-drop.html +/// Here's a question: *Do we need to worry about cleaning up our list?*\ +/// With the help of the ownership mechanism, the type `List` will be cleaned up automatically (dropped) after it goes out of scope.\ +/// The Rust Compiler does so automacally. In other words, the `Drop` trait is implemented automatically.\ /// -/// As we know we can't drop the contents of the Box after deallocating, so we need to manually write the iterative drop - +/// The `Drop` trait is implemented for our type `List` with the following order: `List->Link->Box->Node`.\ +/// The `.drop()` method is tail recursive and will clean the element one by one, this recursion will stop at `Box`\ +/// +/// +/// We wouldn't be able to drop the contents contained by the box after deallocating, so we need to manually write the iterative drop. impl Drop for Stack { fn drop(&mut self) { let mut cur_link = self.head.take(); @@ -117,7 +118,7 @@ impl Drop for Stack { } } -/// Rust has nothing like a yield statement, and there's actually 3 different kinds of iterator should to implement +// Rust has nothing like a yield statement, and there are actually 3 different iterator traits to be implemented // Collections are iterated in Rust using the Iterator trait, we define a struct implement Iterator pub struct IntoIter(Stack); @@ -181,7 +182,7 @@ mod test_stack { list.push(4); list.push(5); - assert_eq!(list.is_empty(), false); + assert!(!list.is_empty()); assert_eq!(list.pop(), Ok(5)); assert_eq!(list.pop(), Ok(4)); @@ -189,7 +190,7 @@ mod test_stack { assert_eq!(list.pop(), Ok(1)); assert_eq!(list.pop(), Err("Stack is empty")); - assert_eq!(list.is_empty(), true); + assert!(list.is_empty()); } #[test] @@ -204,8 +205,8 @@ mod test_stack { assert_eq!(list.peek_mut(), Some(&mut 3)); match list.peek_mut() { - None => None, - Some(value) => Some(*value = 42), + None => (), + Some(value) => *value = 42, }; assert_eq!(list.peek(), Some(&42)); diff --git a/src/data_structures/treap.rs b/src/data_structures/treap.rs new file mode 100644 index 00000000000..e78e782a66d --- /dev/null +++ b/src/data_structures/treap.rs @@ -0,0 +1,352 @@ +use std::{ + cmp::Ordering, + iter::FromIterator, + mem, + ops::Not, + time::{SystemTime, UNIX_EPOCH}, +}; + +/// An internal node of an `Treap`. +struct TreapNode { + value: T, + priority: usize, + left: Option>>, + right: Option>>, +} + +/// A set based on a Treap (Randomized Binary Search Tree). +/// +/// A Treap is a self-balancing binary search tree. It matains a priority value for each node, such +/// that for every node, its children will have lower priority than itself. So, by just looking at +/// the priority, it is like a heap, and this is where the name, Treap, comes from, Tree + Heap. +pub struct Treap { + root: Option>>, + length: usize, +} + +/// Refers to the left or right subtree of a `Treap`. +#[derive(Clone, Copy)] +enum Side { + Left, + Right, +} + +impl Treap { + pub fn new() -> Treap { + Treap { + root: None, + length: 0, + } + } + + /// Returns `true` if the tree contains a value. + pub fn contains(&self, value: &T) -> bool { + let mut current = &self.root; + while let Some(node) = current { + current = match value.cmp(&node.value) { + Ordering::Equal => return true, + Ordering::Less => &node.left, + Ordering::Greater => &node.right, + } + } + false + } + + /// Adds a value to the tree + /// + /// Returns `true` if the tree did not yet contain the value. + pub fn insert(&mut self, value: T) -> bool { + let inserted = insert(&mut self.root, value); + if inserted { + self.length += 1; + } + inserted + } + + /// Removes a value from the tree. + /// + /// Returns `true` if the tree contained the value. + pub fn remove(&mut self, value: &T) -> bool { + let removed = remove(&mut self.root, value); + if removed { + self.length -= 1; + } + removed + } + + /// Returns the number of values in the tree. + pub fn len(&self) -> usize { + self.length + } + + /// Returns `true` if the tree contains no values. + pub fn is_empty(&self) -> bool { + self.length == 0 + } + + /// Returns an iterator that visits the nodes in the tree in order. + fn node_iter(&self) -> NodeIter { + let mut node_iter = NodeIter { stack: Vec::new() }; + // Initialize stack with path to leftmost child + let mut child = &self.root; + while let Some(node) = child { + node_iter.stack.push(node.as_ref()); + child = &node.left; + } + node_iter + } + + /// Returns an iterator that visits the values in the tree in ascending order. + pub fn iter(&self) -> Iter { + Iter { + node_iter: self.node_iter(), + } + } +} + +/// Generating random number, should use rand::Rng if possible. +fn rand() -> usize { + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .subsec_nanos() as usize +} + +/// Recursive helper function for `Treap` insertion. +fn insert(tree: &mut Option>>, value: T) -> bool { + if let Some(node) = tree { + let inserted = match value.cmp(&node.value) { + Ordering::Equal => false, + Ordering::Less => insert(&mut node.left, value), + Ordering::Greater => insert(&mut node.right, value), + }; + if inserted { + node.rebalance(); + } + inserted + } else { + *tree = Some(Box::new(TreapNode { + value, + priority: rand(), + left: None, + right: None, + })); + true + } +} + +/// Recursive helper function for `Treap` deletion +fn remove(tree: &mut Option>>, value: &T) -> bool { + if let Some(node) = tree { + let removed = match value.cmp(&node.value) { + Ordering::Less => remove(&mut node.left, value), + Ordering::Greater => remove(&mut node.right, value), + Ordering::Equal => { + *tree = match (node.left.take(), node.right.take()) { + (None, None) => None, + (Some(b), None) | (None, Some(b)) => Some(b), + (Some(left), Some(right)) => { + let side = match left.priority.cmp(&right.priority) { + Ordering::Greater => Side::Right, + _ => Side::Left, + }; + node.left = Some(left); + node.right = Some(right); + node.rotate(side); + remove(node.child_mut(side), value); + Some(tree.take().unwrap()) + } + }; + return true; + } + }; + if removed { + node.rebalance(); + } + removed + } else { + false + } +} + +impl TreapNode { + /// Returns a reference to the left or right child. + fn child(&self, side: Side) -> &Option>> { + match side { + Side::Left => &self.left, + Side::Right => &self.right, + } + } + + /// Returns a mutable reference to the left or right child. + fn child_mut(&mut self, side: Side) -> &mut Option>> { + match side { + Side::Left => &mut self.left, + Side::Right => &mut self.right, + } + } + + /// Returns the priority of the left or right subtree. + fn priority(&self, side: Side) -> usize { + self.child(side).as_ref().map_or(0, |n| n.priority) + } + + /// Performs a left or right rotation + fn rotate(&mut self, side: Side) { + if self.child_mut(!side).is_none() { + return; + } + + let mut subtree = self.child_mut(!side).take().unwrap(); + *self.child_mut(!side) = subtree.child_mut(side).take(); + // Swap root and child nodes in memory + mem::swap(self, subtree.as_mut()); + // Set old root (subtree) as child of new root (self) + *self.child_mut(side) = Some(subtree); + } + + /// Performs left or right tree rotations to balance this node. + fn rebalance(&mut self) { + match ( + self.priority, + self.priority(Side::Left), + self.priority(Side::Right), + ) { + (v, p, q) if p >= q && p > v => self.rotate(Side::Right), + (v, p, q) if p < q && q > v => self.rotate(Side::Left), + _ => (), + }; + } + + #[cfg(test)] + fn is_valid(&self) -> bool { + self.priority >= self.priority(Side::Left) && self.priority >= self.priority(Side::Right) + } +} + +impl Default for Treap { + fn default() -> Self { + Self::new() + } +} + +impl Not for Side { + type Output = Side; + + fn not(self) -> Self::Output { + match self { + Side::Left => Side::Right, + Side::Right => Side::Left, + } + } +} + +impl FromIterator for Treap { + fn from_iter>(iter: I) -> Self { + let mut tree = Treap::new(); + for value in iter { + tree.insert(value); + } + tree + } +} + +/// An iterator over the nodes of an `Treap`. +/// +/// This struct is created by the `node_iter` method of `Treap`. +struct NodeIter<'a, T: Ord> { + stack: Vec<&'a TreapNode>, +} + +impl<'a, T: Ord> Iterator for NodeIter<'a, T> { + type Item = &'a TreapNode; + + fn next(&mut self) -> Option { + if let Some(node) = self.stack.pop() { + // Push left path of right subtree to stack + let mut child = &node.right; + while let Some(subtree) = child { + self.stack.push(subtree.as_ref()); + child = &subtree.left; + } + Some(node) + } else { + None + } + } +} + +/// An iterator over the items of an `Treap`. +/// +/// This struct is created by the `iter` method of `Treap`. +pub struct Iter<'a, T: Ord> { + node_iter: NodeIter<'a, T>, +} + +impl<'a, T: Ord> Iterator for Iter<'a, T> { + type Item = &'a T; + + fn next(&mut self) -> Option<&'a T> { + match self.node_iter.next() { + Some(node) => Some(&node.value), + None => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::Treap; + + /// Returns `true` if all nodes in the tree are valid. + fn is_valid(tree: &Treap) -> bool { + tree.node_iter().all(|n| n.is_valid()) + } + + #[test] + fn len() { + let tree: Treap<_> = (1..4).collect(); + assert_eq!(tree.len(), 3); + } + + #[test] + fn contains() { + let tree: Treap<_> = (1..4).collect(); + assert!(tree.contains(&1)); + assert!(!tree.contains(&4)); + } + + #[test] + fn insert() { + let mut tree = Treap::new(); + // First insert succeeds + assert!(tree.insert(1)); + // Second insert fails + assert!(!tree.insert(1)); + } + + #[test] + fn remove() { + let mut tree: Treap<_> = (1..8).collect(); + // First remove succeeds + assert!(tree.remove(&4)); + // Second remove fails + assert!(!tree.remove(&4)); + } + + #[test] + fn sorted() { + let tree: Treap<_> = (1..8).rev().collect(); + assert!((1..8).eq(tree.iter().copied())); + } + + #[test] + fn valid() { + let mut tree: Treap<_> = (1..8).collect(); + assert!(is_valid(&tree)); + for x in 1..8 { + tree.remove(&x); + assert!(is_valid(&tree)); + } + } +} diff --git a/src/data_structures/trie.rs b/src/data_structures/trie.rs index cf861cb1793..ed05b0a509f 100644 --- a/src/data_structures/trie.rs +++ b/src/data_structures/trie.rs @@ -1,18 +1,29 @@ +//! This module provides a generic implementation of a Trie (prefix tree). +//! A Trie is a tree-like data structure that is commonly used to store sequences of keys +//! (such as strings, integers, or other iterable types) where each node represents one element +//! of the key, and values can be associated with full sequences. + use std::collections::HashMap; use std::hash::Hash; +/// A single node in the Trie structure, representing a key and an optional value. #[derive(Debug, Default)] struct Node { + /// A map of children nodes where each key maps to another `Node`. children: HashMap>, + /// The value associated with this node, if any. value: Option, } +/// A generic Trie (prefix tree) data structure that allows insertion and lookup +/// based on a sequence of keys. #[derive(Debug, Default)] pub struct Trie where Key: Default + Eq + Hash, Type: Default, { + /// The root node of the Trie, which does not hold a value itself. root: Node, } @@ -21,34 +32,47 @@ where Key: Default + Eq + Hash, Type: Default, { + /// Creates a new, empty `Trie`. + /// + /// # Returns + /// A `Trie` instance with an empty root node. pub fn new() -> Self { Self { root: Node::default(), } } + /// Inserts a value into the Trie, associating it with a sequence of keys. + /// + /// # Arguments + /// - `key`: An iterable sequence of keys (e.g., characters in a string or integers in a vector). + /// - `value`: The value to associate with the sequence of keys. pub fn insert(&mut self, key: impl IntoIterator, value: Type) where Key: Eq + Hash, { let mut node = &mut self.root; - for c in key.into_iter() { - node = node.children.entry(c).or_insert_with(Node::default); + for c in key { + node = node.children.entry(c).or_default(); } node.value = Some(value); } + /// Retrieves a reference to the value associated with a sequence of keys, if it exists. + /// + /// # Arguments + /// - `key`: An iterable sequence of keys (e.g., characters in a string or integers in a vector). + /// + /// # Returns + /// An `Option` containing a reference to the value if the sequence of keys exists in the Trie, + /// or `None` if it does not. pub fn get(&self, key: impl IntoIterator) -> Option<&Type> where Key: Eq + Hash, { let mut node = &self.root; - for c in key.into_iter() { - if node.children.contains_key(&c) { - node = node.children.get(&c).unwrap() - } else { - return None; - } + for c in key { + node = node.children.get(&c)?; } node.value.as_ref() } @@ -56,42 +80,76 @@ where #[cfg(test)] mod tests { - use super::*; #[test] - fn test_insertion() { + fn test_insertion_and_retrieval_with_strings() { let mut trie = Trie::new(); - assert_eq!(trie.get("".chars()), None); trie.insert("foo".chars(), 1); + assert_eq!(trie.get("foo".chars()), Some(&1)); trie.insert("foobar".chars(), 2); + assert_eq!(trie.get("foobar".chars()), Some(&2)); + assert_eq!(trie.get("foo".chars()), Some(&1)); + trie.insert("bar".chars(), 3); + assert_eq!(trie.get("bar".chars()), Some(&3)); + assert_eq!(trie.get("baz".chars()), None); + assert_eq!(trie.get("foobarbaz".chars()), None); + } + #[test] + fn test_insertion_and_retrieval_with_integers() { let mut trie = Trie::new(); - assert_eq!(trie.get(vec![1, 2, 3]), None); trie.insert(vec![1, 2, 3], 1); - trie.insert(vec![3, 4, 5], 2); + assert_eq!(trie.get(vec![1, 2, 3]), Some(&1)); + trie.insert(vec![1, 2, 3, 4, 5], 2); + assert_eq!(trie.get(vec![1, 2, 3, 4, 5]), Some(&2)); + assert_eq!(trie.get(vec![1, 2, 3]), Some(&1)); + trie.insert(vec![10, 20, 30], 3); + assert_eq!(trie.get(vec![10, 20, 30]), Some(&3)); + assert_eq!(trie.get(vec![4, 5, 6]), None); + assert_eq!(trie.get(vec![1, 2, 3, 4, 5, 6]), None); } #[test] - fn test_get() { + fn test_empty_trie() { + let trie: Trie = Trie::new(); + + assert_eq!(trie.get("foo".chars()), None); + assert_eq!(trie.get("".chars()), None); + } + + #[test] + fn test_insert_empty_key() { + let mut trie: Trie = Trie::new(); + + trie.insert("".chars(), 42); + assert_eq!(trie.get("".chars()), Some(&42)); + assert_eq!(trie.get("foo".chars()), None); + } + + #[test] + fn test_overlapping_keys() { let mut trie = Trie::new(); - trie.insert("foo".chars(), 1); - trie.insert("foobar".chars(), 2); - trie.insert("bar".chars(), 3); - trie.insert("baz".chars(), 4); - assert_eq!(trie.get("foo".chars()), Some(&1)); - assert_eq!(trie.get("food".chars()), None); + trie.insert("car".chars(), 1); + trie.insert("cart".chars(), 2); + trie.insert("carter".chars(), 3); + assert_eq!(trie.get("car".chars()), Some(&1)); + assert_eq!(trie.get("cart".chars()), Some(&2)); + assert_eq!(trie.get("carter".chars()), Some(&3)); + assert_eq!(trie.get("care".chars()), None); + } + #[test] + fn test_partial_match() { let mut trie = Trie::new(); - trie.insert(vec![1, 2, 3, 4], 1); - trie.insert(vec![42], 2); - trie.insert(vec![42, 6, 1000], 3); - trie.insert(vec![1, 2, 4, 16, 32], 4); - assert_eq!(trie.get(vec![42, 6, 1000]), Some(&3)); - assert_eq!(trie.get(vec![43, 44, 45]), None); + trie.insert("apple".chars(), 10); + assert_eq!(trie.get("app".chars()), None); + assert_eq!(trie.get("appl".chars()), None); + assert_eq!(trie.get("apple".chars()), Some(&10)); + assert_eq!(trie.get("applepie".chars()), None); } } diff --git a/src/data_structures/union_find.rs b/src/data_structures/union_find.rs index 6a4dbdd646a..b7cebd18c06 100644 --- a/src/data_structures/union_find.rs +++ b/src/data_structures/union_find.rs @@ -1,90 +1,228 @@ -/// UnionFind data structure -pub struct UnionFind { - id: Vec, - size: Vec, - count: usize, +//! A Union-Find (Disjoint Set) data structure implementation in Rust. +//! +//! The Union-Find data structure keeps track of elements partitioned into +//! disjoint (non-overlapping) sets. +//! It provides near-constant-time operations to add new sets, to find the +//! representative of a set, and to merge sets. + +use std::cmp::Ordering; +use std::collections::HashMap; +use std::fmt::Debug; +use std::hash::Hash; + +#[derive(Debug)] +pub struct UnionFind { + payloads: HashMap, // Maps values to their indices in the parent_links array. + parent_links: Vec, // Holds the parent pointers; root elements are their own parents. + sizes: Vec, // Holds the sizes of the sets. + count: usize, // Number of disjoint sets. } -impl UnionFind { - /// Creates a new UnionFind data structure with n elements - pub fn new(n: usize) -> Self { - let mut id = vec![0; n]; - let mut size = vec![0; n]; - for i in 0..n { - id[i] = i; - size[i] = 1; +impl UnionFind { + /// Creates an empty Union-Find structure with a specified capacity. + pub fn with_capacity(capacity: usize) -> Self { + Self { + parent_links: Vec::with_capacity(capacity), + sizes: Vec::with_capacity(capacity), + payloads: HashMap::with_capacity(capacity), + count: 0, + } + } + + /// Inserts a new item (disjoint set) into the data structure. + pub fn insert(&mut self, item: T) { + let key = self.payloads.len(); + self.parent_links.push(key); + self.sizes.push(1); + self.payloads.insert(item, key); + self.count += 1; + } + + /// Returns the root index of the set containing the given value, or `None` if it doesn't exist. + pub fn find(&mut self, value: &T) -> Option { + self.payloads + .get(value) + .copied() + .map(|key| self.find_by_key(key)) + } + + /// Unites the sets containing the two given values. Returns: + /// - `None` if either value hasn't been inserted, + /// - `Some(true)` if two disjoint sets have been merged, + /// - `Some(false)` if both elements were already in the same set. + pub fn union(&mut self, first_item: &T, sec_item: &T) -> Option { + let (first_root, sec_root) = (self.find(first_item), self.find(sec_item)); + match (first_root, sec_root) { + (Some(first_root), Some(sec_root)) => Some(self.union_by_key(first_root, sec_root)), + _ => None, } - Self { id, size, count: n } } - /// Returns the parent of the element - pub fn find(&mut self, x: usize) -> usize { - let mut x = x; - while x != self.id[x] { - x = self.id[x]; - // self.id[x] = self.id[self.id[x]]; // path compression + /// Finds the root of the set containing the element with the given index. + fn find_by_key(&mut self, key: usize) -> usize { + if self.parent_links[key] != key { + self.parent_links[key] = self.find_by_key(self.parent_links[key]); } - x + self.parent_links[key] } - /// Unions the sets containing x and y - pub fn union(&mut self, x: usize, y: usize) -> bool { - let x = self.find(x); - let y = self.find(y); - if x == y { + /// Unites the sets containing the two elements identified by their indices. + fn union_by_key(&mut self, first_key: usize, sec_key: usize) -> bool { + let (first_root, sec_root) = (self.find_by_key(first_key), self.find_by_key(sec_key)); + + if first_root == sec_root { return false; } - if self.size[x] < self.size[y] { - self.id[x] = y; - self.size[y] += self.size[x]; - } else { - self.id[y] = x; - self.size[x] += self.size[y]; + + match self.sizes[first_root].cmp(&self.sizes[sec_root]) { + Ordering::Less => { + self.parent_links[first_root] = sec_root; + self.sizes[sec_root] += self.sizes[first_root]; + } + _ => { + self.parent_links[sec_root] = first_root; + self.sizes[first_root] += self.sizes[sec_root]; + } } + self.count -= 1; true } - /// Checks if x and y are in the same set - pub fn is_same_set(&mut self, x: usize, y: usize) -> bool { - self.find(x) == self.find(y) + /// Checks if two items belong to the same set. + pub fn is_same_set(&mut self, first_item: &T, sec_item: &T) -> bool { + matches!((self.find(first_item), self.find(sec_item)), (Some(first_root), Some(sec_root)) if first_root == sec_root) } - /// Returns the number of disjoint sets + /// Returns the number of disjoint sets. pub fn count(&self) -> usize { self.count } } +impl Default for UnionFind { + fn default() -> Self { + Self { + parent_links: Vec::default(), + sizes: Vec::default(), + payloads: HashMap::default(), + count: 0, + } + } +} + +impl FromIterator for UnionFind { + /// Creates a new UnionFind data structure from an iterable of disjoint elements. + fn from_iter>(iter: I) -> Self { + let mut uf = UnionFind::default(); + for item in iter { + uf.insert(item); + } + uf + } +} + #[cfg(test)] mod tests { use super::*; #[test] fn test_union_find() { - let mut uf = UnionFind::new(10); - assert_eq!(uf.find(0), 0); - assert_eq!(uf.find(1), 1); - assert_eq!(uf.find(2), 2); - assert_eq!(uf.find(3), 3); - assert_eq!(uf.find(4), 4); - assert_eq!(uf.find(5), 5); - assert_eq!(uf.find(6), 6); - assert_eq!(uf.find(7), 7); - assert_eq!(uf.find(8), 8); - assert_eq!(uf.find(9), 9); - - assert_eq!(uf.union(0, 1), true); - assert_eq!(uf.union(1, 2), true); - assert_eq!(uf.union(2, 3), true); - assert_eq!(uf.union(3, 4), true); - assert_eq!(uf.union(4, 5), true); - assert_eq!(uf.union(5, 6), true); - assert_eq!(uf.union(6, 7), true); - assert_eq!(uf.union(7, 8), true); - assert_eq!(uf.union(8, 9), true); - assert_eq!(uf.union(9, 0), false); - - assert_eq!(1, uf.count()); + let mut uf = (0..10).collect::>(); + assert_eq!(uf.find(&0), Some(0)); + assert_eq!(uf.find(&1), Some(1)); + assert_eq!(uf.find(&2), Some(2)); + assert_eq!(uf.find(&3), Some(3)); + assert_eq!(uf.find(&4), Some(4)); + assert_eq!(uf.find(&5), Some(5)); + assert_eq!(uf.find(&6), Some(6)); + assert_eq!(uf.find(&7), Some(7)); + assert_eq!(uf.find(&8), Some(8)); + assert_eq!(uf.find(&9), Some(9)); + + assert!(!uf.is_same_set(&0, &1)); + assert!(!uf.is_same_set(&2, &9)); + assert_eq!(uf.count(), 10); + + assert_eq!(uf.union(&0, &1), Some(true)); + assert_eq!(uf.union(&1, &2), Some(true)); + assert_eq!(uf.union(&2, &3), Some(true)); + assert_eq!(uf.union(&0, &2), Some(false)); + assert_eq!(uf.union(&4, &5), Some(true)); + assert_eq!(uf.union(&5, &6), Some(true)); + assert_eq!(uf.union(&6, &7), Some(true)); + assert_eq!(uf.union(&7, &8), Some(true)); + assert_eq!(uf.union(&8, &9), Some(true)); + assert_eq!(uf.union(&7, &9), Some(false)); + + assert_ne!(uf.find(&0), uf.find(&9)); + assert_eq!(uf.find(&0), uf.find(&3)); + assert_eq!(uf.find(&4), uf.find(&9)); + assert!(uf.is_same_set(&0, &3)); + assert!(uf.is_same_set(&4, &9)); + assert!(!uf.is_same_set(&0, &9)); + assert_eq!(uf.count(), 2); + + assert_eq!(Some(true), uf.union(&3, &4)); + assert_eq!(uf.find(&0), uf.find(&9)); + assert_eq!(uf.count(), 1); + assert!(uf.is_same_set(&0, &9)); + + assert_eq!(None, uf.union(&0, &11)); + } + + #[test] + fn test_spanning_tree() { + let mut uf = UnionFind::from_iter(["A", "B", "C", "D", "E", "F", "G"]); + uf.union(&"A", &"B"); + uf.union(&"B", &"C"); + uf.union(&"A", &"D"); + uf.union(&"F", &"G"); + + assert_eq!(None, uf.union(&"A", &"W")); + + assert_eq!(uf.find(&"A"), uf.find(&"B")); + assert_eq!(uf.find(&"A"), uf.find(&"C")); + assert_eq!(uf.find(&"B"), uf.find(&"D")); + assert_ne!(uf.find(&"A"), uf.find(&"E")); + assert_ne!(uf.find(&"A"), uf.find(&"F")); + assert_eq!(uf.find(&"G"), uf.find(&"F")); + assert_ne!(uf.find(&"G"), uf.find(&"E")); + + assert!(uf.is_same_set(&"A", &"B")); + assert!(uf.is_same_set(&"A", &"C")); + assert!(uf.is_same_set(&"B", &"D")); + assert!(!uf.is_same_set(&"B", &"F")); + assert!(!uf.is_same_set(&"E", &"A")); + assert!(!uf.is_same_set(&"E", &"G")); + assert_eq!(uf.count(), 3); + } + + #[test] + fn test_with_capacity() { + let mut uf: UnionFind = UnionFind::with_capacity(5); + uf.insert(0); + uf.insert(1); + uf.insert(2); + uf.insert(3); + uf.insert(4); + + assert_eq!(uf.count(), 5); + + assert_eq!(uf.union(&0, &1), Some(true)); + assert!(uf.is_same_set(&0, &1)); + assert_eq!(uf.count(), 4); + + assert_eq!(uf.union(&2, &3), Some(true)); + assert!(uf.is_same_set(&2, &3)); + assert_eq!(uf.count(), 3); + + assert_eq!(uf.union(&0, &2), Some(true)); + assert!(uf.is_same_set(&0, &1)); + assert!(uf.is_same_set(&2, &3)); + assert!(uf.is_same_set(&0, &3)); + assert_eq!(uf.count(), 2); + + assert_eq!(None, uf.union(&0, &10)); } } diff --git a/src/data_structures/veb_tree.rs b/src/data_structures/veb_tree.rs new file mode 100644 index 00000000000..fe5fd7fc06d --- /dev/null +++ b/src/data_structures/veb_tree.rs @@ -0,0 +1,342 @@ +// This struct implements Van Emde Boas tree (VEB tree). It stores integers in range [0, U), where +// O is any integer that is a power of 2. It supports operations such as insert, search, +// predecessor, and successor in O(log(log(U))) time. The structure takes O(U) space. +pub struct VebTree { + size: u32, + child_size: u32, // Set to square root of size. Cache here to avoid recomputation. + min: u32, + max: u32, + summary: Option>, + cluster: Vec, +} + +impl VebTree { + /// Create a new, empty VEB tree. The tree will contain number of elements equal to size + /// rounded up to the nearest power of two. + pub fn new(size: u32) -> VebTree { + let rounded_size = size.next_power_of_two(); + let child_size = (size as f64).sqrt().ceil() as u32; + + let mut cluster = Vec::new(); + if rounded_size > 2 { + for _ in 0..rounded_size { + cluster.push(VebTree::new(child_size)); + } + } + + VebTree { + size: rounded_size, + child_size, + min: u32::MAX, + max: u32::MIN, + cluster, + summary: if rounded_size <= 2 { + None + } else { + Some(Box::new(VebTree::new(child_size))) + }, + } + } + + fn high(&self, value: u32) -> u32 { + value / self.child_size + } + + fn low(&self, value: u32) -> u32 { + value % self.child_size + } + + fn index(&self, cluster: u32, offset: u32) -> u32 { + cluster * self.child_size + offset + } + + pub fn min(&self) -> u32 { + self.min + } + + pub fn max(&self) -> u32 { + self.max + } + + pub fn iter(&self) -> VebTreeIter { + VebTreeIter::new(self) + } + + // A VEB tree is empty if the min is greater than the max. + pub fn empty(&self) -> bool { + self.min > self.max + } + + // Returns true if value is in the tree, false otherwise. + pub fn search(&self, value: u32) -> bool { + if self.empty() { + return false; + } else if value == self.min || value == self.max { + return true; + } else if value < self.min || value > self.max { + return false; + } + self.cluster[self.high(value) as usize].search(self.low(value)) + } + + fn insert_empty(&mut self, value: u32) { + assert!(self.empty(), "tree should be empty"); + self.min = value; + self.max = value; + } + + // Inserts value into the tree. + pub fn insert(&mut self, mut value: u32) { + assert!(value < self.size); + + if self.empty() { + self.insert_empty(value); + return; + } + + if value < self.min { + // If the new value is less than the current tree's min, set the min to the new value + // and insert the old min. + (value, self.min) = (self.min, value); + } + + if self.size > 2 { + // Non base case. The checks for min/max will handle trees of size 2. + let high = self.high(value); + let low = self.low(value); + if self.cluster[high as usize].empty() { + // If the cluster tree for the value is empty, we set the min/max of the tree to + // value and record that the cluster tree has an elements in the summary. + self.cluster[high as usize].insert_empty(low); + if let Some(summary) = self.summary.as_mut() { + summary.insert(high); + } + } else { + // If the cluster tree already has a value, the summary does not need to be + // updated. Recursively insert the value into the cluster tree. + self.cluster[high as usize].insert(low); + } + } + + if value > self.max { + self.max = value; + } + } + + // Returns the next greatest value(successor) in the tree after pred. Returns + // `None` if there is no successor. + pub fn succ(&self, pred: u32) -> Option { + if self.empty() { + return None; + } + + if self.size == 2 { + // Base case. If pred is 0, and 1 exists in the tree (max is set to 1), the successor + // is 1. + return if pred == 0 && self.max == 1 { + Some(1) + } else { + None + }; + } + + if pred < self.min { + // If the predecessor is less than the minimum of this tree, the successor is the min. + return Some(self.min); + } + + let low = self.low(pred); + let high = self.high(pred); + + if !self.cluster[high as usize].empty() && low < self.cluster[high as usize].max { + // The successor is within the same cluster as the predecessor + return Some(self.index(high, self.cluster[high as usize].succ(low).unwrap())); + }; + + // If we reach this point, the successor exists in a different cluster. We use the summary + // to efficiently query which cluster the successor lives in. If there is no successor + // cluster, return None. + let succ_cluster = self.summary.as_ref().unwrap().succ(high); + succ_cluster + .map(|succ_cluster| self.index(succ_cluster, self.cluster[succ_cluster as usize].min)) + } + + // Returns the next smallest value(predecessor) in the tree after succ. Returns + // `None` if there is no predecessor. pred() is almost a mirror of succ(). + // Differences are noted in comments. + pub fn pred(&self, succ: u32) -> Option { + if self.empty() { + return None; + } + + // base case. + if self.size == 2 { + return if succ == 1 && self.min == 0 { + Some(0) + } else { + None + }; + } + + if succ > self.max { + return Some(self.max); + } + + let low = self.low(succ); + let high = self.high(succ); + + if !self.cluster[high as usize].empty() && low > self.cluster[high as usize].min { + return Some(self.index(high, self.cluster[high as usize].pred(low).unwrap())); + }; + + // Find the cluster that has the predecessor. The successor will be that cluster's max. + let succ_cluster = self.summary.as_ref().unwrap().pred(high); + match succ_cluster { + Some(succ_cluster) => { + Some(self.index(succ_cluster, self.cluster[succ_cluster as usize].max)) + } + // Special case for pred() that does not exist in succ(). The current tree's min + // does not exist in a cluster. So if we cannot find a cluster that could have the + // predecessor, the predecessor could be the min of the current tree. + None => { + if succ > self.min { + Some(self.min) + } else { + None + } + } + } + } +} + +pub struct VebTreeIter<'a> { + tree: &'a VebTree, + curr: Option, +} + +impl<'a> VebTreeIter<'a> { + pub fn new(tree: &'a VebTree) -> VebTreeIter<'a> { + let curr = if tree.empty() { None } else { Some(tree.min) }; + VebTreeIter { tree, curr } + } +} + +impl Iterator for VebTreeIter<'_> { + type Item = u32; + + fn next(&mut self) -> Option { + let curr = self.curr; + curr?; + self.curr = self.tree.succ(curr.unwrap()); + curr + } +} + +#[cfg(test)] +mod test { + use super::VebTree; + use rand::{rngs::StdRng, Rng, SeedableRng}; + + fn test_veb_tree(size: u32, mut elements: Vec, exclude: Vec) { + // Insert elements + let mut tree = VebTree::new(size); + for element in elements.iter() { + tree.insert(*element); + } + + // Test search + for element in elements.iter() { + assert!(tree.search(*element)); + } + for element in exclude { + assert!(!tree.search(element)); + } + + // Test iterator and successor, and predecessor + elements.sort(); + elements.dedup(); + for (i, element) in tree.iter().enumerate() { + assert!(elements[i] == element); + } + for i in 1..elements.len() { + assert!(tree.succ(elements[i - 1]) == Some(elements[i])); + assert!(tree.pred(elements[i]) == Some(elements[i - 1])); + } + } + + #[test] + fn test_empty() { + test_veb_tree(16, Vec::new(), (0..16).collect()); + } + + #[test] + fn test_single() { + test_veb_tree(16, Vec::from([5]), (0..16).filter(|x| *x != 5).collect()); + } + + #[test] + fn test_two() { + test_veb_tree( + 16, + Vec::from([4, 9]), + (0..16).filter(|x| *x != 4 && *x != 9).collect(), + ); + } + + #[test] + fn test_repeat_insert() { + let mut tree = VebTree::new(16); + for _ in 0..5 { + tree.insert(10); + } + assert!(tree.search(10)); + let elements: Vec = (0..16).filter(|x| *x != 10).collect(); + for element in elements { + assert!(!tree.search(element)); + } + } + + #[test] + fn test_linear() { + test_veb_tree(16, (0..10).collect(), (10..16).collect()); + } + + fn test_full(size: u32) { + test_veb_tree(size, (0..size).collect(), Vec::new()); + } + + #[test] + fn test_full_small() { + test_full(8); + test_full(10); + test_full(16); + test_full(20); + test_full(32); + } + + #[test] + fn test_full_256() { + test_full(256); + } + + #[test] + fn test_10_256() { + let mut rng = StdRng::seed_from_u64(0); + let elements: Vec = (0..10).map(|_| rng.random_range(0..255)).collect(); + test_veb_tree(256, elements, Vec::new()); + } + + #[test] + fn test_100_256() { + let mut rng = StdRng::seed_from_u64(0); + let elements: Vec = (0..100).map(|_| rng.random_range(0..255)).collect(); + test_veb_tree(256, elements, Vec::new()); + } + + #[test] + fn test_100_300() { + let mut rng = StdRng::seed_from_u64(0); + let elements: Vec = (0..100).map(|_| rng.random_range(0..255)).collect(); + test_veb_tree(300, elements, Vec::new()); + } +} diff --git a/src/dynamic_programming/coin_change.rs b/src/dynamic_programming/coin_change.rs index 7c10c31d63e..2bfd573a9c0 100644 --- a/src/dynamic_programming/coin_change.rs +++ b/src/dynamic_programming/coin_change.rs @@ -1,70 +1,94 @@ -/// Coin change via Dynamic Programming +//! This module provides a solution to the coin change problem using dynamic programming. +//! The `coin_change` function calculates the fewest number of coins required to make up +//! a given amount using a specified set of coin denominations. +//! +//! The implementation leverages dynamic programming to build up solutions for smaller +//! amounts and combines them to solve for larger amounts. It ensures optimal substructure +//! and overlapping subproblems are efficiently utilized to achieve the solution. -/// coin_change(coins, amount) returns the fewest number of coins that need to make up that amount. -/// If that amount of money cannot be made up by any combination of the coins, return `None`. +//! # Complexity +//! - Time complexity: O(amount * coins.length) +//! - Space complexity: O(amount) + +/// Returns the fewest number of coins needed to make up the given amount using the provided coin denominations. +/// If the amount cannot be made up by any combination of the coins, returns `None`. +/// +/// # Arguments +/// * `coins` - A slice of coin denominations. +/// * `amount` - The total amount of money to be made up. +/// +/// # Returns +/// * `Option` - The minimum number of coins required to make up the amount, or `None` if it's not possible. /// -/// Arguments: -/// * `coins` - coins of different denominations -/// * `amount` - a total amount of money be made up. -/// Complexity -/// - time complexity: O(amount * coins.length), -/// - space complexity: O(amount), +/// # Complexity +/// * Time complexity: O(amount * coins.length) +/// * Space complexity: O(amount) pub fn coin_change(coins: &[usize], amount: usize) -> Option { - let mut dp = vec![None; amount + 1]; - dp[0] = Some(0); + let mut min_coins = vec![None; amount + 1]; + min_coins[0] = Some(0); - // Assume dp[i] is the fewest number of coins making up amount i, - // then for every coin in coins, dp[i] = min(dp[i - coin] + 1). - for i in 0..=amount { - for &coin in coins { - if i >= coin { - dp[i] = match dp[i - coin] { - Some(prev_coins) => match dp[i] { - Some(curr_coins) => Some(curr_coins.min(prev_coins + 1)), - None => Some(prev_coins + 1), - }, - None => dp[i], - }; - } - } - } + (0..=amount).for_each(|curr_amount| { + coins + .iter() + .filter(|&&coin| curr_amount >= coin) + .for_each(|&coin| { + if let Some(prev_min_coins) = min_coins[curr_amount - coin] { + min_coins[curr_amount] = Some( + min_coins[curr_amount].map_or(prev_min_coins + 1, |curr_min_coins| { + curr_min_coins.min(prev_min_coins + 1) + }), + ); + } + }); + }); - dp[amount] + min_coins[amount] } #[cfg(test)] mod tests { use super::*; - #[test] - fn basic() { - // 11 = 5 * 2 + 1 * 1 - let coins = vec![1, 2, 5]; - assert_eq!(Some(3), coin_change(&coins, 11)); - - // 119 = 11 * 10 + 7 * 1 + 2 * 1 - let coins = vec![2, 3, 5, 7, 11]; - assert_eq!(Some(12), coin_change(&coins, 119)); - } - - #[test] - fn coins_empty() { - let coins = vec![]; - assert_eq!(None, coin_change(&coins, 1)); - } - - #[test] - fn amount_zero() { - let coins = vec![1, 2, 3]; - assert_eq!(Some(0), coin_change(&coins, 0)); + macro_rules! coin_change_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (coins, amount, expected) = $test_case; + assert_eq!(expected, coin_change(&coins, amount)); + } + )* + } } - #[test] - fn fail_change() { - // 3 can't be change by 2. - let coins = vec![2]; - assert_eq!(None, coin_change(&coins, 3)); - let coins = vec![10, 20, 50, 100]; - assert_eq!(None, coin_change(&coins, 5)); + coin_change_tests! { + test_basic_case: (vec![1, 2, 5], 11, Some(3)), + test_multiple_denominations: (vec![2, 3, 5, 7, 11], 119, Some(12)), + test_empty_coins: (vec![], 1, None), + test_zero_amount: (vec![1, 2, 3], 0, Some(0)), + test_no_solution_small_coin: (vec![2], 3, None), + test_no_solution_large_coin: (vec![10, 20, 50, 100], 5, None), + test_single_coin_large_amount: (vec![1], 100, Some(100)), + test_large_amount_multiple_coins: (vec![1, 2, 5], 10000, Some(2000)), + test_no_combination_possible: (vec![3, 7], 5, None), + test_exact_combination: (vec![1, 3, 4], 6, Some(2)), + test_large_denomination_multiple_coins: (vec![10, 50, 100], 1000, Some(10)), + test_small_amount_not_possible: (vec![5, 10], 1, None), + test_non_divisible_amount: (vec![2], 3, None), + test_all_multiples: (vec![1, 2, 4, 8], 15, Some(4)), + test_large_amount_mixed_coins: (vec![1, 5, 10, 25], 999, Some(45)), + test_prime_coins_and_amount: (vec![2, 3, 5, 7], 17, Some(3)), + test_coins_larger_than_amount: (vec![5, 10, 20], 1, None), + test_repeating_denominations: (vec![1, 1, 1, 5], 8, Some(4)), + test_non_standard_denominations: (vec![1, 4, 6, 9], 15, Some(2)), + test_very_large_denominations: (vec![1000, 2000, 5000], 1, None), + test_large_amount_performance: (vec![1, 5, 10, 25, 50, 100, 200, 500], 9999, Some(29)), + test_powers_of_two: (vec![1, 2, 4, 8, 16, 32, 64], 127, Some(7)), + test_fibonacci_sequence: (vec![1, 2, 3, 5, 8, 13, 21, 34], 55, Some(2)), + test_mixed_small_large: (vec![1, 100, 1000, 10000], 11001, Some(3)), + test_impossible_combinations: (vec![2, 4, 6, 8], 7, None), + test_greedy_approach_does_not_work: (vec![1, 12, 20], 24, Some(2)), + test_zero_denominations_no_solution: (vec![0], 1, None), + test_zero_denominations_solution: (vec![0], 0, Some(0)), } } diff --git a/src/dynamic_programming/edit_distance.rs b/src/dynamic_programming/edit_distance.rs deleted file mode 100644 index 913d58c87ed..00000000000 --- a/src/dynamic_programming/edit_distance.rs +++ /dev/null @@ -1,119 +0,0 @@ -//! Compute the edit distance between two strings - -use std::cmp::min; - -/// edit_distance(str_a, str_b) returns the edit distance between the two -/// strings This edit distance is defined as being 1 point per insertion, -/// substitution, or deletion which must be made to make the strings equal. -/// -/// This function iterates over the bytes in the string, so it may not behave -/// entirely as expected for non-ASCII strings. -/// -/// # Complexity -/// -/// - time complexity: O(nm), -/// - space complexity: O(nm), -/// -/// where n and m are lengths of `str_a` and `str_b` -pub fn edit_distance(str_a: &str, str_b: &str) -> u32 { - // distances[i][j] = distance between a[..i] and b[..j] - let mut distances = vec![vec![0; str_b.len() + 1]; str_a.len() + 1]; - // Initialize cases in which one string is empty - for j in 0..=str_b.len() { - distances[0][j] = j as u32; - } - for (i, item) in distances.iter_mut().enumerate() { - item[0] = i as u32; - } - for i in 1..=str_a.len() { - for j in 1..=str_b.len() { - distances[i][j] = min(distances[i - 1][j] + 1, distances[i][j - 1] + 1); - if str_a.as_bytes()[i - 1] == str_b.as_bytes()[j - 1] { - distances[i][j] = min(distances[i][j], distances[i - 1][j - 1]); - } else { - distances[i][j] = min(distances[i][j], distances[i - 1][j - 1] + 1); - } - } - } - distances[str_a.len()][str_b.len()] -} - -/// The space efficient version of the above algorithm. -/// -/// Instead of storing the `m * n` matrix expicitly, only one row (of length `n`) is stored. -/// It keeps overwriting itself based on its previous values with the help of two scalars, -/// gradually reaching the last row. Then, the score is `matrix[n]`. -/// -/// # Complexity -/// -/// - time complexity: O(nm), -/// - space complexity: O(n), -/// -/// where n and m are lengths of `str_a` and `str_b` -pub fn edit_distance_se(str_a: &str, str_b: &str) -> u32 { - let (str_a, str_b) = (str_a.as_bytes(), str_b.as_bytes()); - let (m, n) = (str_a.len(), str_b.len()); - let mut distances: Vec = vec![0; n + 1]; // the dynamic programming matrix (only 1 row stored) - let mut s: u32; // distances[i - 1][j - 1] or distances[i - 1][j] - let mut c: u32; // distances[i][j - 1] or distances[i][j] - let mut char_a: u8; // str_a[i - 1] the i-th character in str_a; only needs to be computed once per row - let mut char_b: u8; // str_b[j - 1] the j-th character in str_b - - // 0th row - for (j, v) in distances.iter_mut().enumerate().take(n + 1).skip(1) { - *v = j as u32; - } - // rows 1 to m - for i in 1..=m { - s = (i - 1) as u32; - c = i as u32; - char_a = str_a[i - 1]; - for j in 1..=n { - // c is distances[i][j-1] and s is distances[i-1][j-1] at the beginning of each round of iteration - char_b = str_b[j - 1]; - c = min( - s + if char_a == char_b { 0 } else { 1 }, - min(c + 1, distances[j] + 1), - ); - // c is updated to distances[i][j], and will thus become distances[i][j-1] for the next cell - s = distances[j]; // here distances[j] means distances[i-1][j] because it has not been overwritten yet - // s is updated to distances[i-1][j], and will thus become distances[i-1][j-1] for the next cell - distances[j] = c; // now distances[j] is updated to distances[i][j], and will thus become distances[i-1][j] for the next ROW - } - } - - distances[n] -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn equal_strings() { - assert_eq!(0, edit_distance("Hello, world!", "Hello, world!")); - assert_eq!(0, edit_distance_se("Hello, world!", "Hello, world!")); - assert_eq!(0, edit_distance("Test_Case_#1", "Test_Case_#1")); - assert_eq!(0, edit_distance_se("Test_Case_#1", "Test_Case_#1")); - } - - #[test] - fn one_edit_difference() { - assert_eq!(1, edit_distance("Hello, world!", "Hell, world!")); - assert_eq!(1, edit_distance("Test_Case_#1", "Test_Case_#2")); - assert_eq!(1, edit_distance("Test_Case_#1", "Test_Case_#10")); - assert_eq!(1, edit_distance_se("Hello, world!", "Hell, world!")); - assert_eq!(1, edit_distance_se("Test_Case_#1", "Test_Case_#2")); - assert_eq!(1, edit_distance_se("Test_Case_#1", "Test_Case_#10")); - } - - #[test] - fn several_differences() { - assert_eq!(2, edit_distance("My Cat", "My Case")); - assert_eq!(7, edit_distance("Hello, world!", "Goodbye, world!")); - assert_eq!(6, edit_distance("Test_Case_#3", "Case #3")); - assert_eq!(2, edit_distance_se("My Cat", "My Case")); - assert_eq!(7, edit_distance_se("Hello, world!", "Goodbye, world!")); - assert_eq!(6, edit_distance_se("Test_Case_#3", "Case #3")); - } -} diff --git a/src/dynamic_programming/egg_dropping.rs b/src/dynamic_programming/egg_dropping.rs index 1a403637ec5..ab24494c014 100644 --- a/src/dynamic_programming/egg_dropping.rs +++ b/src/dynamic_programming/egg_dropping.rs @@ -1,91 +1,84 @@ -/// # Egg Dropping Puzzle +//! This module contains the `egg_drop` function, which determines the minimum number of egg droppings +//! required to find the highest floor from which an egg can be dropped without breaking. It also includes +//! tests for the function using various test cases, including edge cases. -/// `egg_drop(eggs, floors)` returns the least number of egg droppings -/// required to determine the highest floor from which an egg will not -/// break upon dropping +/// Returns the least number of egg droppings required to determine the highest floor from which an egg will not break upon dropping. /// -/// Assumptions: n > 0 -pub fn egg_drop(eggs: u32, floors: u32) -> u32 { - assert!(eggs > 0); - - // Explicity handle edge cases (optional) - if eggs == 1 || floors == 0 || floors == 1 { - return floors; - } - - let eggs_index = eggs as usize; - let floors_index = floors as usize; - - // Store solutions to subproblems in 2D Vec, - // where egg_drops[i][j] represents the solution to the egg dropping - // problem with i eggs and j floors - let mut egg_drops: Vec> = vec![vec![0; floors_index + 1]; eggs_index + 1]; - - // Assign solutions for egg_drop(n, 0) = 0, egg_drop(n, 1) = 1 - for egg_drop in egg_drops.iter_mut().skip(1) { - egg_drop[0] = 0; - egg_drop[1] = 1; - } - - // Assign solutions to egg_drop(1, k) = k - for j in 1..=floors_index { - egg_drops[1][j] = j as u32; +/// # Arguments +/// +/// * `eggs` - The number of eggs available. +/// * `floors` - The number of floors in the building. +/// +/// # Returns +/// +/// * `Some(usize)` - The minimum number of drops required if the number of eggs is greater than 0. +/// * `None` - If the number of eggs is 0. +pub fn egg_drop(eggs: usize, floors: usize) -> Option { + if eggs == 0 { + return None; } - // Complete solutions vector using optimal substructure property - for i in 2..=eggs_index { - for j in 2..=floors_index { - egg_drops[i][j] = std::u32::MAX; - - for k in 1..=j { - let res = 1 + std::cmp::max(egg_drops[i - 1][k - 1], egg_drops[i][j - k]); - - if res < egg_drops[i][j] { - egg_drops[i][j] = res; - } - } - } + if eggs == 1 || floors == 0 || floors == 1 { + return Some(floors); } - egg_drops[eggs_index][floors_index] + // Create a 2D vector to store solutions to subproblems + let mut egg_drops: Vec> = vec![vec![0; floors + 1]; eggs + 1]; + + // Base cases: 0 floors -> 0 drops, 1 floor -> 1 drop + (1..=eggs).for_each(|i| { + egg_drops[i][1] = 1; + }); + + // Base case: 1 egg -> k drops for k floors + (1..=floors).for_each(|j| { + egg_drops[1][j] = j; + }); + + // Fill the table using the optimal substructure property + (2..=eggs).for_each(|i| { + (2..=floors).for_each(|j| { + egg_drops[i][j] = (1..=j) + .map(|k| 1 + std::cmp::max(egg_drops[i - 1][k - 1], egg_drops[i][j - k])) + .min() + .unwrap(); + }); + }); + + Some(egg_drops[eggs][floors]) } #[cfg(test)] mod tests { - use super::egg_drop; - - #[test] - fn zero_floors() { - assert_eq!(egg_drop(5, 0), 0); - } - - #[test] - fn one_egg() { - assert_eq!(egg_drop(1, 8), 8); - } - - #[test] - fn eggs2_floors2() { - assert_eq!(egg_drop(2, 2), 2); - } - - #[test] - fn eggs3_floors5() { - assert_eq!(egg_drop(3, 5), 3); - } - - #[test] - fn eggs2_floors10() { - assert_eq!(egg_drop(2, 10), 4); - } - - #[test] - fn eggs2_floors36() { - assert_eq!(egg_drop(2, 36), 8); + use super::*; + + macro_rules! egg_drop_tests { + ($($name:ident: $test_cases:expr,)*) => { + $( + #[test] + fn $name() { + let (eggs, floors, expected) = $test_cases; + assert_eq!(egg_drop(eggs, floors), expected); + } + )* + } } - #[test] - fn large_floors() { - assert_eq!(egg_drop(2, 100), 14); + egg_drop_tests! { + test_no_floors: (5, 0, Some(0)), + test_one_egg_multiple_floors: (1, 8, Some(8)), + test_multiple_eggs_one_floor: (5, 1, Some(1)), + test_two_eggs_two_floors: (2, 2, Some(2)), + test_three_eggs_five_floors: (3, 5, Some(3)), + test_two_eggs_ten_floors: (2, 10, Some(4)), + test_two_eggs_thirty_six_floors: (2, 36, Some(8)), + test_many_eggs_one_floor: (100, 1, Some(1)), + test_many_eggs_few_floors: (100, 5, Some(3)), + test_few_eggs_many_floors: (2, 1000, Some(45)), + test_zero_eggs: (0, 10, None::), + test_no_eggs_no_floors: (0, 0, None::), + test_one_egg_no_floors: (1, 0, Some(0)), + test_one_egg_one_floor: (1, 1, Some(1)), + test_maximum_floors_one_egg: (1, usize::MAX, Some(usize::MAX)), } } diff --git a/src/dynamic_programming/fibonacci.rs b/src/dynamic_programming/fibonacci.rs index b53e6be0b57..f1a55ce77f1 100644 --- a/src/dynamic_programming/fibonacci.rs +++ b/src/dynamic_programming/fibonacci.rs @@ -123,12 +123,169 @@ fn _memoized_fibonacci(n: u32, cache: &mut HashMap) -> u128 { *f } +/// matrix_fibonacci(n) returns the nth fibonacci number +/// This function uses the definition of Fibonacci where: +/// F(0) = 0, F(1) = 1 and F(n+1) = F(n) + F(n-1) for n>0 +/// +/// Matrix formula: +/// [F(n + 2)] = [1, 1] * [F(n + 1)] +/// [F(n + 1)] [1, 0] [F(n) ] +/// +/// Warning: This will overflow the 128-bit unsigned integer at n=186 +pub fn matrix_fibonacci(n: u32) -> u128 { + let multiplier: Vec> = vec![vec![1, 1], vec![1, 0]]; + + let multiplier = matrix_power(&multiplier, n); + let initial_fib_matrix: Vec> = vec![vec![1], vec![0]]; + + let res = matrix_multiply(&multiplier, &initial_fib_matrix); + + res[1][0] +} + +fn matrix_power(base: &Vec>, power: u32) -> Vec> { + let identity_matrix: Vec> = vec![vec![1, 0], vec![0, 1]]; + + vec![base; power as usize] + .iter() + .fold(identity_matrix, |acc, x| matrix_multiply(&acc, x)) +} + +// Copied from matrix_ops since u128 is required instead of i32 +#[allow(clippy::needless_range_loop)] +fn matrix_multiply(multiplier: &[Vec], multiplicand: &[Vec]) -> Vec> { + // Multiply two matching matrices. The multiplier needs to have the same amount + // of columns as the multiplicand has rows. + let mut result: Vec> = vec![]; + let mut temp; + // Using variable to compare lengths of rows in multiplicand later + let row_right_length = multiplicand[0].len(); + for row_left in 0..multiplier.len() { + if multiplier[row_left].len() != multiplicand.len() { + panic!("Matrix dimensions do not match"); + } + result.push(vec![]); + for column_right in 0..multiplicand[0].len() { + temp = 0; + for row_right in 0..multiplicand.len() { + if row_right_length != multiplicand[row_right].len() { + // If row is longer than a previous row cancel operation with error + panic!("Matrix dimensions do not match"); + } + temp += multiplier[row_left][row_right] * multiplicand[row_right][column_right]; + } + result[row_left].push(temp); + } + } + result +} + +/// Binary lifting fibonacci +/// +/// Following properties of F(n) could be deduced from the matrix formula above: +/// +/// F(2n) = F(n) * (2F(n+1) - F(n)) +/// F(2n+1) = F(n+1)^2 + F(n)^2 +/// +/// Therefore F(n) and F(n+1) can be derived from F(n>>1) and F(n>>1 + 1), which +/// has a smaller constant in both time and space compared to matrix fibonacci. +pub fn binary_lifting_fibonacci(n: u32) -> u128 { + // the state always stores F(k), F(k+1) for some k, initially F(0), F(1) + let mut state = (0u128, 1u128); + + for i in (0..u32::BITS - n.leading_zeros()).rev() { + // compute F(2k), F(2k+1) from F(k), F(k+1) + state = ( + state.0 * (2 * state.1 - state.0), + state.0 * state.0 + state.1 * state.1, + ); + if n & (1 << i) != 0 { + state = (state.1, state.0 + state.1); + } + } + + state.0 +} + +/// nth_fibonacci_number_modulo_m(n, m) returns the nth fibonacci number modulo the specified m +/// i.e. F(n) % m +pub fn nth_fibonacci_number_modulo_m(n: i64, m: i64) -> i128 { + let (length, pisano_sequence) = get_pisano_sequence_and_period(m); + + let remainder = n % length as i64; + pisano_sequence[remainder as usize].to_owned() +} + +/// get_pisano_sequence_and_period(m) returns the Pisano Sequence and period for the specified integer m. +/// The pisano period is the period with which the sequence of Fibonacci numbers taken modulo m repeats. +/// The pisano sequence is the numbers in pisano period. +fn get_pisano_sequence_and_period(m: i64) -> (i128, Vec) { + let mut a = 0; + let mut b = 1; + let mut length: i128 = 0; + let mut pisano_sequence: Vec = vec![a, b]; + + // Iterating through all the fib numbers to get the sequence + for _i in 0..=(m * m) { + let c = (a + b) % m as i128; + + // adding number into the sequence + pisano_sequence.push(c); + + a = b; + b = c; + + if a == 0 && b == 1 { + // Remove the last two elements from the sequence + // This is a less elegant way to do it. + pisano_sequence.pop(); + pisano_sequence.pop(); + length = pisano_sequence.len() as i128; + break; + } + } + + (length, pisano_sequence) +} + +/// last_digit_of_the_sum_of_nth_fibonacci_number(n) returns the last digit of the sum of n fibonacci numbers. +/// The function uses the definition of Fibonacci where: +/// F(0) = 0, F(1) = 1 and F(n+1) = F(n) + F(n-1) for n > 2 +/// +/// The sum of the Fibonacci numbers are: +/// F(0) + F(1) + F(2) + ... + F(n) +pub fn last_digit_of_the_sum_of_nth_fibonacci_number(n: i64) -> i64 { + if n < 2 { + return n; + } + + // the pisano period of mod 10 is 60 + let n = ((n + 2) % 60) as usize; + let mut fib = vec![0; n + 1]; + fib[0] = 0; + fib[1] = 1; + + for i in 2..=n { + fib[i] = (fib[i - 1] % 10 + fib[i - 2] % 10) % 10; + } + + if fib[n] == 0 { + return 9; + } + + fib[n] % 10 - 1 +} + #[cfg(test)] mod tests { + use super::binary_lifting_fibonacci; use super::classical_fibonacci; use super::fibonacci; + use super::last_digit_of_the_sum_of_nth_fibonacci_number; use super::logarithmic_fibonacci; + use super::matrix_fibonacci; use super::memoized_fibonacci; + use super::nth_fibonacci_number_modulo_m; use super::recursive_fibonacci; #[test] @@ -199,7 +356,7 @@ mod tests { } #[test] - /// Check that the itterative and recursive fibonacci + /// Check that the iterative and recursive fibonacci /// produce the same value. Both are combinatorial ( F(0) = F(1) = 1 ) fn test_iterative_and_recursive_equivalence() { assert_eq!(fibonacci(0), recursive_fibonacci(0)); @@ -250,4 +407,68 @@ mod tests { 127127879743834334146972278486287885163 ); } + + #[test] + fn test_matrix_fibonacci() { + assert_eq!(matrix_fibonacci(0), 0); + assert_eq!(matrix_fibonacci(1), 1); + assert_eq!(matrix_fibonacci(2), 1); + assert_eq!(matrix_fibonacci(3), 2); + assert_eq!(matrix_fibonacci(4), 3); + assert_eq!(matrix_fibonacci(5), 5); + assert_eq!(matrix_fibonacci(10), 55); + assert_eq!(matrix_fibonacci(20), 6765); + assert_eq!(matrix_fibonacci(21), 10946); + assert_eq!(matrix_fibonacci(100), 354224848179261915075); + assert_eq!( + matrix_fibonacci(184), + 127127879743834334146972278486287885163 + ); + } + + #[test] + fn test_binary_lifting_fibonacci() { + assert_eq!(binary_lifting_fibonacci(0), 0); + assert_eq!(binary_lifting_fibonacci(1), 1); + assert_eq!(binary_lifting_fibonacci(2), 1); + assert_eq!(binary_lifting_fibonacci(3), 2); + assert_eq!(binary_lifting_fibonacci(4), 3); + assert_eq!(binary_lifting_fibonacci(5), 5); + assert_eq!(binary_lifting_fibonacci(10), 55); + assert_eq!(binary_lifting_fibonacci(20), 6765); + assert_eq!(binary_lifting_fibonacci(21), 10946); + assert_eq!(binary_lifting_fibonacci(100), 354224848179261915075); + assert_eq!( + binary_lifting_fibonacci(184), + 127127879743834334146972278486287885163 + ); + } + + #[test] + fn test_nth_fibonacci_number_modulo_m() { + assert_eq!(nth_fibonacci_number_modulo_m(5, 10), 5); + assert_eq!(nth_fibonacci_number_modulo_m(10, 7), 6); + assert_eq!(nth_fibonacci_number_modulo_m(20, 100), 65); + assert_eq!(nth_fibonacci_number_modulo_m(1, 5), 1); + assert_eq!(nth_fibonacci_number_modulo_m(0, 15), 0); + assert_eq!(nth_fibonacci_number_modulo_m(50, 1000), 25); + assert_eq!(nth_fibonacci_number_modulo_m(100, 37), 7); + assert_eq!(nth_fibonacci_number_modulo_m(15, 2), 0); + assert_eq!(nth_fibonacci_number_modulo_m(8, 1_000_000), 21); + assert_eq!(nth_fibonacci_number_modulo_m(1000, 997), 996); + assert_eq!(nth_fibonacci_number_modulo_m(200, 123), 0); + } + + #[test] + fn test_last_digit_of_the_sum_of_nth_fibonacci_number() { + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(0), 0); + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(1), 1); + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(2), 2); + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(3), 4); + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(4), 7); + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(5), 2); + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(25), 7); + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(50), 8); + assert_eq!(last_digit_of_the_sum_of_nth_fibonacci_number(100), 5); + } } diff --git a/src/dynamic_programming/fractional_knapsack.rs b/src/dynamic_programming/fractional_knapsack.rs new file mode 100644 index 00000000000..b05a1b18f2f --- /dev/null +++ b/src/dynamic_programming/fractional_knapsack.rs @@ -0,0 +1,98 @@ +pub fn fractional_knapsack(mut capacity: f64, weights: Vec, values: Vec) -> f64 { + // vector of tuple of weights and their value/weight ratio + let mut weights: Vec<(f64, f64)> = weights + .iter() + .zip(values.iter()) + .map(|(&w, &v)| (w, v / w)) + .collect(); + + // sort in decreasing order by value/weight ratio + weights.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).expect("Encountered NaN")); + dbg!(&weights); + + // value to compute + let mut knapsack_value: f64 = 0.0; + + // iterate through our vector. + for w in weights { + // w.0 is weight and w.1 value/weight ratio + if w.0 < capacity { + capacity -= w.0; // our sack is filling + knapsack_value += w.0 * w.1; + dbg!(&w.0, &knapsack_value); + } else { + // Multiply with capacity and not w.0 + dbg!(&w.0, &knapsack_value); + knapsack_value += capacity * w.1; + break; + } + } + + knapsack_value +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test() { + let capacity = 50.0; + let values = vec![60.0, 100.0, 120.0]; + let weights = vec![10.0, 20.0, 30.0]; + assert_eq!(fractional_knapsack(capacity, weights, values), 240.0); + } + + #[test] + fn test2() { + let capacity = 60.0; + let values = vec![280.0, 100.0, 120.0, 120.0]; + let weights = vec![40.0, 10.0, 20.0, 24.0]; + assert_eq!(fractional_knapsack(capacity, weights, values), 440.0); + } + + #[test] + fn test3() { + let capacity = 50.0; + let values = vec![60.0, 100.0, 120.0]; + let weights = vec![20.0, 50.0, 30.0]; + assert_eq!(fractional_knapsack(capacity, weights, values), 180.0); + } + + #[test] + fn test4() { + let capacity = 60.0; + let values = vec![30.0, 40.0, 45.0, 77.0, 90.0]; + let weights = vec![5.0, 10.0, 15.0, 22.0, 25.0]; + assert_eq!(fractional_knapsack(capacity, weights, values), 230.0); + } + + #[test] + fn test5() { + let capacity = 10.0; + let values = vec![500.0]; + let weights = vec![30.0]; + assert_eq!( + format!("{:.2}", fractional_knapsack(capacity, weights, values)), + String::from("166.67") + ); + } + + #[test] + fn test6() { + let capacity = 36.0; + let values = vec![25.0, 25.0, 25.0, 6.0, 2.0]; + let weights = vec![10.0, 10.0, 10.0, 4.0, 2.0]; + assert_eq!(fractional_knapsack(capacity, weights, values), 83.0); + } + + #[test] + #[should_panic] + fn test_nan() { + let capacity = 36.0; + // 2nd element is NaN + let values = vec![25.0, f64::NAN, 25.0, 6.0, 2.0]; + let weights = vec![10.0, 10.0, 10.0, 4.0, 2.0]; + assert_eq!(fractional_knapsack(capacity, weights, values), 83.0); + } +} diff --git a/src/dynamic_programming/is_subsequence.rs b/src/dynamic_programming/is_subsequence.rs index 689c19c8c63..22b43c387b1 100644 --- a/src/dynamic_programming/is_subsequence.rs +++ b/src/dynamic_programming/is_subsequence.rs @@ -1,39 +1,71 @@ -// Given two strings str1 and str2, return true if str1 is a subsequence of str2, or false otherwise. -// A subsequence of a string is a new string that is formed from the original string -// by deleting some (can be none) of the characters without disturbing the relative -// positions of the remaining characters. -// (i.e., "ace" is a subsequence of "abcde" while "aec" is not). -pub fn is_subsequence(str1: String, str2: String) -> bool { - let mut it1 = 0; - let mut it2 = 0; +//! A module for checking if one string is a subsequence of another string. +//! +//! A subsequence is formed by deleting some (can be none) of the characters +//! from the original string without disturbing the relative positions of the +//! remaining characters. This module provides a function to determine if +//! a given string is a subsequence of another string. - let byte1 = str1.as_bytes(); - let byte2 = str2.as_bytes(); +/// Checks if `sub` is a subsequence of `main`. +/// +/// # Arguments +/// +/// * `sub` - A string slice that may be a subsequence. +/// * `main` - A string slice that is checked against. +/// +/// # Returns +/// +/// Returns `true` if `sub` is a subsequence of `main`, otherwise returns `false`. +pub fn is_subsequence(sub: &str, main: &str) -> bool { + let mut sub_iter = sub.chars().peekable(); + let mut main_iter = main.chars(); - while it1 < str1.len() && it2 < str2.len() { - if byte1[it1] == byte2[it2] { - it1 += 1; + while let Some(&sub_char) = sub_iter.peek() { + match main_iter.next() { + Some(main_char) if main_char == sub_char => { + sub_iter.next(); + } + None => return false, + _ => {} } - - it2 += 1; } - it1 == str1.len() + true } #[cfg(test)] mod tests { use super::*; - #[test] - fn test() { - assert_eq!( - is_subsequence(String::from("abc"), String::from("ahbgdc")), - true - ); - assert_eq!( - is_subsequence(String::from("axc"), String::from("ahbgdc")), - false - ); + macro_rules! subsequence_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (sub, main, expected) = $test_case; + assert_eq!(is_subsequence(sub, main), expected); + } + )* + }; + } + + subsequence_tests! { + test_empty_subsequence: ("", "ahbgdc", true), + test_empty_strings: ("", "", true), + test_non_empty_sub_empty_main: ("abc", "", false), + test_subsequence_found: ("abc", "ahbgdc", true), + test_subsequence_not_found: ("axc", "ahbgdc", false), + test_longer_sub: ("abcd", "abc", false), + test_single_character_match: ("a", "ahbgdc", true), + test_single_character_not_match: ("x", "ahbgdc", false), + test_subsequence_at_start: ("abc", "abchello", true), + test_subsequence_at_end: ("cde", "abcde", true), + test_same_characters: ("aaa", "aaaaa", true), + test_interspersed_subsequence: ("ace", "abcde", true), + test_different_chars_in_subsequence: ("aceg", "abcdef", false), + test_single_character_in_main_not_match: ("a", "b", false), + test_single_character_in_main_match: ("b", "b", true), + test_subsequence_with_special_chars: ("a1!c", "a1!bcd", true), + test_case_sensitive: ("aBc", "abc", false), + test_subsequence_with_whitespace: ("hello world", "h e l l o w o r l d", true), } } diff --git a/src/dynamic_programming/knapsack.rs b/src/dynamic_programming/knapsack.rs index b9edf1bda88..36876a15bd8 100644 --- a/src/dynamic_programming/knapsack.rs +++ b/src/dynamic_programming/knapsack.rs @@ -1,148 +1,363 @@ -//! Solves the knapsack problem -use std::cmp::max; +//! This module provides functionality to solve the knapsack problem using dynamic programming. +//! It includes structures for items and solutions, and functions to compute the optimal solution. -/// knapsack_table(w, weights, values) returns the knapsack table (`n`, `m`) with maximum values, where `n` is number of items -/// -/// Arguments: -/// * `w` - knapsack capacity -/// * `weights` - set of weights for each item -/// * `values` - set of values for each item -fn knapsack_table(w: &usize, weights: &[usize], values: &[usize]) -> Vec> { - // Initialize `n` - number of items - let n: usize = weights.len(); - // Initialize `m` - // m[i, w] - the maximum value that can be attained with weight less that or equal to `w` using items up to `i` - let mut m: Vec> = vec![vec![0; w + 1]; n + 1]; +use std::cmp::Ordering; - for i in 0..=n { - for j in 0..=*w { - // m[i, j] compiled according to the following rule: - if i == 0 || j == 0 { - m[i][j] = 0; - } else if weights[i - 1] <= j { - // If `i` is in the knapsack - // Then m[i, j] is equal to the maximum value of the knapsack, - // where the weight `j` is reduced by the weight of the `i-th` item and the set of admissible items plus the value `k` - m[i][j] = max(values[i - 1] + m[i - 1][j - weights[i - 1]], m[i - 1][j]); - } else { - // If the item `i` did not get into the knapsack - // Then m[i, j] is equal to the maximum cost of a knapsack with the same capacity and a set of admissible items - m[i][j] = m[i - 1][j] - } - } - } - m +/// Represents an item with a weight and a value. +#[derive(Debug, PartialEq, Eq)] +pub struct Item { + weight: usize, + value: usize, +} + +/// Represents the solution to the knapsack problem. +#[derive(Debug, PartialEq, Eq)] +pub struct KnapsackSolution { + /// The optimal profit obtained. + optimal_profit: usize, + /// The total weight of items included in the solution. + total_weight: usize, + /// The indices of items included in the solution. Indices might not be unique. + item_indices: Vec, } -/// knapsack_items(weights, m, i, j) returns the indices of the items of the optimal knapsack (from 1 to `n`) +/// Solves the knapsack problem and returns the optimal profit, total weight, and indices of items included. /// -/// Arguments: -/// * `weights` - set of weights for each item -/// * `m` - knapsack table with maximum values -/// * `i` - include items 1 through `i` in knapsack (for the initial value, use `n`) -/// * `j` - maximum weight of the knapsack -fn knapsack_items(weights: &[usize], m: &[Vec], i: usize, j: usize) -> Vec { - if i == 0 { - return vec![]; - } - if m[i][j] > m[i - 1][j] { - let mut knap: Vec = knapsack_items(weights, m, i - 1, j - weights[i - 1]); - knap.push(i); - knap - } else { - knapsack_items(weights, m, i - 1, j) +/// # Arguments: +/// * `capacity` - The maximum weight capacity of the knapsack. +/// * `items` - A vector of `Item` structs, each representing an item with weight and value. +/// +/// # Returns: +/// A `KnapsackSolution` struct containing: +/// - `optimal_profit` - The maximum profit achievable with the given capacity and items. +/// - `total_weight` - The total weight of items included in the solution. +/// - `item_indices` - Indices of items included in the solution. Indices might not be unique. +/// +/// # Note: +/// The indices of items in the solution might not be unique. +/// This function assumes that `items` is non-empty. +/// +/// # Complexity: +/// - Time complexity: O(num_items * capacity) +/// - Space complexity: O(num_items * capacity) +/// +/// where `num_items` is the number of items and `capacity` is the knapsack capacity. +pub fn knapsack(capacity: usize, items: Vec) -> KnapsackSolution { + let num_items = items.len(); + let item_weights: Vec = items.iter().map(|item| item.weight).collect(); + let item_values: Vec = items.iter().map(|item| item.value).collect(); + + let knapsack_matrix = generate_knapsack_matrix(capacity, &item_weights, &item_values); + let items_included = + retrieve_knapsack_items(&item_weights, &knapsack_matrix, num_items, capacity); + + let total_weight = items_included + .iter() + .map(|&index| item_weights[index - 1]) + .sum(); + + KnapsackSolution { + optimal_profit: knapsack_matrix[num_items][capacity], + total_weight, + item_indices: items_included, } } -/// knapsack(w, weights, values) returns the tuple where first value is `optimal profit`, -/// second value is `knapsack optimal weight` and the last value is `indices of items`, that we got (from 1 to `n`) +/// Generates the knapsack matrix (`num_items`, `capacity`) with maximum values. /// -/// Arguments: -/// * `w` - knapsack capacity -/// * `weights` - set of weights for each item -/// * `values` - set of values for each item +/// # Arguments: +/// * `capacity` - knapsack capacity +/// * `item_weights` - weights of each item +/// * `item_values` - values of each item +fn generate_knapsack_matrix( + capacity: usize, + item_weights: &[usize], + item_values: &[usize], +) -> Vec> { + let num_items = item_weights.len(); + + (0..=num_items).fold( + vec![vec![0; capacity + 1]; num_items + 1], + |mut matrix, item_index| { + (0..=capacity).for_each(|current_capacity| { + matrix[item_index][current_capacity] = if item_index == 0 || current_capacity == 0 { + 0 + } else if item_weights[item_index - 1] <= current_capacity { + usize::max( + item_values[item_index - 1] + + matrix[item_index - 1] + [current_capacity - item_weights[item_index - 1]], + matrix[item_index - 1][current_capacity], + ) + } else { + matrix[item_index - 1][current_capacity] + }; + }); + matrix + }, + ) +} + +/// Retrieves the indices of items included in the optimal knapsack solution. /// -/// Complexity -/// - time complexity: O(nw), -/// - space complexity: O(nw), +/// # Arguments: +/// * `item_weights` - weights of each item +/// * `knapsack_matrix` - knapsack matrix with maximum values +/// * `item_index` - number of items to consider (initially the total number of items) +/// * `remaining_capacity` - remaining capacity of the knapsack /// -/// where `n` and `w` are `number of items` and `knapsack capacity` -pub fn knapsack(w: usize, weights: Vec, values: Vec) -> (usize, usize, Vec) { - // Checks if the number of items in the list of weights is the same as the number of items in the list of values - assert_eq!(weights.len(), values.len(), "Number of items in the list of weights doesn't match the number of items in the list of values!"); - // Initialize `n` - number of items - let n: usize = weights.len(); - // Find the knapsack table - let m: Vec> = knapsack_table(&w, &weights, &values); - // Find the indices of the items - let items: Vec = knapsack_items(&weights, &m, n, w); - // Find the total weight of optimal knapsack - let mut total_weight: usize = 0; - for i in items.iter() { - total_weight += weights[i - 1]; +/// # Returns +/// A vector of item indices included in the optimal solution. The indices might not be unique. +fn retrieve_knapsack_items( + item_weights: &[usize], + knapsack_matrix: &[Vec], + item_index: usize, + remaining_capacity: usize, +) -> Vec { + match item_index { + 0 => vec![], + _ => { + let current_value = knapsack_matrix[item_index][remaining_capacity]; + let previous_value = knapsack_matrix[item_index - 1][remaining_capacity]; + + match current_value.cmp(&previous_value) { + Ordering::Greater => { + let mut knap = retrieve_knapsack_items( + item_weights, + knapsack_matrix, + item_index - 1, + remaining_capacity - item_weights[item_index - 1], + ); + knap.push(item_index); + knap + } + Ordering::Equal | Ordering::Less => retrieve_knapsack_items( + item_weights, + knapsack_matrix, + item_index - 1, + remaining_capacity, + ), + } + } } - // Return result - (m[n][w], total_weight, items) } #[cfg(test)] mod tests { - // Took test datasets from https://people.sc.fsu.edu/~jburkardt/datasets/bin_packing/bin_packing.html use super::*; - #[test] - fn test_p02() { - assert_eq!( - (51, 26, vec![2, 3, 4]), - knapsack(26, vec![12, 7, 11, 8, 9], vec![24, 13, 23, 15, 16]) - ); - } - - #[test] - fn test_p04() { - assert_eq!( - (150, 190, vec![1, 2, 5]), - knapsack( - 190, - vec![56, 59, 80, 64, 75, 17], - vec![50, 50, 64, 46, 50, 5] - ) - ); - } - - #[test] - fn test_p01() { - assert_eq!( - (309, 165, vec![1, 2, 3, 4, 6]), - knapsack( - 165, - vec![23, 31, 29, 44, 53, 38, 63, 85, 89, 82], - vec![92, 57, 49, 68, 60, 43, 67, 84, 87, 72] - ) - ); - } - - #[test] - fn test_p06() { - assert_eq!( - (1735, 169, vec![2, 4, 7]), - knapsack( - 170, - vec![41, 50, 49, 59, 55, 57, 60], - vec![442, 525, 511, 593, 546, 564, 617] - ) - ); + macro_rules! knapsack_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (capacity, items, expected) = $test_case; + assert_eq!(expected, knapsack(capacity, items)); + } + )* + } } - #[test] - fn test_p07() { - assert_eq!( - (1458, 749, vec![1, 3, 5, 7, 8, 9, 14, 15]), - knapsack( - 750, - vec![70, 73, 77, 80, 82, 87, 90, 94, 98, 106, 110, 113, 115, 118, 120], - vec![135, 139, 149, 150, 156, 163, 173, 184, 192, 201, 210, 214, 221, 229, 240] - ) - ); + knapsack_tests! { + test_basic_knapsack_small: ( + 165, + vec![ + Item { weight: 23, value: 92 }, + Item { weight: 31, value: 57 }, + Item { weight: 29, value: 49 }, + Item { weight: 44, value: 68 }, + Item { weight: 53, value: 60 }, + Item { weight: 38, value: 43 }, + Item { weight: 63, value: 67 }, + Item { weight: 85, value: 84 }, + Item { weight: 89, value: 87 }, + Item { weight: 82, value: 72 } + ], + KnapsackSolution { + optimal_profit: 309, + total_weight: 165, + item_indices: vec![1, 2, 3, 4, 6] + } + ), + test_basic_knapsack_tiny: ( + 26, + vec![ + Item { weight: 12, value: 24 }, + Item { weight: 7, value: 13 }, + Item { weight: 11, value: 23 }, + Item { weight: 8, value: 15 }, + Item { weight: 9, value: 16 } + ], + KnapsackSolution { + optimal_profit: 51, + total_weight: 26, + item_indices: vec![2, 3, 4] + } + ), + test_basic_knapsack_medium: ( + 190, + vec![ + Item { weight: 56, value: 50 }, + Item { weight: 59, value: 50 }, + Item { weight: 80, value: 64 }, + Item { weight: 64, value: 46 }, + Item { weight: 75, value: 50 }, + Item { weight: 17, value: 5 } + ], + KnapsackSolution { + optimal_profit: 150, + total_weight: 190, + item_indices: vec![1, 2, 5] + } + ), + test_diverse_weights_values_small: ( + 50, + vec![ + Item { weight: 31, value: 70 }, + Item { weight: 10, value: 20 }, + Item { weight: 20, value: 39 }, + Item { weight: 19, value: 37 }, + Item { weight: 4, value: 7 }, + Item { weight: 3, value: 5 }, + Item { weight: 6, value: 10 } + ], + KnapsackSolution { + optimal_profit: 107, + total_weight: 50, + item_indices: vec![1, 4] + } + ), + test_diverse_weights_values_medium: ( + 104, + vec![ + Item { weight: 25, value: 350 }, + Item { weight: 35, value: 400 }, + Item { weight: 45, value: 450 }, + Item { weight: 5, value: 20 }, + Item { weight: 25, value: 70 }, + Item { weight: 3, value: 8 }, + Item { weight: 2, value: 5 }, + Item { weight: 2, value: 5 } + ], + KnapsackSolution { + optimal_profit: 900, + total_weight: 104, + item_indices: vec![1, 3, 4, 5, 7, 8] + } + ), + test_high_value_items: ( + 170, + vec![ + Item { weight: 41, value: 442 }, + Item { weight: 50, value: 525 }, + Item { weight: 49, value: 511 }, + Item { weight: 59, value: 593 }, + Item { weight: 55, value: 546 }, + Item { weight: 57, value: 564 }, + Item { weight: 60, value: 617 } + ], + KnapsackSolution { + optimal_profit: 1735, + total_weight: 169, + item_indices: vec![2, 4, 7] + } + ), + test_large_knapsack: ( + 750, + vec![ + Item { weight: 70, value: 135 }, + Item { weight: 73, value: 139 }, + Item { weight: 77, value: 149 }, + Item { weight: 80, value: 150 }, + Item { weight: 82, value: 156 }, + Item { weight: 87, value: 163 }, + Item { weight: 90, value: 173 }, + Item { weight: 94, value: 184 }, + Item { weight: 98, value: 192 }, + Item { weight: 106, value: 201 }, + Item { weight: 110, value: 210 }, + Item { weight: 113, value: 214 }, + Item { weight: 115, value: 221 }, + Item { weight: 118, value: 229 }, + Item { weight: 120, value: 240 } + ], + KnapsackSolution { + optimal_profit: 1458, + total_weight: 749, + item_indices: vec![1, 3, 5, 7, 8, 9, 14, 15] + } + ), + test_zero_capacity: ( + 0, + vec![ + Item { weight: 1, value: 1 }, + Item { weight: 2, value: 2 }, + Item { weight: 3, value: 3 } + ], + KnapsackSolution { + optimal_profit: 0, + total_weight: 0, + item_indices: vec![] + } + ), + test_very_small_capacity: ( + 1, + vec![ + Item { weight: 10, value: 1 }, + Item { weight: 20, value: 2 }, + Item { weight: 30, value: 3 } + ], + KnapsackSolution { + optimal_profit: 0, + total_weight: 0, + item_indices: vec![] + } + ), + test_no_items: ( + 1, + vec![], + KnapsackSolution { + optimal_profit: 0, + total_weight: 0, + item_indices: vec![] + } + ), + test_item_too_heavy: ( + 1, + vec![ + Item { weight: 2, value: 100 } + ], + KnapsackSolution { + optimal_profit: 0, + total_weight: 0, + item_indices: vec![] + } + ), + test_greedy_algorithm_does_not_work: ( + 10, + vec![ + Item { weight: 10, value: 15 }, + Item { weight: 6, value: 7 }, + Item { weight: 4, value: 9 } + ], + KnapsackSolution { + optimal_profit: 16, + total_weight: 10, + item_indices: vec![2, 3] + } + ), + test_greedy_algorithm_does_not_work_weight_smaller_than_capacity: ( + 10, + vec![ + Item { weight: 10, value: 15 }, + Item { weight: 1, value: 9 }, + Item { weight: 2, value: 7 } + ], + KnapsackSolution { + optimal_profit: 16, + total_weight: 3, + item_indices: vec![2, 3] + } + ), } } diff --git a/src/dynamic_programming/longest_common_subsequence.rs b/src/dynamic_programming/longest_common_subsequence.rs index a92ad50e26e..58f82714f93 100644 --- a/src/dynamic_programming/longest_common_subsequence.rs +++ b/src/dynamic_programming/longest_common_subsequence.rs @@ -1,73 +1,116 @@ -/// Longest common subsequence via Dynamic Programming +//! This module implements the Longest Common Subsequence (LCS) algorithm. +//! The LCS problem is finding the longest subsequence common to two sequences. +//! It differs from the problem of finding common substrings: unlike substrings, subsequences +//! are not required to occupy consecutive positions within the original sequences. +//! This implementation handles Unicode strings efficiently and correctly, ensuring +//! that multi-byte characters are managed properly. -/// longest_common_subsequence(a, b) returns the longest common subsequence -/// between the strings a and b. -pub fn longest_common_subsequence(a: &str, b: &str) -> String { - let a: Vec<_> = a.chars().collect(); - let b: Vec<_> = b.chars().collect(); - let (na, nb) = (a.len(), b.len()); +/// Computes the longest common subsequence of two input strings. +/// +/// The longest common subsequence (LCS) of two strings is the longest sequence that can +/// be derived from both strings by deleting some elements without changing the order of +/// the remaining elements. +/// +/// ## Note +/// The function may return different LCSs for the same pair of strings depending on the +/// order of the inputs and the nature of the sequences. This is due to the way the dynamic +/// programming algorithm resolves ties when multiple common subsequences of the same length +/// exist. The order of the input strings can influence the specific path taken through the +/// DP table, resulting in different valid LCS outputs. +/// +/// For example: +/// `longest_common_subsequence("hello, world!", "world, hello!")` returns `"hello!"` +/// but +/// `longest_common_subsequence("world, hello!", "hello, world!")` returns `"world!"` +/// +/// This difference arises because the dynamic programming table is filled differently based +/// on the input order, leading to different tie-breaking decisions and thus different LCS results. +pub fn longest_common_subsequence(first_seq: &str, second_seq: &str) -> String { + let first_seq_chars = first_seq.chars().collect::>(); + let second_seq_chars = second_seq.chars().collect::>(); - // solutions[i][j] is the length of the longest common subsequence - // between a[0..i-1] and b[0..j-1] - let mut solutions = vec![vec![0; nb + 1]; na + 1]; + let lcs_lengths = initialize_lcs_lengths(&first_seq_chars, &second_seq_chars); + let lcs_chars = reconstruct_lcs(&first_seq_chars, &second_seq_chars, &lcs_lengths); - for (i, ci) in a.iter().enumerate() { - for (j, cj) in b.iter().enumerate() { - // if ci == cj, there is a new common character; - // otherwise, take the best of the two solutions - // at (i-1,j) and (i,j-1) - solutions[i + 1][j + 1] = if ci == cj { - solutions[i][j] + 1 + lcs_chars.into_iter().collect() +} + +fn initialize_lcs_lengths(first_seq_chars: &[char], second_seq_chars: &[char]) -> Vec> { + let first_seq_len = first_seq_chars.len(); + let second_seq_len = second_seq_chars.len(); + + let mut lcs_lengths = vec![vec![0; second_seq_len + 1]; first_seq_len + 1]; + + // Populate the LCS lengths table + (1..=first_seq_len).for_each(|i| { + (1..=second_seq_len).for_each(|j| { + lcs_lengths[i][j] = if first_seq_chars[i - 1] == second_seq_chars[j - 1] { + lcs_lengths[i - 1][j - 1] + 1 } else { - solutions[i][j + 1].max(solutions[i + 1][j]) - } - } - } + lcs_lengths[i - 1][j].max(lcs_lengths[i][j - 1]) + }; + }); + }); - // reconstitute the solution string from the lengths - let mut result: Vec = Vec::new(); - let (mut i, mut j) = (na, nb); + lcs_lengths +} + +fn reconstruct_lcs( + first_seq_chars: &[char], + second_seq_chars: &[char], + lcs_lengths: &[Vec], +) -> Vec { + let mut lcs_chars = Vec::new(); + let mut i = first_seq_chars.len(); + let mut j = second_seq_chars.len(); while i > 0 && j > 0 { - if a[i - 1] == b[j - 1] { - result.push(a[i - 1]); + if first_seq_chars[i - 1] == second_seq_chars[j - 1] { + lcs_chars.push(first_seq_chars[i - 1]); i -= 1; j -= 1; - } else if solutions[i - 1][j] > solutions[i][j - 1] { + } else if lcs_lengths[i - 1][j] >= lcs_lengths[i][j - 1] { i -= 1; } else { j -= 1; } } - result.reverse(); - result.iter().collect() + lcs_chars.reverse(); + lcs_chars } #[cfg(test)] mod tests { - use super::longest_common_subsequence; - - #[test] - fn test_longest_common_subsequence() { - // empty case - assert_eq!(&longest_common_subsequence("", ""), ""); - assert_eq!(&longest_common_subsequence("", "abcd"), ""); - assert_eq!(&longest_common_subsequence("abcd", ""), ""); + use super::*; - // simple cases - assert_eq!(&longest_common_subsequence("abcd", "c"), "c"); - assert_eq!(&longest_common_subsequence("abcd", "d"), "d"); - assert_eq!(&longest_common_subsequence("abcd", "e"), ""); - assert_eq!(&longest_common_subsequence("abcdefghi", "acegi"), "acegi"); - - // less simple cases - assert_eq!(&longest_common_subsequence("abcdgh", "aedfhr"), "adh"); - assert_eq!(&longest_common_subsequence("aggtab", "gxtxayb"), "gtab"); + macro_rules! longest_common_subsequence_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (first_seq, second_seq, expected_lcs) = $test_case; + assert_eq!(longest_common_subsequence(&first_seq, &second_seq), expected_lcs); + } + )* + }; + } - // unicode - assert_eq!( - &longest_common_subsequence("你好,世界", "再见世界"), - "世界" - ); + longest_common_subsequence_tests! { + empty_case: ("", "", ""), + one_empty: ("", "abcd", ""), + identical_strings: ("abcd", "abcd", "abcd"), + completely_different: ("abcd", "efgh", ""), + single_character: ("a", "a", "a"), + different_length: ("abcd", "abc", "abc"), + special_characters: ("$#%&", "#@!%", "#%"), + long_strings: ("abcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefgh", + "bcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefgha", + "bcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefghabcdefgh"), + unicode_characters: ("你好,世界", "再见,世界", ",世界"), + spaces_and_punctuation_0: ("hello, world!", "world, hello!", "hello!"), + spaces_and_punctuation_1: ("hello, world!", "world, hello!", "hello!"), // longest_common_subsequence is not symmetric + random_case_1: ("abcdef", "xbcxxxe", "bce"), + random_case_2: ("xyz", "abc", ""), + random_case_3: ("abracadabra", "avadakedavra", "aaadara"), } } diff --git a/src/dynamic_programming/longest_common_substring.rs b/src/dynamic_programming/longest_common_substring.rs new file mode 100644 index 00000000000..52b858e5008 --- /dev/null +++ b/src/dynamic_programming/longest_common_substring.rs @@ -0,0 +1,70 @@ +//! This module provides a function to find the length of the longest common substring +//! between two strings using dynamic programming. + +/// Finds the length of the longest common substring between two strings using dynamic programming. +/// +/// The algorithm uses a 2D dynamic programming table where each cell represents +/// the length of the longest common substring ending at the corresponding indices in +/// the two input strings. The maximum value in the DP table is the result, i.e., the +/// length of the longest common substring. +/// +/// The time complexity is `O(n * m)`, where `n` and `m` are the lengths of the two strings. +/// # Arguments +/// +/// * `s1` - The first input string. +/// * `s2` - The second input string. +/// +/// # Returns +/// +/// Returns the length of the longest common substring between `s1` and `s2`. +pub fn longest_common_substring(s1: &str, s2: &str) -> usize { + let mut substr_len = vec![vec![0; s2.len() + 1]; s1.len() + 1]; + let mut max_len = 0; + + s1.as_bytes().iter().enumerate().for_each(|(i, &c1)| { + s2.as_bytes().iter().enumerate().for_each(|(j, &c2)| { + if c1 == c2 { + substr_len[i + 1][j + 1] = substr_len[i][j] + 1; + max_len = max_len.max(substr_len[i + 1][j + 1]); + } + }); + }); + + max_len +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_longest_common_substring { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (s1, s2, expected) = $inputs; + assert_eq!(longest_common_substring(s1, s2), expected); + assert_eq!(longest_common_substring(s2, s1), expected); + } + )* + } + } + + test_longest_common_substring! { + test_empty_strings: ("", "", 0), + test_one_empty_string: ("", "a", 0), + test_identical_single_char: ("a", "a", 1), + test_different_single_char: ("a", "b", 0), + test_common_substring_at_start: ("abcdef", "abc", 3), + test_common_substring_at_middle: ("abcdef", "bcd", 3), + test_common_substring_at_end: ("abcdef", "def", 3), + test_no_common_substring: ("abc", "xyz", 0), + test_overlapping_substrings: ("abcdxyz", "xyzabcd", 4), + test_special_characters: ("@abc#def$", "#def@", 4), + test_case_sensitive: ("abcDEF", "ABCdef", 0), + test_full_string_match: ("GeeksforGeeks", "GeeksforGeeks", 13), + test_substring_with_repeated_chars: ("aaaaaaaaaaaaa", "aaa", 3), + test_longer_strings_with_common_substring: ("OldSite:GeeksforGeeks.org", "NewSite:GeeksQuiz.com", 10), + test_no_common_substring_with_special_chars: ("!!!", "???", 0), + } +} diff --git a/src/dynamic_programming/longest_continuous_increasing_subsequence.rs b/src/dynamic_programming/longest_continuous_increasing_subsequence.rs index 0ca9d803371..3d47b433ae6 100644 --- a/src/dynamic_programming/longest_continuous_increasing_subsequence.rs +++ b/src/dynamic_programming/longest_continuous_increasing_subsequence.rs @@ -1,74 +1,93 @@ -pub fn longest_continuous_increasing_subsequence(input_array: &[T]) -> &[T] { - let length: usize = input_array.len(); +use std::cmp::Ordering; - //Handle the base cases - if length <= 1 { - return input_array; +/// Finds the longest continuous increasing subsequence in a slice. +/// +/// Given a slice of elements, this function returns a slice representing +/// the longest continuous subsequence where each element is strictly +/// less than the following element. +/// +/// # Arguments +/// +/// * `arr` - A reference to a slice of elements +/// +/// # Returns +/// +/// A subslice of the input, representing the longest continuous increasing subsequence. +/// If there are multiple subsequences of the same length, the function returns the first one found. +pub fn longest_continuous_increasing_subsequence(arr: &[T]) -> &[T] { + if arr.len() <= 1 { + return arr; } - //Create the array to store the longest subsequence at each location - let mut tracking_vec = vec![1; length]; + let mut start = 0; + let mut max_start = 0; + let mut max_len = 1; + let mut curr_len = 1; - //Iterate through the input and store longest subsequences at each location in the vector - for i in (0..length - 1).rev() { - if input_array[i] < input_array[i + 1] { - tracking_vec[i] = tracking_vec[i + 1] + 1; + for i in 1..arr.len() { + match arr[i - 1].cmp(&arr[i]) { + // include current element is greater than or equal to the previous + // one elements in the current increasing sequence + Ordering::Less | Ordering::Equal => { + curr_len += 1; + } + // reset when a strictly decreasing element is found + Ordering::Greater => { + if curr_len > max_len { + max_len = curr_len; + max_start = start; + } + // reset start to the current position + start = i; + // reset current length + curr_len = 1; + } } } - //Find the longest subsequence - let mut max_index: usize = 0; - let mut max_value: i32 = 0; - for (index, value) in tracking_vec.iter().enumerate() { - if value > &max_value { - max_value = *value; - max_index = index; - } + // final check for the last sequence + if curr_len > max_len { + max_len = curr_len; + max_start = start; } - &input_array[max_index..max_index + max_value as usize] + &arr[max_start..max_start + max_len] } #[cfg(test)] mod tests { - use super::longest_continuous_increasing_subsequence; - - #[test] - fn test_longest_increasing_subsequence() { - //Base Cases - let base_case_array: [i32; 0] = []; - assert_eq!( - &longest_continuous_increasing_subsequence(&base_case_array), - &[] - ); - assert_eq!(&longest_continuous_increasing_subsequence(&[1]), &[1]); + use super::*; - //Normal i32 Cases - assert_eq!( - &longest_continuous_increasing_subsequence(&[1, 2, 3, 4]), - &[1, 2, 3, 4] - ); - assert_eq!( - &longest_continuous_increasing_subsequence(&[1, 2, 2, 3, 4, 2]), - &[2, 3, 4] - ); - assert_eq!( - &longest_continuous_increasing_subsequence(&[5, 4, 3, 2, 1]), - &[5] - ); - assert_eq!( - &longest_continuous_increasing_subsequence(&[5, 4, 3, 4, 2, 1]), - &[3, 4] - ); + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(longest_continuous_increasing_subsequence(input), expected); + } + )* + }; + } - //Non-Numeric case - assert_eq!( - &longest_continuous_increasing_subsequence(&['a', 'b', 'c']), - &['a', 'b', 'c'] - ); - assert_eq!( - &longest_continuous_increasing_subsequence(&['d', 'c', 'd']), - &['c', 'd'] - ); + test_cases! { + empty_array: (&[] as &[isize], &[] as &[isize]), + single_element: (&[1], &[1]), + all_increasing: (&[1, 2, 3, 4, 5], &[1, 2, 3, 4, 5]), + all_decreasing: (&[5, 4, 3, 2, 1], &[5]), + with_equal_elements: (&[1, 2, 2, 3, 4, 2], &[1, 2, 2, 3, 4]), + increasing_with_plateau: (&[1, 2, 2, 2, 3, 3, 4], &[1, 2, 2, 2, 3, 3, 4]), + mixed_elements: (&[5, 4, 3, 4, 2, 1], &[3, 4]), + alternating_increase_decrease: (&[1, 2, 1, 2, 1, 2], &[1, 2]), + zigzag: (&[1, 3, 2, 4, 3, 5], &[1, 3]), + single_negative_element: (&[-1], &[-1]), + negative_and_positive_mixed: (&[-2, -1, 0, 1, 2, 3], &[-2, -1, 0, 1, 2, 3]), + increasing_then_decreasing: (&[1, 2, 3, 4, 3, 2, 1], &[1, 2, 3, 4]), + single_increasing_subsequence_later: (&[3, 2, 1, 1, 2, 3, 4], &[1, 1, 2, 3, 4]), + longer_subsequence_at_start: (&[5, 6, 7, 8, 9, 2, 3, 4, 5], &[5, 6, 7, 8, 9]), + longer_subsequence_at_end: (&[2, 3, 4, 10, 5, 6, 7, 8, 9], &[5, 6, 7, 8, 9]), + longest_subsequence_at_start: (&[2, 3, 4, 5, 1, 0], &[2, 3, 4, 5]), + longest_subsequence_at_end: (&[1, 7, 2, 3, 4, 5,], &[2, 3, 4, 5]), + repeated_elements: (&[1, 1, 1, 1, 1], &[1, 1, 1, 1, 1]), } } diff --git a/src/dynamic_programming/longest_increasing_subsequence.rs b/src/dynamic_programming/longest_increasing_subsequence.rs index fedefe07463..ed58500135a 100644 --- a/src/dynamic_programming/longest_increasing_subsequence.rs +++ b/src/dynamic_programming/longest_increasing_subsequence.rs @@ -52,13 +52,13 @@ mod tests { #[test] /// Need to specify generic type T in order to function fn test_empty_vec() { - assert_eq!(longest_increasing_subsequence::(&vec![]), vec![]); + assert_eq!(longest_increasing_subsequence::(&[]), vec![]); } #[test] fn test_example_1() { assert_eq!( - longest_increasing_subsequence(&vec![10, 9, 2, 5, 3, 7, 101, 18]), + longest_increasing_subsequence(&[10, 9, 2, 5, 3, 7, 101, 18]), vec![2, 3, 7, 18] ); } @@ -66,7 +66,7 @@ mod tests { #[test] fn test_example_2() { assert_eq!( - longest_increasing_subsequence(&vec![0, 1, 0, 3, 2, 3]), + longest_increasing_subsequence(&[0, 1, 0, 3, 2, 3]), vec![0, 1, 2, 3] ); } @@ -74,7 +74,7 @@ mod tests { #[test] fn test_example_3() { assert_eq!( - longest_increasing_subsequence(&vec![7, 7, 7, 7, 7, 7, 7]), + longest_increasing_subsequence(&[7, 7, 7, 7, 7, 7, 7]), vec![7] ); } @@ -104,6 +104,6 @@ mod tests { #[test] fn test_negative_elements() { - assert_eq!(longest_increasing_subsequence(&vec![-2, -1]), vec![-2, -1]); + assert_eq!(longest_increasing_subsequence(&[-2, -1]), vec![-2, -1]); } } diff --git a/src/dynamic_programming/matrix_chain_multiply.rs b/src/dynamic_programming/matrix_chain_multiply.rs new file mode 100644 index 00000000000..410aec741e5 --- /dev/null +++ b/src/dynamic_programming/matrix_chain_multiply.rs @@ -0,0 +1,91 @@ +//! This module implements a dynamic programming solution to find the minimum +//! number of multiplications needed to multiply a chain of matrices with given dimensions. +//! +//! The algorithm uses a dynamic programming approach with tabulation to calculate the minimum +//! number of multiplications required for matrix chain multiplication. +//! +//! # Time Complexity +//! +//! The algorithm runs in O(n^3) time complexity and O(n^2) space complexity, where n is the +//! number of matrices. + +/// Custom error types for matrix chain multiplication +#[derive(Debug, PartialEq)] +pub enum MatrixChainMultiplicationError { + EmptyDimensions, + InsufficientDimensions, +} + +/// Calculates the minimum number of scalar multiplications required to multiply a chain +/// of matrices with given dimensions. +/// +/// # Arguments +/// +/// * `dimensions`: A vector where each element represents the dimensions of consecutive matrices +/// in the chain. For example, [1, 2, 3, 4] represents matrices of dimensions (1x2), (2x3), and (3x4). +/// +/// # Returns +/// +/// The minimum number of scalar multiplications needed to compute the product of the matrices +/// in the optimal order. +/// +/// # Errors +/// +/// Returns an error if the input is invalid (i.e., empty or length less than 2). +pub fn matrix_chain_multiply( + dimensions: Vec, +) -> Result { + if dimensions.is_empty() { + return Err(MatrixChainMultiplicationError::EmptyDimensions); + } + + if dimensions.len() == 1 { + return Err(MatrixChainMultiplicationError::InsufficientDimensions); + } + + let mut min_operations = vec![vec![0; dimensions.len()]; dimensions.len()]; + + (2..dimensions.len()).for_each(|chain_len| { + (0..dimensions.len() - chain_len).for_each(|start| { + let end = start + chain_len; + min_operations[start][end] = (start + 1..end) + .map(|split| { + min_operations[start][split] + + min_operations[split][end] + + dimensions[start] * dimensions[split] * dimensions[end] + }) + .min() + .unwrap_or(usize::MAX); + }); + }); + + Ok(min_operations[0][dimensions.len() - 1]) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(matrix_chain_multiply(input.clone()), expected); + assert_eq!(matrix_chain_multiply(input.into_iter().rev().collect()), expected); + } + )* + }; + } + + test_cases! { + basic_chain_of_matrices: (vec![1, 2, 3, 4], Ok(18)), + chain_of_large_matrices: (vec![40, 20, 30, 10, 30], Ok(26000)), + long_chain_of_matrices: (vec![1, 2, 3, 4, 3, 5, 7, 6, 10], Ok(182)), + complex_chain_of_matrices: (vec![4, 10, 3, 12, 20, 7], Ok(1344)), + empty_dimensions_input: (vec![], Err(MatrixChainMultiplicationError::EmptyDimensions)), + single_dimensions_input: (vec![10], Err(MatrixChainMultiplicationError::InsufficientDimensions)), + single_matrix_input: (vec![10, 20], Ok(0)), + } +} diff --git a/src/dynamic_programming/maximal_square.rs b/src/dynamic_programming/maximal_square.rs index 6c4d1c2be63..706d0b9aeb8 100644 --- a/src/dynamic_programming/maximal_square.rs +++ b/src/dynamic_programming/maximal_square.rs @@ -2,14 +2,16 @@ use std::cmp::max; use std::cmp::min; /// Maximal Square -/// Given an m x n binary matrix filled with 0's and 1's, find the largest square containing only 1's and return its area. -/// https://leetcode.com/problems/maximal-square/ /// -/// Arguments: -/// * `matrix` - an array of integer array -/// Complexity -/// - time complexity: O(n^2), -/// - space complexity: O(n), +/// Given an `m` * `n` binary matrix filled with 0's and 1's, find the largest square containing only 1's and return its area.\ +/// +/// +/// # Arguments: +/// * `matrix` - an array of integer array +/// +/// # Complexity +/// - time complexity: O(n^2), +/// - space complexity: O(n), pub fn maximal_square(matrix: &mut [Vec]) -> i32 { if matrix.is_empty() { return 0; @@ -45,7 +47,7 @@ mod tests { #[test] fn test() { - assert_eq!(maximal_square(&mut vec![]), 0); + assert_eq!(maximal_square(&mut []), 0); let mut matrix = vec![vec![0, 1], vec![1, 0]]; assert_eq!(maximal_square(&mut matrix), 1); diff --git a/src/dynamic_programming/maximum_subarray.rs b/src/dynamic_programming/maximum_subarray.rs index efcbec402d5..740f8009d60 100644 --- a/src/dynamic_programming/maximum_subarray.rs +++ b/src/dynamic_programming/maximum_subarray.rs @@ -1,62 +1,82 @@ -/// ## maximum subarray via Dynamic Programming +//! This module provides a function to find the largest sum of the subarray +//! in a given array of integers using dynamic programming. It also includes +//! tests to verify the correctness of the implementation. -/// maximum_subarray(array) find the subarray (containing at least one number) which has the largest sum -/// and return its sum. +/// Custom error type for maximum subarray +#[derive(Debug, PartialEq)] +pub enum MaximumSubarrayError { + EmptyArray, +} + +/// Finds the subarray (containing at least one number) which has the largest sum +/// and returns its sum. /// /// A subarray is a contiguous part of an array. /// -/// Arguments: -/// * `array` - an integer array -/// Complexity -/// - time complexity: O(array.length), -/// - space complexity: O(array.length), -pub fn maximum_subarray(array: &[i32]) -> i32 { - let mut dp = vec![0; array.len()]; - dp[0] = array[0]; - let mut result = dp[0]; - - for i in 1..array.len() { - if dp[i - 1] > 0 { - dp[i] = dp[i - 1] + array[i]; - } else { - dp[i] = array[i]; - } - result = result.max(dp[i]); +/// # Arguments +/// +/// * `array` - A slice of integers. +/// +/// # Returns +/// +/// A `Result` which is: +/// * `Ok(isize)` representing the largest sum of a contiguous subarray. +/// * `Err(MaximumSubarrayError)` if the array is empty. +/// +/// # Complexity +/// +/// * Time complexity: `O(array.len())` +/// * Space complexity: `O(1)` +pub fn maximum_subarray(array: &[isize]) -> Result { + if array.is_empty() { + return Err(MaximumSubarrayError::EmptyArray); } - result + let mut cur_sum = array[0]; + let mut max_sum = cur_sum; + + for &x in &array[1..] { + cur_sum = (cur_sum + x).max(x); + max_sum = max_sum.max(cur_sum); + } + + Ok(max_sum) } #[cfg(test)] mod tests { use super::*; - #[test] - fn non_negative() { - //the maximum value: 1 + 0 + 5 + 8 = 14 - let array = vec![1, 0, 5, 8]; - assert_eq!(maximum_subarray(&array), 14); - } - - #[test] - fn negative() { - //the maximum value: -1 - let array = vec![-3, -1, -8, -2]; - assert_eq!(maximum_subarray(&array), -1); - } - - #[test] - fn normal() { - //the maximum value: 3 + (-2) + 5 = 6 - let array = vec![-4, 3, -2, 5, -8]; - assert_eq!(maximum_subarray(&array), 6); + macro_rules! maximum_subarray_tests { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (array, expected) = $tc; + assert_eq!(maximum_subarray(&array), expected); + } + )* + } } - #[test] - fn single_element() { - let array = vec![6]; - assert_eq!(maximum_subarray(&array), 6); - let array = vec![-6]; - assert_eq!(maximum_subarray(&array), -6); + maximum_subarray_tests! { + test_all_non_negative: (vec![1, 0, 5, 8], Ok(14)), + test_all_negative: (vec![-3, -1, -8, -2], Ok(-1)), + test_mixed_negative_and_positive: (vec![-4, 3, -2, 5, -8], Ok(6)), + test_single_element_positive: (vec![6], Ok(6)), + test_single_element_negative: (vec![-6], Ok(-6)), + test_mixed_elements: (vec![-2, 1, -3, 4, -1, 2, 1, -5, 4], Ok(6)), + test_empty_array: (vec![], Err(MaximumSubarrayError::EmptyArray)), + test_all_zeroes: (vec![0, 0, 0, 0], Ok(0)), + test_single_zero: (vec![0], Ok(0)), + test_alternating_signs: (vec![3, -2, 5, -1], Ok(6)), + test_all_negatives_with_one_positive: (vec![-3, -4, 1, -7, -2], Ok(1)), + test_all_positives_with_one_negative: (vec![3, 4, -1, 7, 2], Ok(15)), + test_all_positives: (vec![2, 3, 1, 5], Ok(11)), + test_large_values: (vec![1000, -500, 1000, -500, 1000], Ok(2000)), + test_large_array: ((0..1000).collect::>(), Ok(499500)), + test_large_negative_array: ((0..1000).map(|x| -x).collect::>(), Ok(0)), + test_single_large_positive: (vec![1000000], Ok(1000000)), + test_single_large_negative: (vec![-1000000], Ok(-1000000)), } } diff --git a/src/dynamic_programming/minimum_cost_path.rs b/src/dynamic_programming/minimum_cost_path.rs new file mode 100644 index 00000000000..e06481199cf --- /dev/null +++ b/src/dynamic_programming/minimum_cost_path.rs @@ -0,0 +1,177 @@ +use std::cmp::min; + +/// Represents possible errors that can occur when calculating the minimum cost path in a matrix. +#[derive(Debug, PartialEq, Eq)] +pub enum MatrixError { + /// Error indicating that the matrix is empty or has empty rows. + EmptyMatrix, + /// Error indicating that the matrix is not rectangular in shape. + NonRectangularMatrix, +} + +/// Computes the minimum cost path from the top-left to the bottom-right +/// corner of a matrix, where movement is restricted to right and down directions. +/// +/// # Arguments +/// +/// * `matrix` - A 2D vector of positive integers, where each element represents +/// the cost to step on that cell. +/// +/// # Returns +/// +/// * `Ok(usize)` - The minimum path cost to reach the bottom-right corner from +/// the top-left corner of the matrix. +/// * `Err(MatrixError)` - An error if the matrix is empty or improperly formatted. +/// +/// # Complexity +/// +/// * Time complexity: `O(m * n)`, where `m` is the number of rows +/// and `n` is the number of columns in the input matrix. +/// * Space complexity: `O(n)`, as only a single row of cumulative costs +/// is stored at any time. +pub fn minimum_cost_path(matrix: Vec>) -> Result { + // Check if the matrix is rectangular + if !matrix.iter().all(|row| row.len() == matrix[0].len()) { + return Err(MatrixError::NonRectangularMatrix); + } + + // Check if the matrix is empty or contains empty rows + if matrix.is_empty() || matrix.iter().all(|row| row.is_empty()) { + return Err(MatrixError::EmptyMatrix); + } + + // Initialize the first row of the cost vector + let mut cost = matrix[0] + .iter() + .scan(0, |acc, &val| { + *acc += val; + Some(*acc) + }) + .collect::>(); + + // Process each row from the second to the last + for row in matrix.iter().skip(1) { + // Update the first element of cost for this row + cost[0] += row[0]; + + // Update the rest of the elements in the current row of cost + for col in 1..matrix[0].len() { + cost[col] = row[col] + min(cost[col - 1], cost[col]); + } + } + + // The last element in cost contains the minimum path cost to the bottom-right corner + Ok(cost[matrix[0].len() - 1]) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! minimum_cost_path_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (matrix, expected) = $test_case; + assert_eq!(minimum_cost_path(matrix), expected); + } + )* + }; + } + + minimum_cost_path_tests! { + basic: ( + vec![ + vec![2, 1, 4], + vec![2, 1, 3], + vec![3, 2, 1] + ], + Ok(7) + ), + single_element: ( + vec![ + vec![5] + ], + Ok(5) + ), + single_row: ( + vec![ + vec![1, 3, 2, 1, 5] + ], + Ok(12) + ), + single_column: ( + vec![ + vec![1], + vec![3], + vec![2], + vec![1], + vec![5] + ], + Ok(12) + ), + large_matrix: ( + vec![ + vec![1, 3, 1, 5], + vec![2, 1, 4, 2], + vec![3, 2, 1, 3], + vec![4, 3, 2, 1] + ], + Ok(10) + ), + uniform_matrix: ( + vec![ + vec![1, 1, 1], + vec![1, 1, 1], + vec![1, 1, 1] + ], + Ok(5) + ), + increasing_values: ( + vec![ + vec![1, 2, 3], + vec![4, 5, 6], + vec![7, 8, 9] + ], + Ok(21) + ), + high_cost_path: ( + vec![ + vec![1, 100, 1], + vec![1, 100, 1], + vec![1, 1, 1] + ], + Ok(5) + ), + complex_matrix: ( + vec![ + vec![5, 9, 6, 8], + vec![1, 4, 7, 3], + vec![2, 1, 8, 2], + vec![3, 6, 9, 4] + ], + Ok(23) + ), + empty_matrix: ( + vec![], + Err(MatrixError::EmptyMatrix) + ), + empty_row: ( + vec![ + vec![], + vec![], + vec![] + ], + Err(MatrixError::EmptyMatrix) + ), + non_rectangular: ( + vec![ + vec![1, 2, 3], + vec![4, 5], + vec![6, 7, 8] + ], + Err(MatrixError::NonRectangularMatrix) + ), + } +} diff --git a/src/dynamic_programming/mod.rs b/src/dynamic_programming/mod.rs index fd496b40c7f..f18c1847479 100644 --- a/src/dynamic_programming/mod.rs +++ b/src/dynamic_programming/mod.rs @@ -1,31 +1,49 @@ mod coin_change; -mod edit_distance; mod egg_dropping; mod fibonacci; +mod fractional_knapsack; mod is_subsequence; mod knapsack; mod longest_common_subsequence; +mod longest_common_substring; mod longest_continuous_increasing_subsequence; mod longest_increasing_subsequence; +mod matrix_chain_multiply; mod maximal_square; mod maximum_subarray; +mod minimum_cost_path; +mod optimal_bst; mod rod_cutting; mod snail; +mod subset_generation; +mod trapped_rainwater; +mod word_break; pub use self::coin_change::coin_change; -pub use self::edit_distance::{edit_distance, edit_distance_se}; pub use self::egg_dropping::egg_drop; +pub use self::fibonacci::binary_lifting_fibonacci; pub use self::fibonacci::classical_fibonacci; pub use self::fibonacci::fibonacci; +pub use self::fibonacci::last_digit_of_the_sum_of_nth_fibonacci_number; pub use self::fibonacci::logarithmic_fibonacci; +pub use self::fibonacci::matrix_fibonacci; pub use self::fibonacci::memoized_fibonacci; +pub use self::fibonacci::nth_fibonacci_number_modulo_m; pub use self::fibonacci::recursive_fibonacci; +pub use self::fractional_knapsack::fractional_knapsack; pub use self::is_subsequence::is_subsequence; pub use self::knapsack::knapsack; pub use self::longest_common_subsequence::longest_common_subsequence; +pub use self::longest_common_substring::longest_common_substring; pub use self::longest_continuous_increasing_subsequence::longest_continuous_increasing_subsequence; pub use self::longest_increasing_subsequence::longest_increasing_subsequence; +pub use self::matrix_chain_multiply::matrix_chain_multiply; pub use self::maximal_square::maximal_square; pub use self::maximum_subarray::maximum_subarray; +pub use self::minimum_cost_path::minimum_cost_path; +pub use self::optimal_bst::optimal_search_tree; pub use self::rod_cutting::rod_cut; pub use self::snail::snail; +pub use self::subset_generation::list_subset; +pub use self::trapped_rainwater::trapped_rainwater; +pub use self::word_break::word_break; diff --git a/src/dynamic_programming/optimal_bst.rs b/src/dynamic_programming/optimal_bst.rs new file mode 100644 index 00000000000..162351a21c6 --- /dev/null +++ b/src/dynamic_programming/optimal_bst.rs @@ -0,0 +1,93 @@ +// Optimal Binary Search Tree Algorithm in Rust +// Time Complexity: O(n^3) with prefix sum optimization +// Space Complexity: O(n^2) for the dp table and prefix sum array + +/// Constructs an Optimal Binary Search Tree from a list of key frequencies. +/// The goal is to minimize the expected search cost given key access frequencies. +/// +/// # Arguments +/// * `freq` - A slice of integers representing the frequency of key access +/// +/// # Returns +/// * An integer representing the minimum cost of the optimal BST +pub fn optimal_search_tree(freq: &[i32]) -> i32 { + let n = freq.len(); + if n == 0 { + return 0; + } + + // dp[i][j] stores the cost of optimal BST that can be formed from keys[i..=j] + let mut dp = vec![vec![0; n]; n]; + + // prefix_sum[i] stores sum of freq[0..i] + let mut prefix_sum = vec![0; n + 1]; + for i in 0..n { + prefix_sum[i + 1] = prefix_sum[i] + freq[i]; + } + + // Base case: Trees with only one key + for i in 0..n { + dp[i][i] = freq[i]; + } + + // Build chains of increasing length l (from 2 to n) + for l in 2..=n { + for i in 0..=n - l { + let j = i + l - 1; + dp[i][j] = i32::MAX; + + // Compute the total frequency sum in the range [i..=j] using prefix sum + let fsum = prefix_sum[j + 1] - prefix_sum[i]; + + // Try making each key in freq[i..=j] the root of the tree + for r in i..=j { + // Cost of left subtree + let left = if r > i { dp[i][r - 1] } else { 0 }; + // Cost of right subtree + let right = if r < j { dp[r + 1][j] } else { 0 }; + + // Total cost = left + right + sum of frequencies (fsum) + let cost = left + right + fsum; + + // Choose the minimum among all possible roots + if cost < dp[i][j] { + dp[i][j] = cost; + } + } + } + } + + // Minimum cost of the optimal BST storing all keys + dp[0][n - 1] +} + +#[cfg(test)] +mod tests { + use super::*; + + // Macro to generate multiple test cases for the optimal_search_tree function + macro_rules! optimal_bst_tests { + ($($name:ident: $input:expr => $expected:expr,)*) => { + $( + #[test] + fn $name() { + let freq = $input; + assert_eq!(optimal_search_tree(freq), $expected); + } + )* + }; + } + + optimal_bst_tests! { + // Common test cases + test_case_1: &[34, 10, 8, 50] => 180, + test_case_2: &[10, 12] => 32, + test_case_3: &[10, 12, 20] => 72, + test_case_4: &[25, 10, 20] => 95, + test_case_5: &[4, 2, 6, 3] => 26, + + // Edge test cases + test_case_single: &[42] => 42, + test_case_empty: &[] => 0, + } +} diff --git a/src/dynamic_programming/rod_cutting.rs b/src/dynamic_programming/rod_cutting.rs index 015e26d46a2..e56d482fdf7 100644 --- a/src/dynamic_programming/rod_cutting.rs +++ b/src/dynamic_programming/rod_cutting.rs @@ -1,55 +1,65 @@ -//! Solves the rod-cutting problem +//! This module provides functions for solving the rod-cutting problem using dynamic programming. use std::cmp::max; -/// `rod_cut(p)` returns the maximum possible profit if a rod of length `n` = `p.len()` -/// is cut into up to `n` pieces, where the profit gained from each piece of length -/// `l` is determined by `p[l - 1]` and the total profit is the sum of the profit -/// gained from each piece. +/// Calculates the maximum possible profit from cutting a rod into pieces of varying lengths. /// -/// # Arguments -/// - `p` - profit for rods of length 1 to n inclusive +/// Returns the maximum profit achievable by cutting a rod into pieces such that the profit from each +/// piece is determined by its length and predefined prices. /// /// # Complexity -/// - time complexity: O(n^2), -/// - space complexity: O(n^2), +/// - Time complexity: `O(n^2)` +/// - Space complexity: `O(n)` /// -/// where n is the length of `p`. -pub fn rod_cut(p: &[usize]) -> usize { - let n = p.len(); - // f is the dynamic programming table - let mut f = vec![0; n]; - - for i in 0..n { - let mut max_price = p[i]; - for j in 1..=i { - max_price = max(max_price, p[j - 1] + f[i - j]); - } - f[i] = max_price; +/// where `n` is the number of different rod lengths considered. +pub fn rod_cut(prices: &[usize]) -> usize { + if prices.is_empty() { + return 0; } - // accomodate for input with length zero - if n != 0 { - f[n - 1] - } else { - 0 - } + (1..=prices.len()).fold(vec![0; prices.len() + 1], |mut max_profit, rod_length| { + max_profit[rod_length] = (1..=rod_length) + .map(|cut_position| prices[cut_position - 1] + max_profit[rod_length - cut_position]) + .fold(prices[rod_length - 1], |max_price, current_price| { + max(max_price, current_price) + }); + max_profit + })[prices.len()] } #[cfg(test)] mod tests { - use super::rod_cut; + use super::*; + + macro_rules! rod_cut_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected_output) = $test_case; + assert_eq!(expected_output, rod_cut(input)); + } + )* + }; + } - #[test] - fn test_rod_cut() { - assert_eq!(0, rod_cut(&[])); - assert_eq!(15, rod_cut(&[5, 8, 2])); - assert_eq!(10, rod_cut(&[1, 5, 8, 9])); - assert_eq!(25, rod_cut(&[5, 8, 2, 1, 7])); - assert_eq!(87, rod_cut(&[0, 0, 0, 0, 0, 87])); - assert_eq!(49, rod_cut(&[7, 6, 5, 4, 3, 2, 1])); - assert_eq!(22, rod_cut(&[1, 5, 8, 9, 10, 17, 17, 20])); - assert_eq!(60, rod_cut(&[6, 4, 8, 2, 5, 8, 2, 3, 7, 11])); - assert_eq!(30, rod_cut(&[1, 5, 8, 9, 10, 17, 17, 20, 24, 30])); - assert_eq!(12, rod_cut(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])); + rod_cut_tests! { + test_empty_prices: (&[], 0), + test_example_with_three_prices: (&[5, 8, 2], 15), + test_example_with_four_prices: (&[1, 5, 8, 9], 10), + test_example_with_five_prices: (&[5, 8, 2, 1, 7], 25), + test_all_zeros_except_last: (&[0, 0, 0, 0, 0, 87], 87), + test_descending_prices: (&[7, 6, 5, 4, 3, 2, 1], 49), + test_varied_prices: (&[1, 5, 8, 9, 10, 17, 17, 20], 22), + test_complex_prices: (&[6, 4, 8, 2, 5, 8, 2, 3, 7, 11], 60), + test_increasing_prices: (&[1, 5, 8, 9, 10, 17, 17, 20, 24, 30], 30), + test_large_range_prices: (&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], 12), + test_single_length_price: (&[5], 5), + test_zero_length_price: (&[0], 0), + test_repeated_prices: (&[5, 5, 5, 5], 20), + test_no_profit: (&[0, 0, 0, 0], 0), + test_large_input: (&[1; 1000], 1000), + test_all_zero_input: (&[0; 100], 0), + test_very_large_prices: (&[1000000, 2000000, 3000000], 3000000), + test_greedy_does_not_work: (&[2, 5, 7, 8], 10), } } diff --git a/src/dynamic_programming/snail.rs b/src/dynamic_programming/snail.rs index cf9673b15f6..8b1a8358f0b 100644 --- a/src/dynamic_programming/snail.rs +++ b/src/dynamic_programming/snail.rs @@ -111,7 +111,7 @@ mod test { #[test] fn test_empty() { let empty: &[Vec] = &[vec![]]; - assert_eq!(snail(&empty), vec![]); + assert_eq!(snail(empty), vec![]); } #[test] diff --git a/src/dynamic_programming/subset_generation.rs b/src/dynamic_programming/subset_generation.rs new file mode 100644 index 00000000000..39dac5d293e --- /dev/null +++ b/src/dynamic_programming/subset_generation.rs @@ -0,0 +1,116 @@ +// list all subset combinations of n element in given set of r element. +// This is a recursive function that collects all subsets of the set of size n +// with the given set of size r. +pub fn list_subset( + set: &[i32], + n: usize, + r: usize, + index: usize, + data: &mut [i32], + i: usize, +) -> Vec> { + let mut res = Vec::new(); + + // Current subset is ready to be added to the list + if i == r { + let mut subset = Vec::new(); + for j in data.iter().take(r) { + subset.push(*j); + } + res.push(subset); + return res; + } + + // When no more elements are there to put in data[] + if index >= n { + return res; + } + + // current is included, put next at next location + data[i] = set[index]; + res.append(&mut list_subset(set, n, r, index + 1, data, i + 1)); + + // current is excluded, replace it with next (Note that + // i+1 is passed, but index is not changed) + res.append(&mut list_subset(set, n, r, index + 1, data, i)); + + res +} + +// Test module +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_print_subset3() { + let set = [1, 2, 3, 4, 5]; + let n = set.len(); + const R: usize = 3; + let mut data = [0; R]; + + let res = list_subset(&set, n, R, 0, &mut data, 0); + + assert_eq!( + res, + vec![ + vec![1, 2, 3], + vec![1, 2, 4], + vec![1, 2, 5], + vec![1, 3, 4], + vec![1, 3, 5], + vec![1, 4, 5], + vec![2, 3, 4], + vec![2, 3, 5], + vec![2, 4, 5], + vec![3, 4, 5] + ] + ); + } + + #[test] + fn test_print_subset4() { + let set = [1, 2, 3, 4, 5]; + let n = set.len(); + const R: usize = 4; + let mut data = [0; R]; + + let res = list_subset(&set, n, R, 0, &mut data, 0); + + assert_eq!( + res, + vec![ + vec![1, 2, 3, 4], + vec![1, 2, 3, 5], + vec![1, 2, 4, 5], + vec![1, 3, 4, 5], + vec![2, 3, 4, 5] + ] + ); + } + + #[test] + fn test_print_subset5() { + let set = [1, 2, 3, 4, 5]; + let n = set.len(); + const R: usize = 5; + let mut data = [0; R]; + + let res = list_subset(&set, n, R, 0, &mut data, 0); + + assert_eq!(res, vec![vec![1, 2, 3, 4, 5]]); + } + + #[test] + fn test_print_incorrect_subset() { + let set = [1, 2, 3, 4, 5]; + let n = set.len(); + const R: usize = 6; + let mut data = [0; R]; + + let res = list_subset(&set, n, R, 0, &mut data, 0); + + let result_set: Vec> = Vec::new(); + assert_eq!(res, result_set); + } +} diff --git a/src/dynamic_programming/trapped_rainwater.rs b/src/dynamic_programming/trapped_rainwater.rs new file mode 100644 index 00000000000..b220754ca23 --- /dev/null +++ b/src/dynamic_programming/trapped_rainwater.rs @@ -0,0 +1,125 @@ +//! Module to calculate trapped rainwater in an elevation map. + +/// Computes the total volume of trapped rainwater in a given elevation map. +/// +/// # Arguments +/// +/// * `elevation_map` - A slice containing the heights of the terrain elevations. +/// +/// # Returns +/// +/// The total volume of trapped rainwater. +pub fn trapped_rainwater(elevation_map: &[u32]) -> u32 { + let left_max = calculate_max_values(elevation_map, false); + let right_max = calculate_max_values(elevation_map, true); + let mut water_trapped = 0; + // Calculate trapped water + for i in 0..elevation_map.len() { + water_trapped += left_max[i].min(right_max[i]) - elevation_map[i]; + } + water_trapped +} + +/// Determines the maximum heights from either direction in the elevation map. +/// +/// # Arguments +/// +/// * `elevation_map` - A slice representing the heights of the terrain elevations. +/// * `reverse` - A boolean that indicates the direction of calculation. +/// - `false` for left-to-right. +/// - `true` for right-to-left. +/// +/// # Returns +/// +/// A vector containing the maximum heights encountered up to each position. +fn calculate_max_values(elevation_map: &[u32], reverse: bool) -> Vec { + let mut max_values = vec![0; elevation_map.len()]; + let mut current_max = 0; + for i in create_iter(elevation_map.len(), reverse) { + current_max = current_max.max(elevation_map[i]); + max_values[i] = current_max; + } + max_values +} + +/// Creates an iterator for the given length, optionally reversing it. +/// +/// # Arguments +/// +/// * `len` - The length of the iterator. +/// * `reverse` - A boolean that determines the order of iteration. +/// - `false` for forward iteration. +/// - `true` for reverse iteration. +/// +/// # Returns +/// +/// A boxed iterator that iterates over the range of indices. +fn create_iter(len: usize, reverse: bool) -> Box> { + if reverse { + Box::new((0..len).rev()) + } else { + Box::new(0..len) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! trapped_rainwater_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (elevation_map, expected_trapped_water) = $test_case; + assert_eq!(trapped_rainwater(&elevation_map), expected_trapped_water); + let elevation_map_rev: Vec = elevation_map.iter().rev().cloned().collect(); + assert_eq!(trapped_rainwater(&elevation_map_rev), expected_trapped_water); + } + )* + }; + } + + trapped_rainwater_tests! { + test_trapped_rainwater_basic: ( + [0, 1, 0, 2, 1, 0, 1, 3, 2, 1, 2, 1], + 6 + ), + test_trapped_rainwater_peak_under_water: ( + [3, 0, 2, 0, 4], + 7, + ), + test_bucket: ( + [5, 1, 5], + 4 + ), + test_skewed_bucket: ( + [4, 1, 5], + 3 + ), + test_trapped_rainwater_empty: ( + [], + 0 + ), + test_trapped_rainwater_flat: ( + [0, 0, 0, 0, 0], + 0 + ), + test_trapped_rainwater_no_trapped_water: ( + [1, 1, 2, 4, 0, 0, 0], + 0 + ), + test_trapped_rainwater_single_elevation_map: ( + [5], + 0 + ), + test_trapped_rainwater_two_point_elevation_map: ( + [5, 1], + 0 + ), + test_trapped_rainwater_large_elevation_map_difference: ( + [5, 1, 6, 1, 7, 1, 8], + 15 + ), + } +} diff --git a/src/dynamic_programming/word_break.rs b/src/dynamic_programming/word_break.rs new file mode 100644 index 00000000000..f3153525f6e --- /dev/null +++ b/src/dynamic_programming/word_break.rs @@ -0,0 +1,89 @@ +use crate::data_structures::Trie; + +/// Checks if a string can be segmented into a space-separated sequence +/// of one or more words from the given dictionary. +/// +/// # Arguments +/// * `s` - The input string to be segmented. +/// * `word_dict` - A slice of words forming the dictionary. +/// +/// # Returns +/// * `bool` - `true` if the string can be segmented, `false` otherwise. +pub fn word_break(s: &str, word_dict: &[&str]) -> bool { + let mut trie = Trie::new(); + for &word in word_dict { + trie.insert(word.chars(), true); + } + + // Memoization vector: one extra space to handle out-of-bound end case. + let mut memo = vec![None; s.len() + 1]; + search(&trie, s, 0, &mut memo) +} + +/// Recursively checks if the substring starting from `start` can be segmented +/// using words in the trie and memoizes the results. +/// +/// # Arguments +/// * `trie` - The Trie containing the dictionary words. +/// * `s` - The input string. +/// * `start` - The starting index for the current substring. +/// * `memo` - A vector for memoization to store intermediate results. +/// +/// # Returns +/// * `bool` - `true` if the substring can be segmented, `false` otherwise. +fn search(trie: &Trie, s: &str, start: usize, memo: &mut Vec>) -> bool { + if start == s.len() { + return true; + } + + if let Some(res) = memo[start] { + return res; + } + + for end in start + 1..=s.len() { + if trie.get(s[start..end].chars()).is_some() && search(trie, s, end, memo) { + memo[start] = Some(true); + return true; + } + } + + memo[start] = Some(false); + false +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, dict, expected) = $test_case; + assert_eq!(word_break(input, &dict), expected); + } + )* + } + } + + test_cases! { + typical_case_1: ("applepenapple", vec!["apple", "pen"], true), + typical_case_2: ("catsandog", vec!["cats", "dog", "sand", "and", "cat"], false), + typical_case_3: ("cars", vec!["car", "ca", "rs"], true), + edge_case_empty_string: ("", vec!["apple", "pen"], true), + edge_case_empty_dict: ("apple", vec![], false), + edge_case_single_char_in_dict: ("a", vec!["a"], true), + edge_case_single_char_not_in_dict: ("b", vec!["a"], false), + edge_case_all_words_larger_than_input: ("a", vec!["apple", "banana"], false), + edge_case_no_solution_large_string: ("abcdefghijklmnoqrstuv", vec!["a", "bc", "def", "ghij", "klmno", "pqrst"], false), + successful_segmentation_large_string: ("abcdefghijklmnopqrst", vec!["a", "bc", "def", "ghij", "klmno", "pqrst"], true), + long_string_repeated_pattern: (&"ab".repeat(100), vec!["a", "b", "ab"], true), + long_string_no_solution: (&"a".repeat(100), vec!["b"], false), + mixed_size_dict_1: ("pineapplepenapple", vec!["apple", "pen", "applepen", "pine", "pineapple"], true), + mixed_size_dict_2: ("catsandog", vec!["cats", "dog", "sand", "and", "cat"], false), + mixed_size_dict_3: ("abcd", vec!["a", "abc", "b", "cd"], true), + performance_stress_test_large_valid: (&"abc".repeat(1000), vec!["a", "ab", "abc"], true), + performance_stress_test_large_invalid: (&"x".repeat(1000), vec!["a", "ab", "abc"], false), + } +} diff --git a/src/financial/mod.rs b/src/financial/mod.rs new file mode 100644 index 00000000000..89b36bfa5e0 --- /dev/null +++ b/src/financial/mod.rs @@ -0,0 +1,2 @@ +mod present_value; +pub use present_value::present_value; diff --git a/src/financial/present_value.rs b/src/financial/present_value.rs new file mode 100644 index 00000000000..5294b71758c --- /dev/null +++ b/src/financial/present_value.rs @@ -0,0 +1,91 @@ +/// In economics and finance, present value (PV), also known as present discounted value, +/// is the value of an expected income stream determined as of the date of valuation. +/// +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Present_value + +#[derive(PartialEq, Eq, Debug)] +pub enum PresentValueError { + NegetiveDiscount, + EmptyCashFlow, +} + +pub fn present_value(discount_rate: f64, cash_flows: Vec) -> Result { + if discount_rate < 0.0 { + return Err(PresentValueError::NegetiveDiscount); + } + if cash_flows.is_empty() { + return Err(PresentValueError::EmptyCashFlow); + } + + let present_value = cash_flows + .iter() + .enumerate() + .map(|(i, &cash_flow)| cash_flow / (1.0 + discount_rate).powi(i as i32)) + .sum::(); + + Ok(round(present_value)) +} + +fn round(value: f64) -> f64 { + (value * 100.0).round() / 100.0 +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_present_value { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let ((discount_rate,cash_flows), expected) = $inputs; + assert_eq!(present_value(discount_rate,cash_flows).unwrap(), expected); + } + )* + } + } + + macro_rules! test_present_value_Err { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let ((discount_rate,cash_flows), expected) = $inputs; + assert_eq!(present_value(discount_rate,cash_flows).unwrap_err(), expected); + } + )* + } + } + + macro_rules! test_round { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $inputs; + assert_eq!(round(input), expected); + } + )* + } + } + + test_present_value! { + general_inputs1:((0.13, vec![10.0, 20.70, -293.0, 297.0]),4.69), + general_inputs2:((0.07, vec![-109129.39, 30923.23, 15098.93, 29734.0, 39.0]),-42739.63), + general_inputs3:((0.07, vec![109129.39, 30923.23, 15098.93, 29734.0, 39.0]), 175519.15), + zero_input:((0.0, vec![109129.39, 30923.23, 15098.93, 29734.0, 39.0]), 184924.55), + + } + + test_present_value_Err! { + negative_discount_rate:((-1.0, vec![10.0, 20.70, -293.0, 297.0]), PresentValueError::NegetiveDiscount), + empty_cash_flow:((1.0, vec![]), PresentValueError::EmptyCashFlow), + + } + test_round! { + test1:(0.55434, 0.55), + test2:(10.453, 10.45), + test3:(1111_f64, 1111_f64), + } +} diff --git a/src/general/convex_hull.rs b/src/general/convex_hull.rs index 272004adee2..e2a073b4e57 100644 --- a/src/general/convex_hull.rs +++ b/src/general/convex_hull.rs @@ -5,9 +5,9 @@ fn sort_by_min_angle(pts: &[(f64, f64)], min: &(f64, f64)) -> Vec<(f64, f64)> { .iter() .map(|x| { ( - ((x.1 - min.1) as f64).atan2((x.0 - min.0) as f64), + (x.1 - min.1).atan2(x.0 - min.0), // angle - ((x.1 - min.1) as f64).hypot((x.0 - min.0) as f64), + (x.1 - min.1).hypot(x.0 - min.0), // distance (we want the closest to be first) *x, ) @@ -70,7 +70,7 @@ mod tests { #[test] fn empty() { - assert_eq!(convex_hull_graham(&vec![]), vec![]); + assert_eq!(convex_hull_graham(&[]), vec![]); } #[test] diff --git a/src/general/fisher_yates_shuffle.rs b/src/general/fisher_yates_shuffle.rs new file mode 100644 index 00000000000..8ba448b2105 --- /dev/null +++ b/src/general/fisher_yates_shuffle.rs @@ -0,0 +1,25 @@ +use std::time::{SystemTime, UNIX_EPOCH}; + +use crate::math::PCG32; + +const DEFAULT: u64 = 4294967296; + +fn gen_range(range: usize, generator: &mut PCG32) -> usize { + generator.get_u64() as usize % range +} + +pub fn fisher_yates_shuffle(array: &mut [i32]) { + let seed = match SystemTime::now().duration_since(UNIX_EPOCH) { + Ok(duration) => duration.as_millis() as u64, + Err(_) => DEFAULT, + }; + + let mut random_generator = PCG32::new_default(seed); + + let len = array.len(); + + for i in 0..(len - 2) { + let r = gen_range(len - i, &mut random_generator); + array.swap(i, i + r); + } +} diff --git a/src/general/genetic.rs b/src/general/genetic.rs new file mode 100644 index 00000000000..43221989a23 --- /dev/null +++ b/src/general/genetic.rs @@ -0,0 +1,461 @@ +use std::cmp::Ordering; +use std::collections::BTreeSet; +use std::fmt::Debug; + +/// The goal is to showcase how Genetic algorithms generically work +/// See: https://en.wikipedia.org/wiki/Genetic_algorithm for concepts + +/// This is the definition of a Chromosome for a genetic algorithm +/// We can picture this as "one contending solution to our problem" +/// It is generic over: +/// * Eval, which could be a float, or any other totally ordered type, so that we can rank solutions to our problem +/// * Rng: a random number generator (could be thread rng, etc.) +pub trait Chromosome { + /// Mutates this Chromosome, changing its genes + fn mutate(&mut self, rng: &mut Rng); + + /// Mixes this chromosome with another one + fn crossover(&self, other: &Self, rng: &mut Rng) -> Self; + + /// How well this chromosome fits the problem we're trying to solve + /// **The smaller the better it fits** (we could use abs(... - expected_value) for instance + fn fitness(&self) -> Eval; +} + +pub trait SelectionStrategy { + fn new(rng: Rng) -> Self; + + /// Selects a portion of the population for reproduction + /// Could be totally random ones or the ones that fit best, etc. + /// This assumes the population is sorted by how it fits the solution (the first the better) + fn select<'a, Eval: Into, C: Chromosome>( + &mut self, + population: &'a [C], + ) -> (&'a C, &'a C); +} + +/// A roulette wheel selection strategy +/// https://en.wikipedia.org/wiki/Fitness_proportionate_selection +pub struct RouletteWheel { + rng: Rng, +} +impl SelectionStrategy for RouletteWheel { + fn new(rng: Rng) -> Self { + Self { rng } + } + + fn select<'a, Eval: Into, C: Chromosome>( + &mut self, + population: &'a [C], + ) -> (&'a C, &'a C) { + // We will assign a probability for every item in the population, based on its proportion towards the sum of all fitness + // This would work well for an increasing fitness function, but not in our case of a fitness function for which "lower is better" + // We thus need to take the reciprocal + let mut parents = Vec::with_capacity(2); + let fitnesses: Vec = population + .iter() + .filter_map(|individual| { + let fitness = individual.fitness().into(); + if individual.fitness().into() == 0.0 { + parents.push(individual); + None + } else { + Some(1.0 / fitness) + } + }) + .collect(); + if parents.len() == 2 { + return (parents[0], parents[1]); + } + let sum: f64 = fitnesses.iter().sum(); + let mut spin = self.rng.random_range(0.0..=sum); + for individual in population { + let fitness: f64 = individual.fitness().into(); + if spin <= fitness { + parents.push(individual); + if parents.len() == 2 { + return (parents[0], parents[1]); + } + } else { + spin -= fitness; + } + } + panic!("Could not select parents"); + } +} + +pub struct Tournament { + rng: Rng, +} +impl SelectionStrategy for Tournament { + fn new(rng: Rng) -> Self { + Self { rng } + } + + fn select<'a, Eval, C: Chromosome>( + &mut self, + population: &'a [C], + ) -> (&'a C, &'a C) { + if K < 2 { + panic!("K must be > 2"); + } + // This strategy is defined as the following: pick K chromosomes randomly, use the 2 that fits the best + // We assume the population is sorted + // This means we can draw K random (distinct) numbers between (0..population.len()) and return the chromosomes at the 2 lowest indices + let mut picked_indices = BTreeSet::new(); // will keep indices ordered + while picked_indices.len() < K { + picked_indices.insert(self.rng.random_range(0..population.len())); + } + let mut iter = picked_indices.into_iter(); + ( + &population[iter.next().unwrap()], + &population[iter.next().unwrap()], + ) + } +} + +type Comparator = Box Ordering>; +pub struct GeneticAlgorithm< + Rng: rand::Rng, + Eval: PartialOrd, + C: Chromosome, + Selection: SelectionStrategy, +> { + rng: Rng, // will be used to draw random numbers for initial population, mutations and crossovers + population: Vec, // the set of random solutions (chromosomes) + threshold: Eval, // Any chromosome fitting over this threshold is considered a valid solution + max_generations: usize, // So that we don't loop infinitely + mutation_chance: f64, // what's the probability a chromosome will mutate + crossover_chance: f64, // what's the probability two chromosomes will cross-over and give birth to a new chromosome + compare: Comparator, + selection: Selection, // how we will select parent chromosomes for crossing over, see `SelectionStrategy` +} + +pub struct GenericAlgorithmParams { + max_generations: usize, + mutation_chance: f64, + crossover_chance: f64, +} + +impl< + Rng: rand::Rng, + Eval: Into + PartialOrd + Debug, + C: Chromosome + Clone + Debug, + Selection: SelectionStrategy, + > GeneticAlgorithm +{ + pub fn init( + rng: Rng, + population: Vec, + threshold: Eval, + params: GenericAlgorithmParams, + compare: Comparator, + selection: Selection, + ) -> Self { + let GenericAlgorithmParams { + max_generations, + mutation_chance, + crossover_chance, + } = params; + Self { + rng, + population, + threshold, + max_generations, + mutation_chance, + crossover_chance, + compare, + selection, + } + } + + pub fn solve(&mut self) -> Option { + let mut generations = 1; // 1st generation is our initial population + while generations <= self.max_generations { + // 1. Sort the population by fitness score, remember: the lower the better (so natural ordering) + self.population + .sort_by(|c1: &C, c2: &C| (self.compare)(&c1.fitness(), &c2.fitness())); + + // 2. Stop condition: we might have found a good solution + if let Some(solution) = self.population.first() { + if solution.fitness() <= self.threshold { + return Some(solution).cloned(); + } + } + + // 3. Apply random mutations to the whole population + for chromosome in self.population.iter_mut() { + if self.rng.random::() <= self.mutation_chance { + chromosome.mutate(&mut self.rng); + } + } + // 4. Select parents that will be mating to create new chromosomes + let mut new_population = Vec::with_capacity(self.population.len() + 1); + while new_population.len() < self.population.len() { + let (p1, p2) = self.selection.select(&self.population); + if self.rng.random::() <= self.crossover_chance { + let child = p1.crossover(p2, &mut self.rng); + new_population.push(child); + } else { + // keep parents + new_population.extend([p1.clone(), p2.clone()]); + } + } + if new_population.len() > self.population.len() { + // We might have added 2 parents + new_population.pop(); + } + self.population = new_population; + // 5. Rinse & Repeat until we find a proper solution or we reach the maximum number of generations + generations += 1; + } + None + } +} + +#[cfg(test)] +mod tests { + use crate::general::genetic::{ + Chromosome, GenericAlgorithmParams, GeneticAlgorithm, RouletteWheel, SelectionStrategy, + Tournament, + }; + use rand::rngs::ThreadRng; + use rand::{rng, Rng}; + use std::collections::HashMap; + use std::fmt::{Debug, Formatter}; + use std::ops::RangeInclusive; + + #[test] + #[ignore] // Too long and not deterministic enough to be part of CI, more of an example than a test + fn find_secret() { + let chars = 'a'..='z'; + let secret = "thisistopsecret".to_owned(); + // Note: we'll pick genes (a, b, c) in the range -10, 10 + #[derive(Clone)] + struct TestString { + chars: RangeInclusive, + secret: String, + genes: Vec, + } + impl TestString { + fn new(rng: &mut ThreadRng, secret: String, chars: RangeInclusive) -> Self { + let current = (0..secret.len()) + .map(|_| rng.random_range(chars.clone())) + .collect::>(); + + Self { + chars, + secret, + genes: current, + } + } + } + impl Debug for TestString { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.genes.iter().collect::()) + } + } + impl Chromosome for TestString { + fn mutate(&mut self, rng: &mut ThreadRng) { + // let's assume mutations happen completely randomly, one "gene" at a time (i.e. one char at a time) + let gene_idx = rng.random_range(0..self.secret.len()); + let new_char = rng.random_range(self.chars.clone()); + self.genes[gene_idx] = new_char; + } + + fn crossover(&self, other: &Self, rng: &mut ThreadRng) -> Self { + // Let's not assume anything here, simply mixing random genes from both parents + let genes = (0..self.secret.len()) + .map(|idx| { + if rng.random_bool(0.5) { + // pick gene from self + self.genes[idx] + } else { + // pick gene from other parent + other.genes[idx] + } + }) + .collect(); + Self { + chars: self.chars.clone(), + secret: self.secret.clone(), + genes, + } + } + + fn fitness(&self) -> i32 { + // We are just counting how many chars are distinct from secret + self.genes + .iter() + .zip(self.secret.chars()) + .filter(|(char, expected)| expected != *char) + .count() as i32 + } + } + let mut rng = rng(); + let pop_count = 1_000; + let mut population = Vec::with_capacity(pop_count); + for _ in 0..pop_count { + population.push(TestString::new(&mut rng, secret.clone(), chars.clone())); + } + let selection: Tournament<100, ThreadRng> = Tournament::new(rng.clone()); + let params = GenericAlgorithmParams { + max_generations: 100, + mutation_chance: 0.2, + crossover_chance: 0.4, + }; + let mut solver = + GeneticAlgorithm::init(rng, population, 0, params, Box::new(i32::cmp), selection); + let res = solver.solve(); + assert!(res.is_some()); + assert_eq!(res.unwrap().genes, secret.chars().collect::>()) + } + + #[test] + #[ignore] // Too long and not deterministic enough to be part of CI, more of an example than a test + fn solve_mastermind() { + #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] + enum ColoredPeg { + Red, + Yellow, + Green, + Blue, + White, + Black, + } + struct GuessAnswer { + right_pos: i32, // right color at the right pos + wrong_pos: i32, // right color, but at wrong pos + } + #[derive(Clone, Debug)] + struct CodeMaker { + // the player coming up with a secret code + code: [ColoredPeg; 4], + count_by_color: HashMap, + } + impl CodeMaker { + fn new(code: [ColoredPeg; 4]) -> Self { + let mut count_by_color = HashMap::with_capacity(4); + for peg in &code { + *count_by_color.entry(*peg).or_insert(0) += 1; + } + Self { + code, + count_by_color, + } + } + fn eval(&self, guess: &[ColoredPeg; 4]) -> GuessAnswer { + let mut right_pos = 0; + let mut wrong_pos = 0; + let mut idx_by_colors = self.count_by_color.clone(); + for (idx, color) in guess.iter().enumerate() { + if self.code[idx] == *color { + right_pos += 1; + let count = idx_by_colors.get_mut(color).unwrap(); + *count -= 1; // don't reuse to say "right color but wrong pos" + if *count == 0 { + idx_by_colors.remove(color); + } + } + } + for (idx, color) in guess.iter().enumerate() { + if self.code[idx] != *color { + // try to use another color + if let Some(count) = idx_by_colors.get_mut(color) { + *count -= 1; + if *count == 0 { + idx_by_colors.remove(color); + } + wrong_pos += 1; + } + } + } + GuessAnswer { + right_pos, + wrong_pos, + } + } + } + + #[derive(Clone)] + struct CodeBreaker { + maker: CodeMaker, // so that we can ask the code maker if our guess is good or not + guess: [ColoredPeg; 4], + } + impl Debug for CodeBreaker { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(format!("{:?}", self.guess).as_str()) + } + } + fn random_color(rng: &mut ThreadRng) -> ColoredPeg { + match rng.random_range(0..=5) { + 0 => ColoredPeg::Red, + 1 => ColoredPeg::Yellow, + 2 => ColoredPeg::Green, + 3 => ColoredPeg::Blue, + 4 => ColoredPeg::White, + _ => ColoredPeg::Black, + } + } + fn random_guess(rng: &mut ThreadRng) -> [ColoredPeg; 4] { + std::array::from_fn(|_| random_color(rng)) + } + impl Chromosome for CodeBreaker { + fn mutate(&mut self, rng: &mut ThreadRng) { + // change one random color + let idx = rng.random_range(0..4); + self.guess[idx] = random_color(rng); + } + + fn crossover(&self, other: &Self, rng: &mut ThreadRng) -> Self { + Self { + maker: self.maker.clone(), + guess: std::array::from_fn(|i| { + if rng.random::() < 0.5 { + self.guess[i] + } else { + other.guess[i] + } + }), + } + } + + fn fitness(&self) -> i32 { + // Ask the code maker for the result + let answer = self.maker.eval(&self.guess); + // Remember: we need to have fitness return 0 if the guess is good, and the higher number we return, the further we are from a proper solution + let mut res = 32; // worst case scenario, everything is wrong + res -= answer.right_pos * 8; // count 8 points for the right item at the right spot + res -= answer.wrong_pos; // count 1 point for having a right color + res + } + } + let code = [ + ColoredPeg::Red, + ColoredPeg::Red, + ColoredPeg::White, + ColoredPeg::Blue, + ]; + let maker = CodeMaker::new(code); + let population_count = 10; + let params = GenericAlgorithmParams { + max_generations: 100, + mutation_chance: 0.5, + crossover_chance: 0.3, + }; + let mut rng = rng(); + let mut initial_pop = Vec::with_capacity(population_count); + for _ in 0..population_count { + initial_pop.push(CodeBreaker { + maker: maker.clone(), + guess: random_guess(&mut rng), + }); + } + let selection = RouletteWheel { rng: rng.clone() }; + let mut solver = + GeneticAlgorithm::init(rng, initial_pop, 0, params, Box::new(i32::cmp), selection); + let res = solver.solve(); + assert!(res.is_some()); + assert_eq!(code, res.unwrap().guess); + } +} diff --git a/src/general/huffman_encoding.rs b/src/general/huffman_encoding.rs index 5332592446a..fc26d3cb5ee 100644 --- a/src/general/huffman_encoding.rs +++ b/src/general/huffman_encoding.rs @@ -28,7 +28,7 @@ impl PartialEq for HuffmanNode { impl PartialOrd for HuffmanNode { fn partial_cmp(&self, other: &Self) -> Option { - Some(self.frequency.cmp(&other.frequency).reverse()) + Some(self.cmp(other)) } } @@ -43,7 +43,6 @@ impl Ord for HuffmanNode { impl HuffmanNode { /// Turn the tree into the map that can be used in encoding pub fn get_alphabet( - &self, height: u32, path: u64, node: &HuffmanNode, @@ -60,8 +59,8 @@ impl HuffmanNode { ); } None => { - self.get_alphabet(height + 1, path, node.left.as_ref().unwrap(), map); - self.get_alphabet( + Self::get_alphabet(height + 1, path, node.left.as_ref().unwrap(), map); + Self::get_alphabet( height + 1, path | (1 << height), node.right.as_ref().unwrap(), @@ -103,7 +102,7 @@ impl HuffmanDictionary { }); } let root = queue.pop().unwrap(); - root.get_alphabet(0, 0, &root, &mut alph); + HuffmanNode::get_alphabet(0, 0, &root, &mut alph); HuffmanDictionary { alphabet: alph, root, @@ -112,7 +111,7 @@ impl HuffmanDictionary { pub fn encode(&self, data: &[T]) -> HuffmanEncoding { let mut result = HuffmanEncoding::new(); data.iter() - .for_each(|value| result.add_data(*self.alphabet.get(value).unwrap())); + .for_each(|value| result.add_data(self.alphabet[value])); result } } @@ -156,9 +155,10 @@ impl HuffmanEncoding { result.push(state.symbol.unwrap()); state = &dict.root; } - match self.get_bit(i) { - false => state = state.left.as_ref().unwrap(), - true => state = state.right.as_ref().unwrap(), + state = if self.get_bit(i) { + state.right.as_ref().unwrap() + } else { + state.left.as_ref().unwrap() } } if self.num_bits > 0 { diff --git a/src/general/kadane_algorithm.rs b/src/general/kadane_algorithm.rs new file mode 100644 index 00000000000..a7452b62d23 --- /dev/null +++ b/src/general/kadane_algorithm.rs @@ -0,0 +1,86 @@ +/** + * @file + * @brief Find the maximum subarray sum using Kadane's algorithm.(https://en.wikipedia.org/wiki/Maximum_subarray_problem) + * + * @details + * This program provides a function to find the maximum subarray sum in an array of integers + * using Kadane's algorithm. + * + * @param arr A slice of integers representing the array. + * @return The maximum subarray sum. + * + * @author [Gyandeep] (https://github.com/Gyan172004) + * @see Wikipedia - Maximum subarray problem + */ + +/** + * Find the maximum subarray sum using Kadane's algorithm. + * @param arr A slice of integers representing the array. + * @return The maximum subarray sum. + */ +pub fn max_sub_array(nums: Vec) -> i32 { + if nums.is_empty() { + return 0; + } + + let mut max_current = nums[0]; + let mut max_global = nums[0]; + + nums.iter().skip(1).for_each(|&item| { + max_current = std::cmp::max(item, max_current + item); + if max_current > max_global { + max_global = max_current; + } + }); + max_global +} + +#[cfg(test)] +mod tests { + use super::*; + + /** + * Test case for Kadane's algorithm with positive numbers. + */ + #[test] + fn test_kadanes_algorithm_positive() { + let arr = [1, 2, 3, 4, 5]; + assert_eq!(max_sub_array(arr.to_vec()), 15); + } + + /** + * Test case for Kadane's algorithm with negative numbers. + */ + #[test] + fn test_kadanes_algorithm_negative() { + let arr = [-2, -3, -4, -1, -2]; + assert_eq!(max_sub_array(arr.to_vec()), -1); + } + + /** + * Test case for Kadane's algorithm with mixed numbers. + */ + #[test] + fn test_kadanes_algorithm_mixed() { + let arr = [-2, 1, -3, 4, -1, 2, 1, -5, 4]; + assert_eq!(max_sub_array(arr.to_vec()), 6); + } + + /** + * Test case for Kadane's algorithm with an empty array. + */ + #[test] + fn test_kadanes_algorithm_empty() { + let arr: [i32; 0] = []; + assert_eq!(max_sub_array(arr.to_vec()), 0); + } + + /** + * Test case for Kadane's algorithm with a single positive number. + */ + #[test] + fn test_kadanes_algorithm_single_positive() { + let arr = [10]; + assert_eq!(max_sub_array(arr.to_vec()), 10); + } +} diff --git a/src/general/kmeans.rs b/src/general/kmeans.rs index 1d9162ae35e..54022c36dd7 100644 --- a/src/general/kmeans.rs +++ b/src/general/kmeans.rs @@ -4,7 +4,6 @@ macro_rules! impl_kmeans { ($kind: ty, $modname: ident) => { // Since we can't overload methods in rust, we have to use namespacing pub mod $modname { - use std::$modname::INFINITY; /// computes sum of squared deviation between two identically sized vectors /// `x`, and `y`. @@ -22,7 +21,7 @@ macro_rules! impl_kmeans { // Find the argmin by folding using a tuple containing the argmin // and the minimum distance. let (argmin, _) = centroids.iter().enumerate().fold( - (0_usize, INFINITY), + (0_usize, <$kind>::INFINITY), |(min_ix, min_dist), (ix, ci)| { let dist = distance(xi, ci); if dist < min_dist { @@ -89,9 +88,8 @@ macro_rules! impl_kmeans { { // We need to use `return` to break out of the `loop` return clustering; - } else { - clustering = new_clustering; } + clustering = new_clustering; } } } diff --git a/src/general/mex.rs b/src/general/mex.rs new file mode 100644 index 00000000000..a0514a35c54 --- /dev/null +++ b/src/general/mex.rs @@ -0,0 +1,79 @@ +use std::collections::BTreeSet; + +// Find minimum excluded number from a set of given numbers using a set +/// Finds the MEX of the values provided in `arr` +/// Uses [`BTreeSet`](std::collections::BTreeSet) +/// O(nlog(n)) implementation +pub fn mex_using_set(arr: &[i64]) -> i64 { + let mut s: BTreeSet = BTreeSet::new(); + for i in 0..=arr.len() { + s.insert(i as i64); + } + for x in arr { + s.remove(x); + } + // TODO: change the next 10 lines to *s.first().unwrap() when merged into stable + // set should never have 0 elements + if let Some(x) = s.into_iter().next() { + x + } else { + panic!("Some unknown error in mex_using_set") + } +} +/// Finds the MEX of the values provided in `arr` +/// Uses sorting +/// O(nlog(n)) implementation +pub fn mex_using_sort(arr: &[i64]) -> i64 { + let mut arr = arr.to_vec(); + arr.sort(); + let mut mex = 0; + for x in arr { + if x == mex { + mex += 1; + } + } + mex +} + +#[cfg(test)] +mod tests { + use super::*; + struct MexTests { + test_arrays: Vec>, + outputs: Vec, + } + impl MexTests { + fn new() -> Self { + Self { + test_arrays: vec![ + vec![-1, 0, 1, 2, 3], + vec![-100, 0, 1, 2, 3, 5], + vec![-1000000, 0, 1, 2, 5], + vec![2, 0, 1, 2, 4], + vec![1, 2, 3, 0, 4], + vec![0, 1, 5, 2, 4, 3], + vec![0, 1, 2, 3, 4, 5, 6], + vec![0, 1, 2, 3, 4, 5, 6, 7], + vec![0, 1, 2, 3, 4, 5, 6, 7, 8], + ], + outputs: vec![4, 4, 3, 3, 5, 6, 7, 8, 9], + } + } + fn test_function(&self, f: fn(&[i64]) -> i64) { + for (nums, output) in self.test_arrays.iter().zip(self.outputs.iter()) { + assert_eq!(f(nums), *output); + } + } + } + #[test] + fn test_mex_using_set() { + let tests = MexTests::new(); + mex_using_set(&[1, 23, 3]); + tests.test_function(mex_using_set); + } + #[test] + fn test_mex_using_sort() { + let tests = MexTests::new(); + tests.test_function(mex_using_sort); + } +} diff --git a/src/general/mod.rs b/src/general/mod.rs index eded2921566..3572b146f4a 100644 --- a/src/general/mod.rs +++ b/src/general/mod.rs @@ -1,14 +1,25 @@ mod convex_hull; +mod fisher_yates_shuffle; +mod genetic; mod hanoi; mod huffman_encoding; +mod kadane_algorithm; mod kmeans; -mod nqueens; +mod mex; +mod permutations; mod two_sum; pub use self::convex_hull::convex_hull_graham; +pub use self::fisher_yates_shuffle::fisher_yates_shuffle; +pub use self::genetic::GeneticAlgorithm; pub use self::hanoi::hanoi; pub use self::huffman_encoding::{HuffmanDictionary, HuffmanEncoding}; +pub use self::kadane_algorithm::max_sub_array; pub use self::kmeans::f32::kmeans as kmeans_f32; pub use self::kmeans::f64::kmeans as kmeans_f64; -pub use self::nqueens::nqueens; +pub use self::mex::mex_using_set; +pub use self::mex::mex_using_sort; +pub use self::permutations::{ + heap_permute, permute, permute_unique, steinhaus_johnson_trotter_permute, +}; pub use self::two_sum::two_sum; diff --git a/src/general/nqueens.rs b/src/general/nqueens.rs deleted file mode 100644 index 6eb91756fcc..00000000000 --- a/src/general/nqueens.rs +++ /dev/null @@ -1,148 +0,0 @@ -#[allow(unused_imports)] -use std::env::args; - -#[allow(dead_code)] -fn main() { - let mut board_width = 0; - - for arg in args() { - board_width = match arg.parse() { - Ok(x) => x, - _ => 0, - }; - - if board_width != 0 { - break; - } - } - - if board_width < 4 { - println!( - "Running algorithm with 8 as a default. Specify an alternative Chess board size for \ - N-Queens as a command line argument.\n" - ); - board_width = 8; - } - - let board = match nqueens(board_width) { - Ok(success) => success, - Err(err) => panic!("{}", err), - }; - - println!("N-Queens {} by {} board result:", board_width, board_width); - print_board(&board); -} - -/* -The n-Queens search is a backtracking algorithm. Each row of the Chess board where a Queen is -placed is dependent on all earlier rows. As only one Queen can fit per row, a one-dimensional -integer array is used to represent the Queen's offset on each row. -*/ -pub fn nqueens(board_width: i64) -> Result, &'static str> { - let mut board_rows = vec![0; board_width as usize]; - let mut conflict; - let mut current_row = 0; - - //Process by row up to the current active row - loop { - conflict = false; - - //Column review of previous rows - for review_index in 0..current_row { - //Calculate the diagonals of earlier rows where a Queen would be a conflict - let left = board_rows[review_index] - (current_row as i64 - review_index as i64); - let right = board_rows[review_index] + (current_row as i64 - review_index as i64); - - if board_rows[current_row] == board_rows[review_index] - || (left >= 0 && left == board_rows[current_row]) - || (right < board_width as i64 && right == board_rows[current_row]) - { - conflict = true; - break; - } - } - - match conflict { - true => { - board_rows[current_row] += 1; - - if current_row == 0 && board_rows[current_row] == board_width { - return Err("No solution exists for specificed board size."); - } - - while board_rows[current_row] == board_width { - board_rows[current_row] = 0; - - if current_row == 0 { - return Err("No solution exists for specificed board size."); - } - - current_row -= 1; - board_rows[current_row] += 1; - } - } - _ => { - current_row += 1; - - if current_row as i64 == board_width { - break; - } - } - } - } - - Ok(board_rows) -} - -fn print_board(board: &[i64]) { - for row in 0..board.len() { - print!("{}\t", board[row as usize]); - - for column in 0..board.len() as i64 { - if board[row as usize] == column { - print!("Q"); - } else { - print!("."); - } - } - println!(); - } -} - -#[cfg(test)] -mod test { - use super::*; - - fn check_board(board: &Vec) -> bool { - for current_row in 0..board.len() { - //Column review - for review_index in 0..current_row { - //Look for any conflict. - let left = board[review_index] - (current_row as i64 - review_index as i64); - let right = board[review_index] + (current_row as i64 - review_index as i64); - - if board[current_row] == board[review_index] - || (left >= 0 && left == board[current_row]) - || (right < board.len() as i64 && right == board[current_row]) - { - return false; - } - } - } - true - } - - #[test] - fn test_board_size_4() { - let board = nqueens(4).expect("Error propagated."); - assert_eq!(board, vec![1, 3, 0, 2]); - assert!(check_board(&board)); - } - - #[test] - fn test_board_size_7() { - let board = nqueens(7).expect("Error propagated."); - assert_eq!(board, vec![0, 2, 4, 6, 1, 3, 5]); - assert!(check_board(&board)); - } -} diff --git a/src/general/permutations/heap.rs b/src/general/permutations/heap.rs new file mode 100644 index 00000000000..b1c3b38d198 --- /dev/null +++ b/src/general/permutations/heap.rs @@ -0,0 +1,66 @@ +use std::fmt::Debug; + +/// Computes all permutations of an array using Heap's algorithm +/// Read `recurse_naive` first, since we're building on top of the same intuition +pub fn heap_permute(arr: &[T]) -> Vec> { + if arr.is_empty() { + return vec![vec![]]; + } + let n = arr.len(); + let mut collector = Vec::with_capacity((1..=n).product()); // collects the permuted arrays + let mut arr = arr.to_owned(); // Heap's algorithm needs to mutate the array + heap_recurse(&mut arr, n, &mut collector); + collector +} + +fn heap_recurse(arr: &mut [T], k: usize, collector: &mut Vec>) { + if k == 1 { + // same base-case as in the naive version + collector.push((*arr).to_owned()); + return; + } + // Remember the naive recursion. We did the following: swap(i, last), recurse, swap back(i, last) + // Heap's algorithm has a more clever way of permuting the elements so that we never need to swap back! + for i in 0..k { + // now deal with [a, b] + let swap_idx = if k % 2 == 0 { i } else { 0 }; + arr.swap(swap_idx, k - 1); + heap_recurse(arr, k - 1, collector); + } +} + +#[cfg(test)] +mod tests { + use quickcheck_macros::quickcheck; + + use crate::general::permutations::heap_permute; + use crate::general::permutations::tests::{ + assert_permutations, assert_valid_permutation, NotTooBigVec, + }; + + #[test] + fn test_3_different_values() { + let original = vec![1, 2, 3]; + let res = heap_permute(&original); + assert_eq!(res.len(), 6); // 3! + for permut in res { + assert_valid_permutation(&original, &permut) + } + } + + #[test] + fn test_3_times_the_same_value() { + let original = vec![1, 1, 1]; + let res = heap_permute(&original); + assert_eq!(res.len(), 6); // 3! + for permut in res { + assert_valid_permutation(&original, &permut) + } + } + + #[quickcheck] + fn test_some_elements(NotTooBigVec { inner: original }: NotTooBigVec) { + let permutations = heap_permute(&original); + assert_permutations(&original, &permutations) + } +} diff --git a/src/general/permutations/mod.rs b/src/general/permutations/mod.rs new file mode 100644 index 00000000000..0893eacb8ec --- /dev/null +++ b/src/general/permutations/mod.rs @@ -0,0 +1,93 @@ +mod heap; +mod naive; +mod steinhaus_johnson_trotter; + +pub use self::heap::heap_permute; +pub use self::naive::{permute, permute_unique}; +pub use self::steinhaus_johnson_trotter::steinhaus_johnson_trotter_permute; + +#[cfg(test)] +mod tests { + use quickcheck::{Arbitrary, Gen}; + use std::collections::HashMap; + + pub fn assert_permutations(original: &[i32], permutations: &[Vec]) { + if original.is_empty() { + assert_eq!(vec![vec![] as Vec], permutations); + return; + } + let n = original.len(); + assert_eq!((1..=n).product::(), permutations.len()); // n! + for permut in permutations { + assert_valid_permutation(original, permut); + } + } + + pub fn assert_valid_permutation(original: &[i32], permuted: &[i32]) { + assert_eq!(original.len(), permuted.len()); + let mut indices = HashMap::with_capacity(original.len()); + for value in original { + *indices.entry(*value).or_insert(0) += 1; + } + for permut_value in permuted { + let count = indices.get_mut(permut_value).unwrap_or_else(|| { + panic!("Value {permut_value} appears too many times in permutation") + }); + *count -= 1; // use this value + if *count == 0 { + indices.remove(permut_value); // so that we can simply check every value has been removed properly + } + } + assert!(indices.is_empty()) + } + + #[test] + fn test_valid_permutations() { + assert_valid_permutation(&[1, 2, 3], &[1, 2, 3]); + assert_valid_permutation(&[1, 2, 3], &[1, 3, 2]); + assert_valid_permutation(&[1, 2, 3], &[2, 1, 3]); + assert_valid_permutation(&[1, 2, 3], &[2, 3, 1]); + assert_valid_permutation(&[1, 2, 3], &[3, 1, 2]); + assert_valid_permutation(&[1, 2, 3], &[3, 2, 1]); + } + + #[test] + #[should_panic] + fn test_invalid_permutation_1() { + assert_valid_permutation(&[1, 2, 3], &[4, 2, 3]); + } + + #[test] + #[should_panic] + fn test_invalid_permutation_2() { + assert_valid_permutation(&[1, 2, 3], &[1, 4, 3]); + } + + #[test] + #[should_panic] + fn test_invalid_permutation_3() { + assert_valid_permutation(&[1, 2, 3], &[1, 2, 4]); + } + + #[test] + #[should_panic] + fn test_invalid_permutation_repeat() { + assert_valid_permutation(&[1, 2, 3], &[1, 2, 2]); + } + + /// A Data Structure for testing permutations + /// Holds a Vec with just a few items, so that it's not too long to compute permutations + #[derive(Debug, Clone)] + pub struct NotTooBigVec { + pub(crate) inner: Vec, // opaque type alias so that we can implement Arbitrary + } + + const MAX_SIZE: usize = 8; // 8! ~= 40k permutations already + impl Arbitrary for NotTooBigVec { + fn arbitrary(g: &mut Gen) -> Self { + let size = usize::arbitrary(g) % MAX_SIZE; + let res = (0..size).map(|_| i32::arbitrary(g)).collect(); + NotTooBigVec { inner: res } + } + } +} diff --git a/src/general/permutations/naive.rs b/src/general/permutations/naive.rs new file mode 100644 index 00000000000..748d763555e --- /dev/null +++ b/src/general/permutations/naive.rs @@ -0,0 +1,135 @@ +use std::collections::HashSet; +use std::fmt::Debug; +use std::hash::Hash; + +/// Here's a basic (naive) implementation for generating permutations +pub fn permute(arr: &[T]) -> Vec> { + if arr.is_empty() { + return vec![vec![]]; + } + let n = arr.len(); + let count = (1..=n).product(); // n! permutations + let mut collector = Vec::with_capacity(count); // collects the permuted arrays + let mut arr = arr.to_owned(); // we'll need to mutate the array + + // the idea is the following: imagine [a, b, c] + // always swap an item with the last item, then generate all permutations from the first k characters + // permute_recurse(arr, k - 1, collector); // leave the last character alone, and permute the first k-1 characters + permute_recurse(&mut arr, n, &mut collector); + collector +} + +fn permute_recurse(arr: &mut Vec, k: usize, collector: &mut Vec>) { + if k == 1 { + collector.push(arr.to_owned()); + return; + } + for i in 0..k { + arr.swap(i, k - 1); // swap i with the last character + permute_recurse(arr, k - 1, collector); // collect the permutations of the rest + arr.swap(i, k - 1); // swap back to original + } +} + +/// A common variation of generating permutations is to generate only unique permutations +/// Of course, we could use the version above together with a Set as collector instead of a Vec. +/// But let's try something different: how can we avoid to generate duplicated permutations in the first place, can we tweak the algorithm above? +pub fn permute_unique(arr: &[T]) -> Vec> { + if arr.is_empty() { + return vec![vec![]]; + } + let n = arr.len(); + let count = (1..=n).product(); // n! permutations + let mut collector = Vec::with_capacity(count); // collects the permuted arrays + let mut arr = arr.to_owned(); // Heap's algorithm needs to mutate the array + permute_recurse_unique(&mut arr, n, &mut collector); + collector +} + +fn permute_recurse_unique( + arr: &mut Vec, + k: usize, + collector: &mut Vec>, +) { + // We have the same base-case as previously, whenever we reach the first element in the array, collect the result + if k == 1 { + collector.push(arr.to_owned()); + return; + } + // We'll keep the same idea (swap with last item, and generate all permutations for the first k - 1) + // But we'll have to be careful though: how would we generate duplicates? + // Basically if, when swapping i with k-1, we generate the exact same array as in a previous iteration + // Imagine [a, a, b] + // i = 0: + // Swap (a, b) => [b, a, a], fix 'a' as last, and generate all permutations of [b, a] => [b, a, a], [a, b, a] + // Swap Back to [a, a, b] + // i = 1: + // Swap(a, b) => [b, a, a], we've done that already!! + let mut swapped = HashSet::with_capacity(k); + for i in 0..k { + if swapped.contains(&arr[i]) { + continue; + } + swapped.insert(arr[i]); + arr.swap(i, k - 1); // swap i with the last character + permute_recurse_unique(arr, k - 1, collector); // collect the permutations + arr.swap(i, k - 1); // go back to original + } +} + +#[cfg(test)] +mod tests { + use crate::general::permutations::naive::{permute, permute_unique}; + use crate::general::permutations::tests::{ + assert_permutations, assert_valid_permutation, NotTooBigVec, + }; + use quickcheck_macros::quickcheck; + use std::collections::HashSet; + + #[test] + fn test_3_different_values() { + let original = vec![1, 2, 3]; + let res = permute(&original); + assert_eq!(res.len(), 6); // 3! + for permut in res { + assert_valid_permutation(&original, &permut) + } + } + + #[test] + fn empty_array() { + let empty: std::vec::Vec = vec![]; + assert_eq!(permute(&empty), vec![vec![]]); + assert_eq!(permute_unique(&empty), vec![vec![]]); + } + + #[test] + fn test_3_times_the_same_value() { + let original = vec![1, 1, 1]; + let res = permute(&original); + assert_eq!(res.len(), 6); // 3! + for permut in res { + assert_valid_permutation(&original, &permut) + } + } + + #[quickcheck] + fn test_some_elements(NotTooBigVec { inner: original }: NotTooBigVec) { + let permutations = permute(&original); + assert_permutations(&original, &permutations) + } + + #[test] + fn test_unique_values() { + let original = vec![1, 1, 2, 2]; + let unique_permutations = permute_unique(&original); + let every_permutation = permute(&original); + for unique_permutation in &unique_permutations { + assert!(every_permutation.contains(unique_permutation)); + } + assert_eq!( + unique_permutations.len(), + every_permutation.iter().collect::>().len() + ) + } +} diff --git a/src/general/permutations/steinhaus_johnson_trotter.rs b/src/general/permutations/steinhaus_johnson_trotter.rs new file mode 100644 index 00000000000..4784a927ccc --- /dev/null +++ b/src/general/permutations/steinhaus_johnson_trotter.rs @@ -0,0 +1,61 @@ +/// +pub fn steinhaus_johnson_trotter_permute(array: &[T]) -> Vec> { + let len = array.len(); + let mut array = array.to_owned(); + let mut inversion_vector = vec![0; len]; + let mut i = 1; + let mut res = Vec::with_capacity((1..=len).product()); + res.push(array.clone()); + while i < len { + if inversion_vector[i] < i { + if i % 2 == 0 { + array.swap(0, i); + } else { + array.swap(inversion_vector[i], i); + } + res.push(array.to_vec()); + inversion_vector[i] += 1; + i = 1; + } else { + inversion_vector[i] = 0; + i += 1; + } + } + res +} + +#[cfg(test)] +mod tests { + use quickcheck_macros::quickcheck; + + use crate::general::permutations::steinhaus_johnson_trotter::steinhaus_johnson_trotter_permute; + use crate::general::permutations::tests::{ + assert_permutations, assert_valid_permutation, NotTooBigVec, + }; + + #[test] + fn test_3_different_values() { + let original = vec![1, 2, 3]; + let res = steinhaus_johnson_trotter_permute(&original); + assert_eq!(res.len(), 6); // 3! + for permut in res { + assert_valid_permutation(&original, &permut) + } + } + + #[test] + fn test_3_times_the_same_value() { + let original = vec![1, 1, 1]; + let res = steinhaus_johnson_trotter_permute(&original); + assert_eq!(res.len(), 6); // 3! + for permut in res { + assert_valid_permutation(&original, &permut) + } + } + + #[quickcheck] + fn test_some_elements(NotTooBigVec { inner: original }: NotTooBigVec) { + let permutations = steinhaus_johnson_trotter_permute(&original); + assert_permutations(&original, &permutations) + } +} diff --git a/src/geometry/closest_points.rs b/src/geometry/closest_points.rs index 9cbe1f55d2c..e92dc562501 100644 --- a/src/geometry/closest_points.rs +++ b/src/geometry/closest_points.rs @@ -1,10 +1,10 @@ -type Point = (f64, f64); +use crate::geometry::Point; use std::cmp::Ordering; -fn point_cmp((a1, a2): &Point, (b1, b2): &Point) -> Ordering { - let acmp = f64_cmp(a1, b1); +fn cmp_x(p1: &Point, p2: &Point) -> Ordering { + let acmp = f64_cmp(&p1.x, &p2.x); match acmp { - Ordering::Equal => f64_cmp(a2, b2), + Ordering::Equal => f64_cmp(&p1.y, &p2.y), _ => acmp, } } @@ -16,21 +16,19 @@ fn f64_cmp(a: &f64, b: &f64) -> Ordering { /// returns the two closest points /// or None if there are zero or one point pub fn closest_points(points: &[Point]) -> Option<(Point, Point)> { - let mut points: Vec = points.to_vec(); - points.sort_by(point_cmp); + let mut points_x: Vec = points.to_vec(); + points_x.sort_by(cmp_x); + let mut points_y = points_x.clone(); + points_y.sort_by(|p1: &Point, p2: &Point| -> Ordering { p1.y.partial_cmp(&p2.y).unwrap() }); - closest_points_aux(&points, 0, points.len()) -} - -fn sqr_dist((x1, y1): &Point, (x2, y2): &Point) -> f64 { - let dx = *x1 - *x2; - let dy = *y1 - *y2; - - dx * dx + dy * dy + closest_points_aux(&points_x, points_y, 0, points_x.len()) } +// We maintain two vectors with the same points, one sort by x coordinates and one sorted by y +// coordinates. fn closest_points_aux( - points: &[Point], + points_x: &[Point], + points_y: Vec, mut start: usize, mut end: usize, ) -> Option<(Point, Point)> { @@ -42,62 +40,71 @@ fn closest_points_aux( if n <= 3 { // bruteforce - let mut min = sqr_dist(&points[0], &points[1]); - let mut pair = (points[0], points[1]); + let mut min = points_x[0].euclidean_distance(&points_x[1]); + let mut pair = (points_x[0].clone(), points_x[1].clone()); for i in 1..n { for j in (i + 1)..n { - let new = sqr_dist(&points[i], &points[j]); + let new = points_x[i].euclidean_distance(&points_x[j]); if new < min { min = new; - pair = (points[i], points[j]); + pair = (points_x[i].clone(), points_x[j].clone()); } } } return Some(pair); } - let mid = (start + end) / 2; - let left = closest_points_aux(points, start, mid); - let right = closest_points_aux(points, mid, end); + let mid = start + (end - start) / 2; + let mid_x = points_x[mid].x; + + // Separate points into y_left and y_right vectors based on their x-coordinate. Since y is + // already sorted by y_axis, y_left and y_right will also be sorted. + let mut y_left = vec![]; + let mut y_right = vec![]; + for point in &points_y { + if point.x < mid_x { + y_left.push(point.clone()); + } else { + y_right.push(point.clone()); + } + } + + let left = closest_points_aux(points_x, y_left, start, mid); + let right = closest_points_aux(points_x, y_right, mid, end); let (mut min_sqr_dist, mut pair) = match (left, right) { (Some((l1, l2)), Some((r1, r2))) => { - let dl = sqr_dist(&l1, &l2); - let dr = sqr_dist(&r1, &r2); + let dl = l1.euclidean_distance(&l2); + let dr = r1.euclidean_distance(&r2); if dl < dr { (dl, (l1, l2)) } else { (dr, (r1, r2)) } } - (Some((a, b)), None) => (sqr_dist(&a, &b), (a, b)), - (None, Some((a, b))) => (sqr_dist(&a, &b), (a, b)), + (Some((a, b)), None) | (None, Some((a, b))) => (a.euclidean_distance(&b), (a, b)), (None, None) => unreachable!(), }; - let mid_x = points[mid].0; - let dist = min_sqr_dist.sqrt(); - while points[start].0 < mid_x - dist { + let dist = min_sqr_dist; + while points_x[start].x < mid_x - dist { start += 1; } - while points[end - 1].0 > mid_x + dist { + while points_x[end - 1].x > mid_x + dist { end -= 1; } - let mut mids: Vec<&Point> = points[start..end].iter().collect(); - mids.sort_by(|a, b| f64_cmp(&a.1, &b.1)); - - for (i, e) in mids.iter().enumerate() { + for (i, e) in points_y.iter().enumerate() { for k in 1..8 { - if i + k >= mids.len() { + if i + k >= points_y.len() { break; } - let new = sqr_dist(e, mids[i + k]); + let new = e.euclidean_distance(&points_y[i + k]); if new < min_sqr_dist { min_sqr_dist = new; - pair = (**e, *mids[i + k]); + pair = ((*e).clone(), points_y[i + k].clone()); } } } @@ -137,86 +144,101 @@ mod tests { #[test] fn one_points() { - let vals = [(0., 0.)]; + let vals = [Point::new(0., 0.)]; assert_display!(closest_points(&vals), None::<(Point, Point)>); } #[test] fn two_points() { - let vals = [(0., 0.), (1., 1.)]; - assert_display!(closest_points(&vals), Some(((0., 0.), (1., 1.)))); + let vals = [Point::new(0., 0.), Point::new(1., 1.)]; + assert_display!( + closest_points(&vals), + Some((vals[0].clone(), vals[1].clone())) + ); } #[test] fn three_points() { - let vals = [(0., 0.), (1., 1.), (3., 3.)]; - assert_display!(closest_points(&vals), Some(((0., 0.), (1., 1.)))); + let vals = [Point::new(0., 0.), Point::new(1., 1.), Point::new(3., 3.)]; + assert_display!( + closest_points(&vals), + Some((vals[0].clone(), vals[1].clone())) + ); } #[test] fn list_1() { let vals = [ - (0., 0.), - (2., 1.), - (5., 2.), - (2., 3.), - (4., 0.), - (0., 4.), - (5., 6.), - (4., 4.), - (7., 3.), - (-1., 2.), - (2., 6.), + Point::new(0., 0.), + Point::new(2., 1.), + Point::new(5., 2.), + Point::new(2., 3.), + Point::new(4., 0.), + Point::new(0., 4.), + Point::new(5., 6.), + Point::new(4., 4.), + Point::new(7., 3.), + Point::new(-1., 2.), + Point::new(2., 6.), ]; - assert_display!(closest_points(&vals), Some(((2., 1.), (2., 3.)))); + assert_display!( + closest_points(&vals), + Some((Point::new(2., 1.), Point::new(2., 3.))) + ); } #[test] fn list_2() { let vals = [ - (1., 3.), - (4., 6.), - (8., 8.), - (7., 5.), - (5., 3.), - (10., 3.), - (7., 1.), - (8., 3.), - (4., 9.), - (4., 12.), - (4., 15.), - (7., 14.), - (8., 12.), - (6., 10.), - (4., 14.), - (2., 7.), - (3., 8.), - (5., 8.), - (6., 7.), - (8., 10.), - (6., 12.), + Point::new(1., 3.), + Point::new(4., 6.), + Point::new(8., 8.), + Point::new(7., 5.), + Point::new(5., 3.), + Point::new(10., 3.), + Point::new(7., 1.), + Point::new(8., 3.), + Point::new(4., 9.), + Point::new(4., 12.), + Point::new(4., 15.), + Point::new(7., 14.), + Point::new(8., 12.), + Point::new(6., 10.), + Point::new(4., 14.), + Point::new(2., 7.), + Point::new(3., 8.), + Point::new(5., 8.), + Point::new(6., 7.), + Point::new(8., 10.), + Point::new(6., 12.), ]; - assert_display!(closest_points(&vals), Some(((4., 14.), (4., 15.)))); + assert_display!( + closest_points(&vals), + Some((Point::new(4., 14.), Point::new(4., 15.))) + ); } #[test] fn vertical_points() { let vals = [ - (0., 0.), - (0., 50.), - (0., -25.), - (0., 40.), - (0., 42.), - (0., 100.), - (0., 17.), - (0., 29.), - (0., -50.), - (0., 37.), - (0., 34.), - (0., 8.), - (0., 3.), - (0., 46.), + Point::new(0., 0.), + Point::new(0., 50.), + Point::new(0., -25.), + Point::new(0., 40.), + Point::new(0., 42.), + Point::new(0., 100.), + Point::new(0., 17.), + Point::new(0., 29.), + Point::new(0., -50.), + Point::new(0., 37.), + Point::new(0., 34.), + Point::new(0., 8.), + Point::new(0., 3.), + Point::new(0., 46.), ]; - assert_display!(closest_points(&vals), Some(((0., 40.), (0., 42.)))); + assert_display!( + closest_points(&vals), + Some((Point::new(0., 40.), Point::new(0., 42.))) + ); } } diff --git a/src/geometry/graham_scan.rs b/src/geometry/graham_scan.rs new file mode 100644 index 00000000000..cdd512e7510 --- /dev/null +++ b/src/geometry/graham_scan.rs @@ -0,0 +1,206 @@ +use crate::geometry::Point; +use std::cmp::Ordering; + +fn point_min(a: &&Point, b: &&Point) -> Ordering { + // Find the bottom-most point. In the case of a tie, find the left-most. + if a.y == b.y { + a.x.partial_cmp(&b.x).unwrap() + } else { + a.y.partial_cmp(&b.y).unwrap() + } +} + +// Returns a Vec of Points that make up the convex hull of `points`. Returns an empty Vec if there +// is no convex hull. +pub fn graham_scan(mut points: Vec) -> Vec { + if points.len() <= 2 { + return vec![]; + } + + let min_point = points.iter().min_by(point_min).unwrap().clone(); + points.retain(|p| p != &min_point); + if points.is_empty() { + // edge case where all the points are the same + return vec![]; + } + + let point_cmp = |a: &Point, b: &Point| -> Ordering { + // Sort points in counter-clockwise direction relative to the min point. We can this by + // checking the orientation of consecutive vectors (min_point, a) and (a, b). + let orientation = min_point.consecutive_orientation(a, b); + if orientation < 0.0 { + Ordering::Greater + } else if orientation > 0.0 { + Ordering::Less + } else { + let a_dist = min_point.euclidean_distance(a); + let b_dist = min_point.euclidean_distance(b); + // When two points have the same relative angle to the min point, we should only + // include the further point in the convex hull. We sort further points into a lower + // index, and in the algorithm, remove all consecutive points with the same relative + // angle. + b_dist.partial_cmp(&a_dist).unwrap() + } + }; + points.sort_by(point_cmp); + let mut convex_hull: Vec = vec![]; + + // We always add the min_point, and the first two points in the sorted vec. + convex_hull.push(min_point.clone()); + convex_hull.push(points[0].clone()); + let mut top = 1; + for point in points.iter().skip(1) { + if min_point.consecutive_orientation(point, &convex_hull[top]) == 0.0 { + // Remove consecutive points with the same angle. We make sure include the furthest + // point in the convex hull in the sort comparator. + continue; + } + loop { + // In this loop, we remove points that we determine are no longer part of the convex + // hull. + if top <= 1 { + break; + } + // If there is a segment(i+1, i+2) turns right relative to segment(i, i+1), point(i+1) + // is not part of the convex hull. + let orientation = + convex_hull[top - 1].consecutive_orientation(&convex_hull[top], point); + if orientation <= 0.0 { + top -= 1; + convex_hull.pop(); + } else { + break; + } + } + convex_hull.push(point.clone()); + top += 1; + } + if convex_hull.len() <= 2 { + return vec![]; + } + convex_hull +} + +#[cfg(test)] +mod tests { + use super::graham_scan; + use super::Point; + + fn test_graham(convex_hull: Vec, others: Vec) { + let mut points = convex_hull.clone(); + points.append(&mut others.clone()); + let graham = graham_scan(points); + for point in convex_hull { + assert!(graham.contains(&point)); + } + for point in others { + assert!(!graham.contains(&point)); + } + } + + #[test] + fn too_few_points() { + test_graham(vec![], vec![]); + test_graham(vec![], vec![Point::new(0.0, 0.0)]); + } + + #[test] + fn duplicate_point() { + let p = Point::new(0.0, 0.0); + test_graham(vec![], vec![p.clone(), p.clone(), p.clone(), p.clone(), p]); + } + + #[test] + fn points_same_line() { + let p1 = Point::new(1.0, 0.0); + let p2 = Point::new(2.0, 0.0); + let p3 = Point::new(3.0, 0.0); + let p4 = Point::new(4.0, 0.0); + let p5 = Point::new(5.0, 0.0); + // let p6 = Point::new(1.0, 1.0); + test_graham(vec![], vec![p1, p2, p3, p4, p5]); + } + + #[test] + fn triangle() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(1.5, 2.0); + let points = vec![p1, p2, p3]; + test_graham(points, vec![]); + } + + #[test] + fn rectangle() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(2.0, 2.0); + let p4 = Point::new(1.0, 2.0); + let points = vec![p1, p2, p3, p4]; + test_graham(points, vec![]); + } + + #[test] + fn triangle_with_points_in_middle() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(1.5, 2.0); + let p4 = Point::new(1.5, 1.5); + let p5 = Point::new(1.2, 1.3); + let p6 = Point::new(1.8, 1.2); + let p7 = Point::new(1.5, 1.9); + let hull = vec![p1, p2, p3]; + let others = vec![p4, p5, p6, p7]; + test_graham(hull, others); + } + + #[test] + fn rectangle_with_points_in_middle() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(2.0, 2.0); + let p4 = Point::new(1.0, 2.0); + let p5 = Point::new(1.5, 1.5); + let p6 = Point::new(1.2, 1.3); + let p7 = Point::new(1.8, 1.2); + let p8 = Point::new(1.9, 1.7); + let p9 = Point::new(1.4, 1.9); + let hull = vec![p1, p2, p3, p4]; + let others = vec![p5, p6, p7, p8, p9]; + test_graham(hull, others); + } + + #[test] + fn star() { + // A single stroke star shape (kind of). Only the tips(p1-5) are part of the convex hull. The + // other points would create angles >180 degrees if they were part of the polygon. + let p1 = Point::new(-5.0, 6.0); + let p2 = Point::new(-11.0, 0.0); + let p3 = Point::new(-9.0, -8.0); + let p4 = Point::new(4.0, 4.0); + let p5 = Point::new(6.0, -7.0); + let p6 = Point::new(-7.0, -2.0); + let p7 = Point::new(-2.0, -4.0); + let p8 = Point::new(0.0, 1.0); + let p9 = Point::new(1.0, 0.0); + let p10 = Point::new(-6.0, 1.0); + let hull = vec![p1, p2, p3, p4, p5]; + let others = vec![p6, p7, p8, p9, p10]; + test_graham(hull, others); + } + + #[test] + fn rectangle_with_points_on_same_line() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(2.0, 2.0); + let p4 = Point::new(1.0, 2.0); + let p5 = Point::new(1.5, 1.0); + let p6 = Point::new(1.0, 1.5); + let p7 = Point::new(2.0, 1.5); + let p8 = Point::new(1.5, 2.0); + let hull = vec![p1, p2, p3, p4]; + let others = vec![p5, p6, p7, p8]; + test_graham(hull, others); + } +} diff --git a/src/geometry/jarvis_scan.rs b/src/geometry/jarvis_scan.rs new file mode 100644 index 00000000000..c8999c910ae --- /dev/null +++ b/src/geometry/jarvis_scan.rs @@ -0,0 +1,193 @@ +use crate::geometry::Point; +use crate::geometry::Segment; + +// Returns a Vec of Points that make up the convex hull of `points`. Returns an empty Vec if there +// is no convex hull. +pub fn jarvis_march(points: Vec) -> Vec { + if points.len() <= 2 { + return vec![]; + } + + let mut convex_hull = vec![]; + let mut left_point = 0; + for i in 1..points.len() { + // Find the initial point, which is the leftmost point. In the case of a tie, we take the + // bottom-most point. This helps prevent adding colinear points on the last segment to the hull. + if points[i].x < points[left_point].x + || (points[i].x == points[left_point].x && points[i].y < points[left_point].y) + { + left_point = i; + } + } + convex_hull.push(points[left_point].clone()); + + let mut p = left_point; + loop { + // Find the next counter-clockwise point. + let mut next_p = (p + 1) % points.len(); + for i in 0..points.len() { + let orientation = points[p].consecutive_orientation(&points[i], &points[next_p]); + if orientation > 0.0 { + next_p = i; + } + } + + if next_p == left_point { + // Completed constructing the hull. Exit the loop. + break; + } + p = next_p; + + let last = convex_hull.len() - 1; + if convex_hull.len() > 1 + && Segment::from_points(points[p].clone(), convex_hull[last - 1].clone()) + .on_segment(&convex_hull[last]) + { + // If the last point lies on the segment with the new point and the second to last + // point, we can remove the last point from the hull. + convex_hull[last] = points[p].clone(); + } else { + convex_hull.push(points[p].clone()); + } + } + + if convex_hull.len() <= 2 { + return vec![]; + } + let last = convex_hull.len() - 1; + if Segment::from_points(convex_hull[0].clone(), convex_hull[last - 1].clone()) + .on_segment(&convex_hull[last]) + { + // Check for the edge case where the last point lies on the segment with the zero'th and + // second the last point. In this case, we remove the last point from the hull. + convex_hull.pop(); + if convex_hull.len() == 2 { + return vec![]; + } + } + convex_hull +} + +#[cfg(test)] +mod tests { + use super::jarvis_march; + use super::Point; + + fn test_jarvis(convex_hull: Vec, others: Vec) { + let mut points = others.clone(); + points.append(&mut convex_hull.clone()); + let jarvis = jarvis_march(points); + for point in convex_hull { + assert!(jarvis.contains(&point)); + } + for point in others { + assert!(!jarvis.contains(&point)); + } + } + + #[test] + fn too_few_points() { + test_jarvis(vec![], vec![]); + test_jarvis(vec![], vec![Point::new(0.0, 0.0)]); + } + + #[test] + fn duplicate_point() { + let p = Point::new(0.0, 0.0); + test_jarvis(vec![], vec![p.clone(), p.clone(), p.clone(), p.clone(), p]); + } + + #[test] + fn points_same_line() { + let p1 = Point::new(1.0, 0.0); + let p2 = Point::new(2.0, 0.0); + let p3 = Point::new(3.0, 0.0); + let p4 = Point::new(4.0, 0.0); + let p5 = Point::new(5.0, 0.0); + // let p6 = Point::new(1.0, 1.0); + test_jarvis(vec![], vec![p1, p2, p3, p4, p5]); + } + + #[test] + fn triangle() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(1.5, 2.0); + let points = vec![p1, p2, p3]; + test_jarvis(points, vec![]); + } + + #[test] + fn rectangle() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(2.0, 2.0); + let p4 = Point::new(1.0, 2.0); + let points = vec![p1, p2, p3, p4]; + test_jarvis(points, vec![]); + } + + #[test] + fn triangle_with_points_in_middle() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(1.5, 2.0); + let p4 = Point::new(1.5, 1.5); + let p5 = Point::new(1.2, 1.3); + let p6 = Point::new(1.8, 1.2); + let p7 = Point::new(1.5, 1.9); + let hull = vec![p1, p2, p3]; + let others = vec![p4, p5, p6, p7]; + test_jarvis(hull, others); + } + + #[test] + fn rectangle_with_points_in_middle() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(2.0, 2.0); + let p4 = Point::new(1.0, 2.0); + let p5 = Point::new(1.5, 1.5); + let p6 = Point::new(1.2, 1.3); + let p7 = Point::new(1.8, 1.2); + let p8 = Point::new(1.9, 1.7); + let p9 = Point::new(1.4, 1.9); + let hull = vec![p1, p2, p3, p4]; + let others = vec![p5, p6, p7, p8, p9]; + test_jarvis(hull, others); + } + + #[test] + fn star() { + // A single stroke star shape (kind of). Only the tips(p1-5) are part of the convex hull. The + // other points would create angles >180 degrees if they were part of the polygon. + let p1 = Point::new(-5.0, 6.0); + let p2 = Point::new(-11.0, 0.0); + let p3 = Point::new(-9.0, -8.0); + let p4 = Point::new(4.0, 4.0); + let p5 = Point::new(6.0, -7.0); + let p6 = Point::new(-7.0, -2.0); + let p7 = Point::new(-2.0, -4.0); + let p8 = Point::new(0.0, 1.0); + let p9 = Point::new(1.0, 0.0); + let p10 = Point::new(-6.0, 1.0); + let hull = vec![p1, p2, p3, p4, p5]; + let others = vec![p6, p7, p8, p9, p10]; + test_jarvis(hull, others); + } + + #[test] + fn rectangle_with_points_on_same_line() { + let p1 = Point::new(1.0, 1.0); + let p2 = Point::new(2.0, 1.0); + let p3 = Point::new(2.0, 2.0); + let p4 = Point::new(1.0, 2.0); + let p5 = Point::new(1.5, 1.0); + let p6 = Point::new(1.0, 1.5); + let p7 = Point::new(2.0, 1.5); + let p8 = Point::new(1.5, 2.0); + let hull = vec![p1, p2, p3, p4]; + let others = vec![p5, p6, p7, p8]; + test_jarvis(hull, others); + } +} diff --git a/src/geometry/mod.rs b/src/geometry/mod.rs index f20b7c9b5b5..e883cc004bc 100644 --- a/src/geometry/mod.rs +++ b/src/geometry/mod.rs @@ -1,3 +1,15 @@ mod closest_points; +mod graham_scan; +mod jarvis_scan; +mod point; +mod polygon_points; +mod ramer_douglas_peucker; +mod segment; pub use self::closest_points::closest_points; +pub use self::graham_scan::graham_scan; +pub use self::jarvis_scan::jarvis_march; +pub use self::point::Point; +pub use self::polygon_points::lattice_points; +pub use self::ramer_douglas_peucker::ramer_douglas_peucker; +pub use self::segment::Segment; diff --git a/src/geometry/point.rs b/src/geometry/point.rs new file mode 100644 index 00000000000..4b4bd8db501 --- /dev/null +++ b/src/geometry/point.rs @@ -0,0 +1,38 @@ +use std::ops::Sub; + +#[derive(Clone, Debug, PartialEq)] +pub struct Point { + pub x: f64, + pub y: f64, +} + +impl Point { + pub fn new(x: f64, y: f64) -> Point { + Point { x, y } + } + + // Returns the orientation of consecutive segments ab and bc. + pub fn consecutive_orientation(&self, b: &Point, c: &Point) -> f64 { + let p1 = b - self; + let p2 = c - self; + p1.cross_prod(&p2) + } + + pub fn cross_prod(&self, other: &Point) -> f64 { + self.x * other.y - self.y * other.x + } + + pub fn euclidean_distance(&self, other: &Point) -> f64 { + ((self.x - other.x).powi(2) + (self.y - other.y).powi(2)).sqrt() + } +} + +impl Sub for &Point { + type Output = Point; + + fn sub(self, other: Self) -> Point { + let x = self.x - other.x; + let y = self.y - other.y; + Point::new(x, y) + } +} diff --git a/src/geometry/polygon_points.rs b/src/geometry/polygon_points.rs new file mode 100644 index 00000000000..7d492681cb7 --- /dev/null +++ b/src/geometry/polygon_points.rs @@ -0,0 +1,84 @@ +type Ll = i64; +type Pll = (Ll, Ll); + +fn cross(x1: Ll, y1: Ll, x2: Ll, y2: Ll) -> Ll { + x1 * y2 - x2 * y1 +} + +pub fn polygon_area(pts: &[Pll]) -> Ll { + let mut ats = 0; + for i in 2..pts.len() { + ats += cross( + pts[i].0 - pts[0].0, + pts[i].1 - pts[0].1, + pts[i - 1].0 - pts[0].0, + pts[i - 1].1 - pts[0].1, + ); + } + Ll::abs(ats / 2) +} + +fn gcd(mut a: Ll, mut b: Ll) -> Ll { + while b != 0 { + let temp = b; + b = a % b; + a = temp; + } + a +} + +fn boundary(pts: &[Pll]) -> Ll { + let mut ats = pts.len() as Ll; + for i in 0..pts.len() { + let deltax = pts[i].0 - pts[(i + 1) % pts.len()].0; + let deltay = pts[i].1 - pts[(i + 1) % pts.len()].1; + ats += Ll::abs(gcd(deltax, deltay)) - 1; + } + ats +} + +pub fn lattice_points(pts: &[Pll]) -> Ll { + let bounds = boundary(pts); + let area = polygon_area(pts); + area + 1 - bounds / 2 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_calculate_cross() { + assert_eq!(cross(1, 2, 3, 4), 4 - 3 * 2); + } + + #[test] + fn test_polygon_3_coordinates() { + let pts = vec![(0, 0), (0, 3), (4, 0)]; + assert_eq!(polygon_area(&pts), 6); + } + + #[test] + fn test_polygon_4_coordinates() { + let pts = vec![(0, 0), (0, 2), (2, 2), (2, 0)]; + assert_eq!(polygon_area(&pts), 4); + } + + #[test] + fn test_gcd_multiple_of_common_factor() { + assert_eq!(gcd(14, 28), 14); + } + + #[test] + fn test_boundary() { + let pts = vec![(0, 0), (0, 3), (0, 4), (2, 2)]; + assert_eq!(boundary(&pts), 8); + } + + #[test] + fn test_lattice_points() { + let pts = vec![(1, 1), (5, 1), (5, 4)]; + let result = lattice_points(&pts); + assert_eq!(result, 3); + } +} diff --git a/src/geometry/ramer_douglas_peucker.rs b/src/geometry/ramer_douglas_peucker.rs new file mode 100644 index 00000000000..ca9d53084b7 --- /dev/null +++ b/src/geometry/ramer_douglas_peucker.rs @@ -0,0 +1,115 @@ +use crate::geometry::Point; + +pub fn ramer_douglas_peucker(points: &[Point], epsilon: f64) -> Vec { + if points.len() < 3 { + return points.to_vec(); + } + let mut dmax = 0.0; + let mut index = 0; + let end = points.len() - 1; + + for i in 1..end { + let d = perpendicular_distance(&points[i], &points[0], &points[end]); + if d > dmax { + index = i; + dmax = d; + } + } + + if dmax > epsilon { + let mut results = ramer_douglas_peucker(&points[..=index], epsilon); + results.pop(); + results.extend(ramer_douglas_peucker(&points[index..], epsilon)); + results + } else { + vec![points[0].clone(), points[end].clone()] + } +} + +fn perpendicular_distance(p: &Point, a: &Point, b: &Point) -> f64 { + let num = (b.y - a.y) * p.x - (b.x - a.x) * p.y + b.x * a.y - b.y * a.x; + let den = a.euclidean_distance(b); + num.abs() / den +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_perpendicular_distance { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (p, a, b, expected) = $test_case; + assert_eq!(perpendicular_distance(&p, &a, &b), expected); + assert_eq!(perpendicular_distance(&p, &b, &a), expected); + } + )* + }; + } + + test_perpendicular_distance! { + basic: (Point::new(4.0, 0.0), Point::new(0.0, 0.0), Point::new(0.0, 3.0), 4.0), + basic_shifted_1: (Point::new(4.0, 1.0), Point::new(0.0, 1.0), Point::new(0.0, 4.0), 4.0), + basic_shifted_2: (Point::new(2.0, 1.0), Point::new(-2.0, 1.0), Point::new(-2.0, 4.0), 4.0), + } + + #[test] + fn test_ramer_douglas_peucker_polygon() { + let a = Point::new(0.0, 0.0); + let b = Point::new(1.0, 0.0); + let c = Point::new(2.0, 0.0); + let d = Point::new(2.0, 1.0); + let e = Point::new(2.0, 2.0); + let f = Point::new(1.0, 2.0); + let g = Point::new(0.0, 2.0); + let h = Point::new(0.0, 1.0); + let polygon = vec![ + a.clone(), + b, + c.clone(), + d, + e.clone(), + f, + g.clone(), + h.clone(), + ]; + let epsilon = 0.7; + let result = ramer_douglas_peucker(&polygon, epsilon); + assert_eq!(result, vec![a, c, e, g, h]); + } + + #[test] + fn test_ramer_douglas_peucker_polygonal_chain() { + let a = Point::new(0., 0.); + let b = Point::new(2., 0.5); + let c = Point::new(3., 3.); + let d = Point::new(6., 3.); + let e = Point::new(8., 4.); + + let points = vec![a.clone(), b, c, d, e.clone()]; + + let epsilon = 3.; // The epsilon is quite large, so the result will be a single line + let result = ramer_douglas_peucker(&points, epsilon); + assert_eq!(result, vec![a, e]); + } + + #[test] + fn test_less_than_three_points() { + let a = Point::new(0., 0.); + let b = Point::new(1., 1.); + + let epsilon = 0.1; + + assert_eq!(ramer_douglas_peucker(&[], epsilon), vec![]); + assert_eq!( + ramer_douglas_peucker(&[a.clone()], epsilon), + vec![a.clone()] + ); + assert_eq!( + ramer_douglas_peucker(&[a.clone(), b.clone()], epsilon), + vec![a, b] + ); + } +} diff --git a/src/geometry/segment.rs b/src/geometry/segment.rs new file mode 100644 index 00000000000..e43162e002f --- /dev/null +++ b/src/geometry/segment.rs @@ -0,0 +1,193 @@ +use super::Point; + +const TOLERANCE: f64 = 0.0001; + +pub struct Segment { + pub a: Point, + pub b: Point, +} + +impl Segment { + pub fn new(x1: f64, y1: f64, x2: f64, y2: f64) -> Segment { + Segment { + a: Point::new(x1, y1), + b: Point::new(x2, y2), + } + } + + pub fn from_points(a: Point, b: Point) -> Segment { + Segment { a, b } + } + + pub fn direction(&self, p: &Point) -> f64 { + let a = Point::new(p.x - self.a.x, p.y - self.a.y); + let b = Point::new(self.b.x - self.a.x, self.b.y - self.a.y); + a.cross_prod(&b) + } + + pub fn is_vertical(&self) -> bool { + self.a.x == self.b.x + } + + // returns (slope, y-intercept) + pub fn get_line_equation(&self) -> (f64, f64) { + let slope = (self.a.y - self.b.y) / (self.a.x - self.b.x); + let y_intercept = self.a.y - slope * self.a.x; + (slope, y_intercept) + } + + // Compute the value of y at x. Uses the line equation, and assumes the segment + // has infinite length. + pub fn compute_y_at_x(&self, x: f64) -> f64 { + let (slope, y_intercept) = self.get_line_equation(); + slope * x + y_intercept + } + + pub fn is_colinear(&self, p: &Point) -> bool { + if self.is_vertical() { + p.x == self.a.x + } else { + (self.compute_y_at_x(p.x) - p.y).abs() < TOLERANCE + } + } + + // p must be colinear with the segment + pub fn colinear_point_on_segment(&self, p: &Point) -> bool { + assert!(self.is_colinear(p), "p must be colinear!"); + let (low_x, high_x) = if self.a.x < self.b.x { + (self.a.x, self.b.x) + } else { + (self.b.x, self.a.x) + }; + let (low_y, high_y) = if self.a.y < self.b.y { + (self.a.y, self.b.y) + } else { + (self.b.y, self.a.y) + }; + + p.x >= low_x && p.x <= high_x && p.y >= low_y && p.y <= high_y + } + + pub fn on_segment(&self, p: &Point) -> bool { + if !self.is_colinear(p) { + return false; + } + self.colinear_point_on_segment(p) + } + + pub fn intersects(&self, other: &Segment) -> bool { + let direction1 = self.direction(&other.a); + let direction2 = self.direction(&other.b); + let direction3 = other.direction(&self.a); + let direction4 = other.direction(&self.b); + + // If the segments saddle each others' endpoints, they intersect + if ((direction1 > 0.0 && direction2 < 0.0) || (direction1 < 0.0 && direction2 > 0.0)) + && ((direction3 > 0.0 && direction4 < 0.0) || (direction3 < 0.0 && direction4 > 0.0)) + { + return true; + } + + // Edge cases where an endpoint lies on a segment + (direction1 == 0.0 && self.colinear_point_on_segment(&other.a)) + || (direction2 == 0.0 && self.colinear_point_on_segment(&other.b)) + || (direction3 == 0.0 && other.colinear_point_on_segment(&self.a)) + || (direction4 == 0.0 && other.colinear_point_on_segment(&self.b)) + } +} + +#[cfg(test)] +mod tests { + use super::Point; + use super::Segment; + + #[test] + fn colinear() { + let segment = Segment::new(2.0, 3.0, 6.0, 5.0); + assert_eq!((0.5, 2.0), segment.get_line_equation()); + + assert!(segment.is_colinear(&Point::new(2.0, 3.0))); + assert!(segment.is_colinear(&Point::new(6.0, 5.0))); + assert!(segment.is_colinear(&Point::new(0.0, 2.0))); + assert!(segment.is_colinear(&Point::new(-5.0, -0.5))); + assert!(segment.is_colinear(&Point::new(10.0, 7.0))); + + assert!(!segment.is_colinear(&Point::new(0.0, 0.0))); + assert!(!segment.is_colinear(&Point::new(1.9, 3.0))); + assert!(!segment.is_colinear(&Point::new(2.1, 3.0))); + assert!(!segment.is_colinear(&Point::new(2.0, 2.9))); + assert!(!segment.is_colinear(&Point::new(2.0, 3.1))); + assert!(!segment.is_colinear(&Point::new(5.9, 5.0))); + assert!(!segment.is_colinear(&Point::new(6.1, 5.0))); + assert!(!segment.is_colinear(&Point::new(6.0, 4.9))); + assert!(!segment.is_colinear(&Point::new(6.0, 5.1))); + } + + #[test] + fn colinear_vertical() { + let segment = Segment::new(2.0, 3.0, 2.0, 5.0); + assert!(segment.is_colinear(&Point::new(2.0, 1.0))); + assert!(segment.is_colinear(&Point::new(2.0, 3.0))); + assert!(segment.is_colinear(&Point::new(2.0, 4.0))); + assert!(segment.is_colinear(&Point::new(2.0, 5.0))); + assert!(segment.is_colinear(&Point::new(2.0, 6.0))); + + assert!(!segment.is_colinear(&Point::new(1.0, 3.0))); + assert!(!segment.is_colinear(&Point::new(3.0, 3.0))); + } + + fn test_intersect(s1: &Segment, s2: &Segment, result: bool) { + assert_eq!(s1.intersects(s2), result); + assert_eq!(s2.intersects(s1), result); + } + + #[test] + fn intersects() { + let s1 = Segment::new(2.0, 3.0, 6.0, 5.0); + let s2 = Segment::new(-1.0, 9.0, 10.0, -3.0); + let s3 = Segment::new(-0.0, 10.0, 11.0, -2.0); + let s4 = Segment::new(100.0, 200.0, 40.0, 50.0); + test_intersect(&s1, &s2, true); + test_intersect(&s1, &s3, true); + test_intersect(&s2, &s3, false); + test_intersect(&s1, &s4, false); + test_intersect(&s2, &s4, false); + test_intersect(&s3, &s4, false); + } + + #[test] + fn intersects_endpoint_on_segment() { + let s1 = Segment::new(2.0, 3.0, 6.0, 5.0); + let s2 = Segment::new(4.0, 4.0, -11.0, 20.0); + let s3 = Segment::new(4.0, 4.0, 14.0, -19.0); + test_intersect(&s1, &s2, true); + test_intersect(&s1, &s3, true); + } + + #[test] + fn intersects_self() { + let s1 = Segment::new(2.0, 3.0, 6.0, 5.0); + let s2 = Segment::new(2.0, 3.0, 6.0, 5.0); + test_intersect(&s1, &s2, true); + } + + #[test] + fn too_short_to_intersect() { + let s1 = Segment::new(2.0, 3.0, 6.0, 5.0); + let s2 = Segment::new(-1.0, 10.0, 3.0, 5.0); + let s3 = Segment::new(5.0, 3.0, 10.0, -11.0); + test_intersect(&s1, &s2, false); + test_intersect(&s1, &s3, false); + test_intersect(&s2, &s3, false); + } + + #[test] + fn parallel_segments() { + let s1 = Segment::new(-5.0, 0.0, 5.0, 0.0); + let s2 = Segment::new(-5.0, 1.0, 5.0, 1.0); + let s3 = Segment::new(-5.0, -1.0, 5.0, -1.0); + test_intersect(&s1, &s2, false); + test_intersect(&s1, &s3, false); + test_intersect(&s2, &s3, false); + } +} diff --git a/src/graph/astar.rs b/src/graph/astar.rs new file mode 100644 index 00000000000..a4244c87b8b --- /dev/null +++ b/src/graph/astar.rs @@ -0,0 +1,261 @@ +use std::{ + collections::{BTreeMap, BinaryHeap}, + ops::Add, +}; + +use num_traits::Zero; + +type Graph = BTreeMap>; + +#[derive(Clone, Debug, Eq, PartialEq)] +struct Candidate { + estimated_weight: E, + real_weight: E, + state: V, +} + +impl PartialOrd for Candidate { + fn partial_cmp(&self, other: &Self) -> Option { + // Note the inverted order; we want nodes with lesser weight to have + // higher priority + Some(self.cmp(other)) + } +} + +impl Ord for Candidate { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + // Note the inverted order; we want nodes with lesser weight to have + // higher priority + other.estimated_weight.cmp(&self.estimated_weight) + } +} + +pub fn astar + Zero>( + graph: &Graph, + start: V, + target: V, + heuristic: impl Fn(V) -> E, +) -> Option<(E, Vec)> { + // traversal front + let mut queue = BinaryHeap::new(); + // maps each node to its predecessor in the final path + let mut previous = BTreeMap::new(); + // weights[v] is the accumulated weight from start to v + let mut weights = BTreeMap::new(); + // initialize traversal + weights.insert(start, E::zero()); + queue.push(Candidate { + estimated_weight: heuristic(start), + real_weight: E::zero(), + state: start, + }); + while let Some(Candidate { + real_weight, + state: current, + .. + }) = queue.pop() + { + if current == target { + break; + } + for (&next, &weight) in &graph[¤t] { + let real_weight = real_weight + weight; + if weights + .get(&next) + .is_none_or(|&weight| real_weight < weight) + { + // current allows us to reach next with lower weight (or at all) + // add next to the front + let estimated_weight = real_weight + heuristic(next); + weights.insert(next, real_weight); + queue.push(Candidate { + estimated_weight, + real_weight, + state: next, + }); + previous.insert(next, current); + } + } + } + let weight = if let Some(&weight) = weights.get(&target) { + weight + } else { + // we did not reach target from start + return None; + }; + // build path in reverse + let mut current = target; + let mut path = vec![current]; + while current != start { + let prev = previous + .get(¤t) + .copied() + .expect("We reached the target, but are unable to reconsistute the path"); + current = prev; + path.push(current); + } + path.reverse(); + Some((weight, path)) +} + +#[cfg(test)] +mod tests { + use super::{astar, Graph}; + use num_traits::Zero; + use std::collections::BTreeMap; + + // the null heuristic make A* equivalent to Dijkstra + fn null_heuristic(_v: V) -> E { + E::zero() + } + + fn add_edge(graph: &mut Graph, v1: V, v2: V, c: E) { + graph.entry(v1).or_default().insert(v2, c); + graph.entry(v2).or_default(); + } + + #[test] + fn single_vertex() { + let mut graph: Graph = BTreeMap::new(); + graph.insert(0, BTreeMap::new()); + + assert_eq!(astar(&graph, 0, 0, null_heuristic), Some((0, vec![0]))); + assert_eq!(astar(&graph, 0, 1, null_heuristic), None); + } + + #[test] + fn single_edge() { + let mut graph = BTreeMap::new(); + add_edge(&mut graph, 0, 1, 2); + + assert_eq!(astar(&graph, 0, 1, null_heuristic), Some((2, vec![0, 1]))); + assert_eq!(astar(&graph, 1, 0, null_heuristic), None); + } + + #[test] + fn graph_1() { + let mut graph = BTreeMap::new(); + add_edge(&mut graph, 'a', 'c', 12); + add_edge(&mut graph, 'a', 'd', 60); + add_edge(&mut graph, 'b', 'a', 10); + add_edge(&mut graph, 'c', 'b', 20); + add_edge(&mut graph, 'c', 'd', 32); + add_edge(&mut graph, 'e', 'a', 7); + + // from a + assert_eq!( + astar(&graph, 'a', 'a', null_heuristic), + Some((0, vec!['a'])) + ); + assert_eq!( + astar(&graph, 'a', 'b', null_heuristic), + Some((32, vec!['a', 'c', 'b'])) + ); + assert_eq!( + astar(&graph, 'a', 'c', null_heuristic), + Some((12, vec!['a', 'c'])) + ); + assert_eq!( + astar(&graph, 'a', 'd', null_heuristic), + Some((12 + 32, vec!['a', 'c', 'd'])) + ); + assert_eq!(astar(&graph, 'a', 'e', null_heuristic), None); + + // from b + assert_eq!( + astar(&graph, 'b', 'a', null_heuristic), + Some((10, vec!['b', 'a'])) + ); + assert_eq!( + astar(&graph, 'b', 'b', null_heuristic), + Some((0, vec!['b'])) + ); + assert_eq!( + astar(&graph, 'b', 'c', null_heuristic), + Some((10 + 12, vec!['b', 'a', 'c'])) + ); + assert_eq!( + astar(&graph, 'b', 'd', null_heuristic), + Some((10 + 12 + 32, vec!['b', 'a', 'c', 'd'])) + ); + assert_eq!(astar(&graph, 'b', 'e', null_heuristic), None); + + // from c + assert_eq!( + astar(&graph, 'c', 'a', null_heuristic), + Some((20 + 10, vec!['c', 'b', 'a'])) + ); + assert_eq!( + astar(&graph, 'c', 'b', null_heuristic), + Some((20, vec!['c', 'b'])) + ); + assert_eq!( + astar(&graph, 'c', 'c', null_heuristic), + Some((0, vec!['c'])) + ); + assert_eq!( + astar(&graph, 'c', 'd', null_heuristic), + Some((32, vec!['c', 'd'])) + ); + assert_eq!(astar(&graph, 'c', 'e', null_heuristic), None); + + // from d + assert_eq!(astar(&graph, 'd', 'a', null_heuristic), None); + assert_eq!(astar(&graph, 'd', 'b', null_heuristic), None); + assert_eq!(astar(&graph, 'd', 'c', null_heuristic), None); + assert_eq!( + astar(&graph, 'd', 'd', null_heuristic), + Some((0, vec!['d'])) + ); + assert_eq!(astar(&graph, 'd', 'e', null_heuristic), None); + + // from e + assert_eq!( + astar(&graph, 'e', 'a', null_heuristic), + Some((7, vec!['e', 'a'])) + ); + assert_eq!( + astar(&graph, 'e', 'b', null_heuristic), + Some((7 + 12 + 20, vec!['e', 'a', 'c', 'b'])) + ); + assert_eq!( + astar(&graph, 'e', 'c', null_heuristic), + Some((7 + 12, vec!['e', 'a', 'c'])) + ); + assert_eq!( + astar(&graph, 'e', 'd', null_heuristic), + Some((7 + 12 + 32, vec!['e', 'a', 'c', 'd'])) + ); + assert_eq!( + astar(&graph, 'e', 'e', null_heuristic), + Some((0, vec!['e'])) + ); + } + + #[test] + fn test_heuristic() { + // make a grid + let mut graph = BTreeMap::new(); + let rows = 100; + let cols = 100; + for row in 0..rows { + for col in 0..cols { + add_edge(&mut graph, (row, col), (row + 1, col), 1); + add_edge(&mut graph, (row, col), (row, col + 1), 1); + add_edge(&mut graph, (row, col), (row + 1, col + 1), 1); + add_edge(&mut graph, (row + 1, col), (row, col), 1); + add_edge(&mut graph, (row + 1, col + 1), (row, col), 1); + } + } + + // Dijkstra would explore most of the 101 × 101 nodes + // the heuristic should allow exploring only about 200 nodes + let now = std::time::Instant::now(); + let res = astar(&graph, (0, 0), (100, 90), |(i, j)| 100 - i + 90 - j); + assert!(now.elapsed() < std::time::Duration::from_millis(10)); + + let (weight, path) = res.unwrap(); + assert_eq!(weight, 100); + assert_eq!(path.len(), 101); + } +} diff --git a/src/graph/bellman_ford.rs b/src/graph/bellman_ford.rs index e16fcce28fe..8faf55860d5 100644 --- a/src/graph/bellman_ford.rs +++ b/src/graph/bellman_ford.rs @@ -90,8 +90,8 @@ mod tests { use std::collections::BTreeMap; fn add_edge(graph: &mut Graph, v1: V, v2: V, c: E) { - graph.entry(v1).or_insert_with(BTreeMap::new).insert(v2, c); - graph.entry(v2).or_insert_with(BTreeMap::new); + graph.entry(v1).or_default().insert(v2, c); + graph.entry(v2).or_default(); } #[test] diff --git a/src/graph/bipartite_matching.rs b/src/graph/bipartite_matching.rs new file mode 100644 index 00000000000..48c25e8064f --- /dev/null +++ b/src/graph/bipartite_matching.rs @@ -0,0 +1,270 @@ +// Adjacency List +use std::collections::VecDeque; +type Graph = Vec>; + +pub struct BipartiteMatching { + pub adj: Graph, + pub num_vertices_grp1: usize, + pub num_vertices_grp2: usize, + // mt1[i] = v is the matching of i in grp1 to v in grp2 + pub mt1: Vec, + pub mt2: Vec, + pub used: Vec, +} +impl BipartiteMatching { + pub fn new(num_vertices_grp1: usize, num_vertices_grp2: usize) -> Self { + BipartiteMatching { + adj: vec![vec![]; num_vertices_grp1 + 1], + num_vertices_grp1, + num_vertices_grp2, + mt2: vec![-1; num_vertices_grp2 + 1], + mt1: vec![-1; num_vertices_grp1 + 1], + used: vec![false; num_vertices_grp1 + 1], + } + } + #[inline] + // Add an directed edge u->v in the graph + pub fn add_edge(&mut self, u: usize, v: usize) { + self.adj[u].push(v); + } + + fn try_kuhn(&mut self, cur: usize) -> bool { + if self.used[cur] { + return false; + } + self.used[cur] = true; + for i in 0..self.adj[cur].len() { + let to = self.adj[cur][i]; + if self.mt2[to] == -1 || self.try_kuhn(self.mt2[to] as usize) { + self.mt2[to] = cur as i32; + return true; + } + } + false + } + // Note: It does not modify self.mt1, it only works on self.mt2 + pub fn kuhn(&mut self) { + self.mt2 = vec![-1; self.num_vertices_grp2 + 1]; + for v in 1..=self.num_vertices_grp1 { + self.used = vec![false; self.num_vertices_grp1 + 1]; + self.try_kuhn(v); + } + } + pub fn print_matching(&self) { + for i in 1..=self.num_vertices_grp2 { + if self.mt2[i] == -1 { + continue; + } + println!("Vertex {} in grp1 matched with {} grp2", self.mt2[i], i) + } + } + fn bfs(&self, dist: &mut [i32]) -> bool { + let mut q = VecDeque::new(); + for (u, d_i) in dist + .iter_mut() + .enumerate() + .skip(1) + .take(self.num_vertices_grp1) + { + if self.mt1[u] == 0 { + // u is not matched + *d_i = 0; + q.push_back(u); + } else { + // else set the vertex distance as infinite because it is matched + // this will be considered the next time + + *d_i = i32::MAX; + } + } + dist[0] = i32::MAX; + while !q.is_empty() { + let u = *q.front().unwrap(); + q.pop_front(); + if dist[u] < dist[0] { + for i in 0..self.adj[u].len() { + let v = self.adj[u][i]; + if dist[self.mt2[v] as usize] == i32::MAX { + dist[self.mt2[v] as usize] = dist[u] + 1; + q.push_back(self.mt2[v] as usize); + } + } + } + } + dist[0] != i32::MAX + } + fn dfs(&mut self, u: i32, dist: &mut Vec) -> bool { + if u == 0 { + return true; + } + for i in 0..self.adj[u as usize].len() { + let v = self.adj[u as usize][i]; + if dist[self.mt2[v] as usize] == dist[u as usize] + 1 && self.dfs(self.mt2[v], dist) { + self.mt2[v] = u; + self.mt1[u as usize] = v as i32; + return true; + } + } + dist[u as usize] = i32::MAX; + false + } + pub fn hopcroft_karp(&mut self) -> i32 { + // NOTE: how to use: https://cses.fi/paste/7558dba8d00436a847eab8/ + self.mt2 = vec![0; self.num_vertices_grp2 + 1]; + self.mt1 = vec![0; self.num_vertices_grp1 + 1]; + let mut dist = vec![i32::MAX; self.num_vertices_grp1 + 1]; + let mut res = 0; + while self.bfs(&mut dist) { + for u in 1..=self.num_vertices_grp1 { + if self.mt1[u] == 0 && self.dfs(u as i32, &mut dist) { + res += 1; + } + } + } + // for x in self.mt2 change x to -1 if it is 0 + for x in self.mt2.iter_mut() { + if *x == 0 { + *x = -1; + } + } + for x in self.mt1.iter_mut() { + if *x == 0 { + *x = -1; + } + } + res + } +} +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn small_graph_kuhn() { + let n1 = 6; + let n2 = 6; + let mut g = BipartiteMatching::new(n1, n2); + // vertex 1 in grp1 to vertex 1 in grp 2 + // denote the ith grp2 vertex as n1+i + g.add_edge(1, 2); + g.add_edge(1, 3); + // 2 is not connected to any vertex + g.add_edge(3, 4); + g.add_edge(3, 1); + g.add_edge(4, 3); + g.add_edge(5, 3); + g.add_edge(5, 4); + g.add_edge(6, 6); + g.kuhn(); + g.print_matching(); + let answer: Vec = vec![-1, 2, -1, 1, 3, 4, 6]; + for i in 1..g.mt2.len() { + if g.mt2[i] == -1 { + // 5 in group2 has no pair + assert_eq!(i, 5); + continue; + } + // 2 in group1 has no pair + assert!(g.mt2[i] != 2); + assert_eq!(i as i32, answer[g.mt2[i] as usize]); + } + } + #[test] + fn small_graph_hopcroft() { + let n1 = 6; + let n2 = 6; + let mut g = BipartiteMatching::new(n1, n2); + // vertex 1 in grp1 to vertex 1 in grp 2 + // denote the ith grp2 vertex as n1+i + g.add_edge(1, 2); + g.add_edge(1, 3); + // 2 is not connected to any vertex + g.add_edge(3, 4); + g.add_edge(3, 1); + g.add_edge(4, 3); + g.add_edge(5, 3); + g.add_edge(5, 4); + g.add_edge(6, 6); + let x = g.hopcroft_karp(); + assert_eq!(x, 5); + g.print_matching(); + let answer: Vec = vec![-1, 2, -1, 1, 3, 4, 6]; + for i in 1..g.mt2.len() { + if g.mt2[i] == -1 { + // 5 in group2 has no pair + assert_eq!(i, 5); + continue; + } + // 2 in group1 has no pair + assert!(g.mt2[i] != 2); + assert_eq!(i as i32, answer[g.mt2[i] as usize]); + } + } + #[test] + fn super_small_graph_kuhn() { + let n1 = 1; + let n2 = 1; + let mut g = BipartiteMatching::new(n1, n2); + g.add_edge(1, 1); + g.kuhn(); + g.print_matching(); + assert_eq!(g.mt2[1], 1); + } + #[test] + fn super_small_graph_hopcroft() { + let n1 = 1; + let n2 = 1; + let mut g = BipartiteMatching::new(n1, n2); + g.add_edge(1, 1); + let x = g.hopcroft_karp(); + assert_eq!(x, 1); + g.print_matching(); + assert_eq!(g.mt2[1], 1); + assert_eq!(g.mt1[1], 1); + } + + #[test] + fn only_one_vertex_graph_kuhn() { + let n1 = 10; + let n2 = 10; + let mut g = BipartiteMatching::new(n1, n2); + g.add_edge(1, 1); + g.add_edge(2, 1); + g.add_edge(3, 1); + g.add_edge(4, 1); + g.add_edge(5, 1); + g.add_edge(6, 1); + g.add_edge(7, 1); + g.add_edge(8, 1); + g.add_edge(9, 1); + g.add_edge(10, 1); + g.kuhn(); + g.print_matching(); + assert_eq!(g.mt2[1], 1); + for i in 2..g.mt2.len() { + assert!(g.mt2[i] == -1); + } + } + #[test] + fn only_one_vertex_graph_hopcroft() { + let n1 = 10; + let n2 = 10; + let mut g = BipartiteMatching::new(n1, n2); + g.add_edge(1, 1); + g.add_edge(2, 1); + g.add_edge(3, 1); + g.add_edge(4, 1); + g.add_edge(5, 1); + g.add_edge(6, 1); + g.add_edge(7, 1); + g.add_edge(8, 1); + g.add_edge(9, 1); + g.add_edge(10, 1); + let x = g.hopcroft_karp(); + assert_eq!(x, 1); + g.print_matching(); + assert_eq!(g.mt2[1], 1); + for i in 2..g.mt2.len() { + assert!(g.mt2[i] == -1); + } + } +} diff --git a/src/graph/breadth_first_search.rs b/src/graph/breadth_first_search.rs index 076d6f11002..4b4875ab721 100644 --- a/src/graph/breadth_first_search.rs +++ b/src/graph/breadth_first_search.rs @@ -34,8 +34,7 @@ pub fn breadth_first_search(graph: &Graph, root: Node, target: Node) -> Option]; const IN_DECOMPOSITION: u64 = 1 << 63; + +/// Centroid Decomposition for a tree. +/// +/// Given a tree, it can be recursively decomposed into centroids. Then the +/// parent of a centroid `c` is the previous centroid that splitted its connected +/// component into two or more components. It can be shown that in such +/// decomposition, for each path `p` with starting and ending vertices `u`, `v`, +/// the lowest common ancestor of `u` and `v` in centroid tree is a vertex of `p`. +/// +/// The input tree should have its vertices numbered from 1 to n, and +/// `graph_enumeration.rs` may help to convert other representations. pub struct CentroidDecomposition { /// The root of the centroid tree, should _not_ be set by the user pub root: usize, - /// The result. decomposition[`v`] is the parent of `v` in centroid tree. - /// decomposition[`root`] is 0 + /// The result. `decomposition[v]` is the parent of `v` in centroid tree. + /// `decomposition[root]` is 0 pub decomposition: Vec, /// Used internally to save the big_child of a vertex, and whether it has /// been added to the centroid tree. @@ -123,6 +121,7 @@ mod tests { let mut adj: Vec> = vec![vec![]; len]; adj[1].push(2); adj[15].push(14); + #[allow(clippy::needless_range_loop)] for i in 2..15 { adj[i].push(i + 1); adj[i].push(i - 1); diff --git a/src/graph/decremental_connectivity.rs b/src/graph/decremental_connectivity.rs new file mode 100644 index 00000000000..e5245404866 --- /dev/null +++ b/src/graph/decremental_connectivity.rs @@ -0,0 +1,267 @@ +use std::collections::HashSet; + +/// A data-structure that, given a forest, allows dynamic-connectivity queries. +/// Meaning deletion of an edge (u,v) and checking whether two vertecies are still connected. +/// +/// # Complexity +/// The preprocessing phase runs in O(n) time, where n is the number of vertecies in the forest. +/// Deletion runs in O(log n) and checking for connectivity runs in O(1) time. +/// +/// # Sources +/// used Wikipedia as reference: +pub struct DecrementalConnectivity { + adjacent: Vec>, + component: Vec, + count: usize, + visited: Vec, + dfs_id: usize, +} +impl DecrementalConnectivity { + //expects the parent of a root to be itself + pub fn new(adjacent: Vec>) -> Result { + let n = adjacent.len(); + if !is_forest(&adjacent) { + return Err("input graph is not a forest!".to_string()); + } + let mut tmp = DecrementalConnectivity { + adjacent, + component: vec![0; n], + count: 0, + visited: vec![0; n], + dfs_id: 1, + }; + tmp.component = tmp.calc_component(); + Ok(tmp) + } + + pub fn connected(&self, u: usize, v: usize) -> Option { + match (self.component.get(u), self.component.get(v)) { + (Some(a), Some(b)) => Some(a == b), + _ => None, + } + } + + pub fn delete(&mut self, u: usize, v: usize) { + if !self.adjacent[u].contains(&v) || self.component[u] != self.component[v] { + panic!("delete called on the edge ({u}, {v}) which doesn't exist"); + } + + self.adjacent[u].remove(&v); + self.adjacent[v].remove(&u); + + let mut queue: Vec = Vec::new(); + if self.is_smaller(u, v) { + queue.push(u); + self.dfs_id += 1; + self.visited[v] = self.dfs_id; + } else { + queue.push(v); + self.dfs_id += 1; + self.visited[u] = self.dfs_id; + } + while !queue.is_empty() { + let ¤t = queue.last().unwrap(); + self.dfs_step(&mut queue, self.dfs_id); + self.component[current] = self.count; + } + self.count += 1; + } + + fn calc_component(&mut self) -> Vec { + let mut visited: Vec = vec![false; self.adjacent.len()]; + let mut comp: Vec = vec![0; self.adjacent.len()]; + + for i in 0..self.adjacent.len() { + if visited[i] { + continue; + } + let mut queue: Vec = vec![i]; + while let Some(current) = queue.pop() { + if !visited[current] { + for &neighbour in self.adjacent[current].iter() { + queue.push(neighbour); + } + } + visited[current] = true; + comp[current] = self.count; + } + self.count += 1; + } + comp + } + + fn is_smaller(&mut self, u: usize, v: usize) -> bool { + let mut u_queue: Vec = vec![u]; + let u_id = self.dfs_id; + self.visited[v] = u_id; + self.dfs_id += 1; + + let mut v_queue: Vec = vec![v]; + let v_id = self.dfs_id; + self.visited[u] = v_id; + self.dfs_id += 1; + + // parallel depth first search + while !u_queue.is_empty() && !v_queue.is_empty() { + self.dfs_step(&mut u_queue, u_id); + self.dfs_step(&mut v_queue, v_id); + } + u_queue.is_empty() + } + + fn dfs_step(&mut self, queue: &mut Vec, dfs_id: usize) { + let u = queue.pop().unwrap(); + self.visited[u] = dfs_id; + for &v in self.adjacent[u].iter() { + if self.visited[v] == dfs_id { + continue; + } + queue.push(v); + } + } +} + +// checks whether the given graph is a forest +// also checks for all adjacent vertices a,b if adjacent[a].contains(b) && adjacent[b].contains(a) +fn is_forest(adjacent: &Vec>) -> bool { + let mut visited = vec![false; adjacent.len()]; + for node in 0..adjacent.len() { + if visited[node] { + continue; + } + if has_cycle(adjacent, &mut visited, node, node) { + return false; + } + } + true +} + +fn has_cycle( + adjacent: &Vec>, + visited: &mut Vec, + node: usize, + parent: usize, +) -> bool { + visited[node] = true; + for &neighbour in adjacent[node].iter() { + if !adjacent[neighbour].contains(&node) { + panic!("the given graph does not strictly contain bidirectional edges\n {node} -> {neighbour} exists, but the other direction does not"); + } + if !visited[neighbour] { + if has_cycle(adjacent, visited, neighbour, node) { + return true; + } + } else if neighbour != parent { + return true; + } + } + false +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + // test forest (remember the assumptoin that roots are adjacent to themselves) + // _ _ + // \ / \ / + // 0 7 + // / | \ | + // 1 2 3 8 + // / / \ + // 4 5 6 + #[test] + fn construction_test() { + let mut adjacent = vec![ + HashSet::from([0, 1, 2, 3]), + HashSet::from([0, 4]), + HashSet::from([0, 5, 6]), + HashSet::from([0]), + HashSet::from([1]), + HashSet::from([2]), + HashSet::from([2]), + HashSet::from([7, 8]), + HashSet::from([7]), + ]; + let dec_con = super::DecrementalConnectivity::new(adjacent.clone()).unwrap(); + assert_eq!(dec_con.component, vec![0, 0, 0, 0, 0, 0, 0, 1, 1]); + + // add a cycle to the tree + adjacent[2].insert(4); + adjacent[4].insert(2); + assert!(super::DecrementalConnectivity::new(adjacent.clone()).is_err()); + } + #[test] + #[should_panic(expected = "2 -> 4 exists")] + fn non_bidirectional_test() { + let adjacent = vec![ + HashSet::from([0, 1, 2, 3]), + HashSet::from([0, 4]), + HashSet::from([0, 5, 6, 4]), + HashSet::from([0]), + HashSet::from([1]), + HashSet::from([2]), + HashSet::from([2]), + HashSet::from([7, 8]), + HashSet::from([7]), + ]; + + // should panic now since our graph is not bidirectional + super::DecrementalConnectivity::new(adjacent).unwrap(); + } + + #[test] + #[should_panic(expected = "delete called on the edge (2, 4)")] + fn delete_panic_test() { + let adjacent = vec![ + HashSet::from([0, 1, 2, 3]), + HashSet::from([0, 4]), + HashSet::from([0, 5, 6]), + HashSet::from([0]), + HashSet::from([1]), + HashSet::from([2]), + HashSet::from([2]), + HashSet::from([7, 8]), + HashSet::from([7]), + ]; + let mut dec_con = super::DecrementalConnectivity::new(adjacent).unwrap(); + dec_con.delete(2, 4); + } + + #[test] + fn query_test() { + let adjacent = vec![ + HashSet::from([0, 1, 2, 3]), + HashSet::from([0, 4]), + HashSet::from([0, 5, 6]), + HashSet::from([0]), + HashSet::from([1]), + HashSet::from([2]), + HashSet::from([2]), + HashSet::from([7, 8]), + HashSet::from([7]), + ]; + let mut dec_con1 = super::DecrementalConnectivity::new(adjacent.clone()).unwrap(); + assert!(dec_con1.connected(3, 4).unwrap()); + assert!(dec_con1.connected(5, 0).unwrap()); + assert!(!dec_con1.connected(2, 7).unwrap()); + assert!(dec_con1.connected(0, 9).is_none()); + dec_con1.delete(0, 2); + assert!(dec_con1.connected(3, 4).unwrap()); + assert!(!dec_con1.connected(5, 0).unwrap()); + assert!(dec_con1.connected(5, 6).unwrap()); + assert!(dec_con1.connected(8, 7).unwrap()); + dec_con1.delete(7, 8); + assert!(!dec_con1.connected(8, 7).unwrap()); + dec_con1.delete(1, 4); + assert!(!dec_con1.connected(1, 4).unwrap()); + + let mut dec_con2 = super::DecrementalConnectivity::new(adjacent.clone()).unwrap(); + dec_con2.delete(4, 1); + assert!(!dec_con2.connected(1, 4).unwrap()); + + let mut dec_con3 = super::DecrementalConnectivity::new(adjacent).unwrap(); + dec_con3.delete(1, 4); + assert!(!dec_con3.connected(4, 1).unwrap()); + } +} diff --git a/src/graph/depth_first_search_tic_tac_toe.rs b/src/graph/depth_first_search_tic_tac_toe.rs index cff9477d2ef..788991c3823 100644 --- a/src/graph/depth_first_search_tic_tac_toe.rs +++ b/src/graph/depth_first_search_tic_tac_toe.rs @@ -95,14 +95,13 @@ fn main() { if result.is_none() { println!("Not a valid empty coordinate."); continue; - } else { - board[move_pos.y as usize][move_pos.x as usize] = Players::PlayerX; + } + board[move_pos.y as usize][move_pos.x as usize] = Players::PlayerX; - if win_check(Players::PlayerX, &board) { - display_board(&board); - println!("Player X Wins!"); - return; - } + if win_check(Players::PlayerX, &board) { + display_board(&board); + println!("Player X Wins!"); + return; } //Find the best game plays from the current board state @@ -111,7 +110,7 @@ fn main() { Some(x) => { //Interactive Tic-Tac-Toe play needs the "rand = "0.8.3" crate. //#[cfg(not(test))] - //let random_selection = rand::thread_rng().gen_range(0..x.positions.len()); + //let random_selection = rand::rng().gen_range(0..x.positions.len()); let random_selection = 0; let response_pos = x.positions[random_selection]; @@ -274,21 +273,17 @@ fn append_playaction( return; } - let mut play_actions = opt_play_actions.as_mut().unwrap(); + let play_actions = opt_play_actions.as_mut().unwrap(); //New game action is scored from the current side and the current saved best score against the new game action. match (current_side, play_actions.side, appendee.side) { (Players::Blank, _, _) => panic!("Unreachable state."), //Winning scores - (Players::PlayerX, Players::PlayerX, Players::PlayerX) => { - play_actions.positions.push(appendee.position); - } - (Players::PlayerX, Players::PlayerX, _) => {} - (Players::PlayerO, Players::PlayerO, Players::PlayerO) => { + (Players::PlayerX, Players::PlayerX, Players::PlayerX) + | (Players::PlayerO, Players::PlayerO, Players::PlayerO) => { play_actions.positions.push(appendee.position); } - (Players::PlayerO, Players::PlayerO, _) => {} //Non-winning to Winning scores (Players::PlayerX, _, Players::PlayerX) => { @@ -303,21 +298,18 @@ fn append_playaction( } //Losing to Neutral scores - (Players::PlayerX, Players::PlayerO, Players::Blank) => { - play_actions.side = Players::Blank; - play_actions.positions.clear(); - play_actions.positions.push(appendee.position); - } - - (Players::PlayerO, Players::PlayerX, Players::Blank) => { + (Players::PlayerX, Players::PlayerO, Players::Blank) + | (Players::PlayerO, Players::PlayerX, Players::Blank) => { play_actions.side = Players::Blank; play_actions.positions.clear(); play_actions.positions.push(appendee.position); } //Ignoring lower scored plays - (Players::PlayerX, Players::Blank, Players::PlayerO) => {} - (Players::PlayerO, Players::Blank, Players::PlayerX) => {} + (Players::PlayerX, Players::PlayerX, _) + | (Players::PlayerO, Players::PlayerO, _) + | (Players::PlayerX, Players::Blank, Players::PlayerO) + | (Players::PlayerO, Players::Blank, Players::PlayerX) => {} //No change hence append only (_, _, _) => { diff --git a/src/graph/detect_cycle.rs b/src/graph/detect_cycle.rs new file mode 100644 index 00000000000..0243b44eede --- /dev/null +++ b/src/graph/detect_cycle.rs @@ -0,0 +1,294 @@ +use std::collections::{HashMap, HashSet, VecDeque}; + +use crate::data_structures::{graph::Graph, DirectedGraph, UndirectedGraph}; + +pub trait DetectCycle { + fn detect_cycle_dfs(&self) -> bool; + fn detect_cycle_bfs(&self) -> bool; +} + +// Helper function to detect cycle in an undirected graph using DFS graph traversal +fn undirected_graph_detect_cycle_dfs<'a>( + graph: &'a UndirectedGraph, + visited_node: &mut HashSet<&'a String>, + parent: Option<&'a String>, + u: &'a String, +) -> bool { + visited_node.insert(u); + for (v, _) in graph.adjacency_table().get(u).unwrap() { + if matches!(parent, Some(parent) if v == parent) { + continue; + } + if visited_node.contains(v) + || undirected_graph_detect_cycle_dfs(graph, visited_node, Some(u), v) + { + return true; + } + } + false +} + +// Helper function to detect cycle in an undirected graph using BFS graph traversal +fn undirected_graph_detect_cycle_bfs<'a>( + graph: &'a UndirectedGraph, + visited_node: &mut HashSet<&'a String>, + u: &'a String, +) -> bool { + visited_node.insert(u); + + // Initialize the queue for BFS, storing (current node, parent node) tuples + let mut queue = VecDeque::<(&String, Option<&String>)>::new(); + queue.push_back((u, None)); + + while let Some((u, parent)) = queue.pop_front() { + for (v, _) in graph.adjacency_table().get(u).unwrap() { + if matches!(parent, Some(parent) if v == parent) { + continue; + } + if visited_node.contains(v) { + return true; + } + visited_node.insert(v); + queue.push_back((v, Some(u))); + } + } + false +} + +impl DetectCycle for UndirectedGraph { + fn detect_cycle_dfs(&self) -> bool { + let mut visited_node = HashSet::<&String>::new(); + let adj = self.adjacency_table(); + for u in adj.keys() { + if !visited_node.contains(u) + && undirected_graph_detect_cycle_dfs(self, &mut visited_node, None, u) + { + return true; + } + } + false + } + + fn detect_cycle_bfs(&self) -> bool { + let mut visited_node = HashSet::<&String>::new(); + let adj = self.adjacency_table(); + for u in adj.keys() { + if !visited_node.contains(u) + && undirected_graph_detect_cycle_bfs(self, &mut visited_node, u) + { + return true; + } + } + false + } +} + +// Helper function to detect cycle in a directed graph using DFS graph traversal +fn directed_graph_detect_cycle_dfs<'a>( + graph: &'a DirectedGraph, + visited_node: &mut HashSet<&'a String>, + in_stack_visited_node: &mut HashSet<&'a String>, + u: &'a String, +) -> bool { + visited_node.insert(u); + in_stack_visited_node.insert(u); + for (v, _) in graph.adjacency_table().get(u).unwrap() { + if visited_node.contains(v) && in_stack_visited_node.contains(v) { + return true; + } + if !visited_node.contains(v) + && directed_graph_detect_cycle_dfs(graph, visited_node, in_stack_visited_node, v) + { + return true; + } + } + in_stack_visited_node.remove(u); + false +} + +impl DetectCycle for DirectedGraph { + fn detect_cycle_dfs(&self) -> bool { + let mut visited_node = HashSet::<&String>::new(); + let mut in_stack_visited_node = HashSet::<&String>::new(); + let adj = self.adjacency_table(); + for u in adj.keys() { + if !visited_node.contains(u) + && directed_graph_detect_cycle_dfs( + self, + &mut visited_node, + &mut in_stack_visited_node, + u, + ) + { + return true; + } + } + false + } + + // detect cycle in a the graph using Kahn's algorithm + // https://www.geeksforgeeks.org/detect-cycle-in-a-directed-graph-using-bfs/ + fn detect_cycle_bfs(&self) -> bool { + // Set 0 in-degree for each vertex + let mut in_degree: HashMap<&String, usize> = + self.adjacency_table().keys().map(|k| (k, 0)).collect(); + + // Calculate in-degree for each vertex + for u in self.adjacency_table().keys() { + for (v, _) in self.adjacency_table().get(u).unwrap() { + *in_degree.get_mut(v).unwrap() += 1; + } + } + // Initialize queue with vertex having 0 in-degree + let mut queue: VecDeque<&String> = in_degree + .iter() + .filter(|(_, °ree)| degree == 0) + .map(|(&k, _)| k) + .collect(); + + let mut count = 0; + while let Some(u) = queue.pop_front() { + count += 1; + for (v, _) in self.adjacency_table().get(u).unwrap() { + in_degree.entry(v).and_modify(|d| { + *d -= 1; + if *d == 0 { + queue.push_back(v); + } + }); + } + } + + // If count of processed vertices is not equal to the number of vertices, + // the graph has a cycle + count != self.adjacency_table().len() + } +} + +#[cfg(test)] +mod test { + use super::DetectCycle; + use crate::data_structures::{graph::Graph, DirectedGraph, UndirectedGraph}; + fn get_undirected_single_node_with_loop() -> UndirectedGraph { + let mut res = UndirectedGraph::new(); + res.add_edge(("a", "a", 1)); + res + } + fn get_directed_single_node_with_loop() -> DirectedGraph { + let mut res = DirectedGraph::new(); + res.add_edge(("a", "a", 1)); + res + } + fn get_undirected_two_nodes_connected() -> UndirectedGraph { + let mut res = UndirectedGraph::new(); + res.add_edge(("a", "b", 1)); + res + } + fn get_directed_two_nodes_connected() -> DirectedGraph { + let mut res = DirectedGraph::new(); + res.add_edge(("a", "b", 1)); + res.add_edge(("b", "a", 1)); + res + } + fn get_directed_two_nodes() -> DirectedGraph { + let mut res = DirectedGraph::new(); + res.add_edge(("a", "b", 1)); + res + } + fn get_undirected_triangle() -> UndirectedGraph { + let mut res = UndirectedGraph::new(); + res.add_edge(("a", "b", 1)); + res.add_edge(("b", "c", 1)); + res.add_edge(("c", "a", 1)); + res + } + fn get_directed_triangle() -> DirectedGraph { + let mut res = DirectedGraph::new(); + res.add_edge(("a", "b", 1)); + res.add_edge(("b", "c", 1)); + res.add_edge(("c", "a", 1)); + res + } + fn get_undirected_triangle_with_tail() -> UndirectedGraph { + let mut res = get_undirected_triangle(); + res.add_edge(("c", "d", 1)); + res.add_edge(("d", "e", 1)); + res.add_edge(("e", "f", 1)); + res.add_edge(("g", "h", 1)); + res + } + fn get_directed_triangle_with_tail() -> DirectedGraph { + let mut res = get_directed_triangle(); + res.add_edge(("c", "d", 1)); + res.add_edge(("d", "e", 1)); + res.add_edge(("e", "f", 1)); + res.add_edge(("g", "h", 1)); + res + } + fn get_undirected_graph_with_cycle() -> UndirectedGraph { + let mut res = UndirectedGraph::new(); + res.add_edge(("a", "b", 1)); + res.add_edge(("a", "c", 1)); + res.add_edge(("b", "c", 1)); + res.add_edge(("b", "d", 1)); + res.add_edge(("c", "d", 1)); + res + } + fn get_undirected_graph_without_cycle() -> UndirectedGraph { + let mut res = UndirectedGraph::new(); + res.add_edge(("a", "b", 1)); + res.add_edge(("a", "c", 1)); + res.add_edge(("b", "d", 1)); + res.add_edge(("c", "e", 1)); + res + } + fn get_directed_graph_with_cycle() -> DirectedGraph { + let mut res = DirectedGraph::new(); + res.add_edge(("b", "a", 1)); + res.add_edge(("c", "a", 1)); + res.add_edge(("b", "c", 1)); + res.add_edge(("c", "d", 1)); + res.add_edge(("d", "b", 1)); + res + } + fn get_directed_graph_without_cycle() -> DirectedGraph { + let mut res = DirectedGraph::new(); + res.add_edge(("b", "a", 1)); + res.add_edge(("c", "a", 1)); + res.add_edge(("b", "c", 1)); + res.add_edge(("c", "d", 1)); + res.add_edge(("b", "d", 1)); + res + } + macro_rules! test_detect_cycle { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (graph, has_cycle) = $test_case; + println!("detect_cycle_dfs: {}", graph.detect_cycle_dfs()); + println!("detect_cycle_bfs: {}", graph.detect_cycle_bfs()); + assert_eq!(graph.detect_cycle_dfs(), has_cycle); + assert_eq!(graph.detect_cycle_bfs(), has_cycle); + } + )* + }; + } + test_detect_cycle! { + undirected_empty: (UndirectedGraph::new(), false), + directed_empty: (DirectedGraph::new(), false), + undirected_single_node_with_loop: (get_undirected_single_node_with_loop(), true), + directed_single_node_with_loop: (get_directed_single_node_with_loop(), true), + undirected_two_nodes_connected: (get_undirected_two_nodes_connected(), false), + directed_two_nodes_connected: (get_directed_two_nodes_connected(), true), + directed_two_nodes: (get_directed_two_nodes(), false), + undirected_triangle: (get_undirected_triangle(), true), + undirected_triangle_with_tail: (get_undirected_triangle_with_tail(), true), + directed_triangle: (get_directed_triangle(), true), + directed_triangle_with_tail: (get_directed_triangle_with_tail(), true), + undirected_graph_with_cycle: (get_undirected_graph_with_cycle(), true), + undirected_graph_without_cycle: (get_undirected_graph_without_cycle(), false), + directed_graph_with_cycle: (get_directed_graph_with_cycle(), true), + directed_graph_without_cycle: (get_directed_graph_without_cycle(), false), + } +} diff --git a/src/graph/dijkstra.rs b/src/graph/dijkstra.rs index 7b28a6bf365..8cef293abe7 100644 --- a/src/graph/dijkstra.rs +++ b/src/graph/dijkstra.rs @@ -1,47 +1,48 @@ -use std::cmp::Reverse; -use std::collections::{BTreeMap, BinaryHeap}; +use std::collections::{BTreeMap, BTreeSet}; use std::ops::Add; type Graph = BTreeMap>; // performs Dijsktra's algorithm on the given graph from the given start -// the graph is a positively-weighted undirected graph +// the graph is a positively-weighted directed graph // // returns a map that for each reachable vertex associates the distance and the predecessor // since the start has no predecessor but is reachable, map[start] will be None +// +// Time: O(E * logV). For each vertex, we traverse each edge, resulting in O(E). For each edge, we +// insert a new shortest path for a vertex into the tree, resulting in O(E * logV). +// Space: O(V). The tree holds up to V vertices. pub fn dijkstra>( graph: &Graph, - start: &V, + start: V, ) -> BTreeMap> { let mut ans = BTreeMap::new(); - let mut prio = BinaryHeap::new(); + let mut prio = BTreeSet::new(); // start is the special case that doesn't have a predecessor - ans.insert(*start, None); + ans.insert(start, None); - for (new, weight) in &graph[start] { - ans.insert(*new, Some((*start, *weight))); - prio.push(Reverse((*weight, new, start))); + for (new, weight) in &graph[&start] { + ans.insert(*new, Some((start, *weight))); + prio.insert((*weight, *new)); } - while let Some(Reverse((dist_new, new, prev))) = prio.pop() { - match ans[new] { - // what we popped is what is in ans, we'll compute it - Some((p, d)) if p == *prev && d == dist_new => {} - // otherwise it's not interesting - _ => continue, - } - - for (next, weight) in &graph[new] { + while let Some((path_weight, vertex)) = prio.pop_first() { + for (next, weight) in &graph[&vertex] { + let new_weight = path_weight + *weight; match ans.get(next) { // if ans[next] is a lower dist than the alternative one, we do nothing - Some(Some((_, dist_next))) if dist_new + *weight >= *dist_next => {} + Some(Some((_, dist_next))) if new_weight >= *dist_next => {} // if ans[next] is None then next is start and so the distance won't be changed, it won't be added again in prio Some(None) => {} // the new path is shorter, either new was not in ans or it was farther _ => { - ans.insert(*next, Some((*new, *weight + dist_new))); - prio.push(Reverse((*weight + dist_new, next, new))); + if let Some(Some((_, prev_weight))) = + ans.insert(*next, Some((vertex, new_weight))) + { + prio.remove(&(prev_weight, *next)); + } + prio.insert((new_weight, *next)); } } } @@ -56,8 +57,8 @@ mod tests { use std::collections::BTreeMap; fn add_edge(graph: &mut Graph, v1: V, v2: V, c: E) { - graph.entry(v1).or_insert_with(BTreeMap::new).insert(v2, c); - graph.entry(v2).or_insert_with(BTreeMap::new); + graph.entry(v1).or_default().insert(v2, c); + graph.entry(v2).or_default(); } #[test] @@ -68,7 +69,7 @@ mod tests { let mut dists = BTreeMap::new(); dists.insert(0, None); - assert_eq!(dijkstra(&graph, &0), dists); + assert_eq!(dijkstra(&graph, 0), dists); } #[test] @@ -80,12 +81,12 @@ mod tests { dists_0.insert(0, None); dists_0.insert(1, Some((0, 2))); - assert_eq!(dijkstra(&graph, &0), dists_0); + assert_eq!(dijkstra(&graph, 0), dists_0); let mut dists_1 = BTreeMap::new(); dists_1.insert(1, None); - assert_eq!(dijkstra(&graph, &1), dists_1); + assert_eq!(dijkstra(&graph, 1), dists_1); } #[test] @@ -109,7 +110,7 @@ mod tests { } } - assert_eq!(dijkstra(&graph, &1), dists); + assert_eq!(dijkstra(&graph, 1), dists); } #[test] @@ -127,25 +128,25 @@ mod tests { dists_a.insert('c', Some(('a', 12))); dists_a.insert('d', Some(('c', 44))); dists_a.insert('b', Some(('c', 32))); - assert_eq!(dijkstra(&graph, &'a'), dists_a); + assert_eq!(dijkstra(&graph, 'a'), dists_a); let mut dists_b = BTreeMap::new(); dists_b.insert('b', None); dists_b.insert('a', Some(('b', 10))); dists_b.insert('c', Some(('a', 22))); dists_b.insert('d', Some(('c', 54))); - assert_eq!(dijkstra(&graph, &'b'), dists_b); + assert_eq!(dijkstra(&graph, 'b'), dists_b); let mut dists_c = BTreeMap::new(); dists_c.insert('c', None); dists_c.insert('b', Some(('c', 20))); dists_c.insert('d', Some(('c', 32))); dists_c.insert('a', Some(('b', 30))); - assert_eq!(dijkstra(&graph, &'c'), dists_c); + assert_eq!(dijkstra(&graph, 'c'), dists_c); let mut dists_d = BTreeMap::new(); dists_d.insert('d', None); - assert_eq!(dijkstra(&graph, &'d'), dists_d); + assert_eq!(dijkstra(&graph, 'd'), dists_d); let mut dists_e = BTreeMap::new(); dists_e.insert('e', None); @@ -153,6 +154,6 @@ mod tests { dists_e.insert('c', Some(('a', 19))); dists_e.insert('d', Some(('c', 51))); dists_e.insert('b', Some(('c', 39))); - assert_eq!(dijkstra(&graph, &'e'), dists_e); + assert_eq!(dijkstra(&graph, 'e'), dists_e); } } diff --git a/src/graph/dinic_maxflow.rs b/src/graph/dinic_maxflow.rs index fb0f3121893..87ff7a7953a 100644 --- a/src/graph/dinic_maxflow.rs +++ b/src/graph/dinic_maxflow.rs @@ -194,8 +194,8 @@ mod tests { let max_flow = flow.find_maxflow(i32::MAX); assert_eq!(max_flow, 23); - let mut sm_out = vec![0; 7]; - let mut sm_in = vec![0; 7]; + let mut sm_out = [0; 7]; + let mut sm_in = [0; 7]; let flow_edges = flow.get_flow_edges(i32::MAX); for e in flow_edges { diff --git a/src/graph/disjoint_set_union.rs b/src/graph/disjoint_set_union.rs index 5566de5e2f1..d20701c8c00 100644 --- a/src/graph/disjoint_set_union.rs +++ b/src/graph/disjoint_set_union.rs @@ -1,95 +1,148 @@ +//! This module implements the Disjoint Set Union (DSU), also known as Union-Find, +//! which is an efficient data structure for keeping track of a set of elements +//! partitioned into disjoint (non-overlapping) subsets. + +/// Represents a node in the Disjoint Set Union (DSU) structure which +/// keep track of the parent-child relationships in the disjoint sets. pub struct DSUNode { + /// The index of the node's parent, or itself if it's the root. parent: usize, + /// The size of the set rooted at this node, used for union by size. size: usize, } +/// Disjoint Set Union (Union-Find) data structure, particularly useful for +/// managing dynamic connectivity problems such as determining +/// if two elements are in the same subset or merging two subsets. pub struct DisjointSetUnion { + /// List of DSU nodes where each element's parent and size are tracked. nodes: Vec, } -// We are using both path compression and union by size impl DisjointSetUnion { - // Create n+1 sets [0, n] - pub fn new(n: usize) -> DisjointSetUnion { - let mut nodes = Vec::new(); - nodes.reserve_exact(n + 1); - for i in 0..=n { - nodes.push(DSUNode { parent: i, size: 1 }); + /// Initializes `n + 1` disjoint sets, each element is its own parent. + /// + /// # Parameters + /// + /// - `n`: The number of elements to manage (`0` to `n` inclusive). + /// + /// # Returns + /// + /// A new instance of `DisjointSetUnion` with `n + 1` independent sets. + pub fn new(num_elements: usize) -> DisjointSetUnion { + let mut nodes = Vec::with_capacity(num_elements + 1); + for idx in 0..=num_elements { + nodes.push(DSUNode { + parent: idx, + size: 1, + }); } - DisjointSetUnion { nodes } + + Self { nodes } } - pub fn find_set(&mut self, v: usize) -> usize { - if v == self.nodes[v].parent { - return v; + + /// Finds the representative (root) of the set containing `element` with path compression. + /// + /// Path compression ensures that future queries are faster by directly linking + /// all nodes in the path to the root. + /// + /// # Parameters + /// + /// - `element`: The element whose set representative is being found. + /// + /// # Returns + /// + /// The root representative of the set containing `element`. + pub fn find_set(&mut self, element: usize) -> usize { + if element != self.nodes[element].parent { + self.nodes[element].parent = self.find_set(self.nodes[element].parent); } - self.nodes[v].parent = self.find_set(self.nodes[v].parent); - self.nodes[v].parent + self.nodes[element].parent } - // Returns the new component of the merged sets, - // or std::usize::MAX if they were the same. - pub fn merge(&mut self, u: usize, v: usize) -> usize { - let mut a = self.find_set(u); - let mut b = self.find_set(v); - if a == b { - return std::usize::MAX; + + /// Merges the sets containing `first_elem` and `sec_elem` using union by size. + /// + /// The smaller set is always attached to the root of the larger set to ensure balanced trees. + /// + /// # Parameters + /// + /// - `first_elem`: The first element whose set is to be merged. + /// - `sec_elem`: The second element whose set is to be merged. + /// + /// # Returns + /// + /// The root of the merged set, or `usize::MAX` if both elements are already in the same set. + pub fn merge(&mut self, first_elem: usize, sec_elem: usize) -> usize { + let mut first_root = self.find_set(first_elem); + let mut sec_root = self.find_set(sec_elem); + + if first_root == sec_root { + // Already in the same set, no merge required + return usize::MAX; } - if self.nodes[a].size < self.nodes[b].size { - std::mem::swap(&mut a, &mut b); + + // Union by size: attach the smaller tree under the larger tree + if self.nodes[first_root].size < self.nodes[sec_root].size { + std::mem::swap(&mut first_root, &mut sec_root); } - self.nodes[b].parent = a; - self.nodes[a].size += self.nodes[b].size; - a + + self.nodes[sec_root].parent = first_root; + self.nodes[first_root].size += self.nodes[sec_root].size; + + first_root } } #[cfg(test)] mod tests { use super::*; + #[test] - fn create_acyclic_graph() { + fn test_disjoint_set_union() { let mut dsu = DisjointSetUnion::new(10); - // Add edges such that vertices 1..=9 are connected - // and vertex 10 is not connected to the other ones - let edges: Vec<(usize, usize)> = vec![ - (1, 2), // + - (2, 1), - (2, 3), // + - (1, 3), - (4, 5), // + - (7, 8), // + - (4, 8), // + - (3, 8), // + - (1, 9), // + - (2, 9), - (3, 9), - (4, 9), - (5, 9), - (6, 9), // + - (7, 9), - ]; - let expected_edges: Vec<(usize, usize)> = vec![ - (1, 2), - (2, 3), - (4, 5), - (7, 8), - (4, 8), - (3, 8), - (1, 9), - (6, 9), - ]; - let mut added_edges: Vec<(usize, usize)> = Vec::new(); - for (u, v) in edges { - if dsu.merge(u, v) < std::usize::MAX { - added_edges.push((u, v)); - } - // Now they should be the same - assert!(dsu.merge(u, v) == std::usize::MAX); - } - assert_eq!(added_edges, expected_edges); - let comp_1 = dsu.find_set(1); - for i in 2..=9 { - assert_eq!(comp_1, dsu.find_set(i)); - } - assert_ne!(comp_1, dsu.find_set(10)); + + dsu.merge(1, 2); + dsu.merge(2, 3); + dsu.merge(1, 9); + dsu.merge(4, 5); + dsu.merge(7, 8); + dsu.merge(4, 8); + dsu.merge(6, 9); + + assert_eq!(dsu.find_set(1), dsu.find_set(2)); + assert_eq!(dsu.find_set(1), dsu.find_set(3)); + assert_eq!(dsu.find_set(1), dsu.find_set(6)); + assert_eq!(dsu.find_set(1), dsu.find_set(9)); + + assert_eq!(dsu.find_set(4), dsu.find_set(5)); + assert_eq!(dsu.find_set(4), dsu.find_set(7)); + assert_eq!(dsu.find_set(4), dsu.find_set(8)); + + assert_ne!(dsu.find_set(1), dsu.find_set(10)); + assert_ne!(dsu.find_set(4), dsu.find_set(10)); + + dsu.merge(3, 4); + + assert_eq!(dsu.find_set(1), dsu.find_set(2)); + assert_eq!(dsu.find_set(1), dsu.find_set(3)); + assert_eq!(dsu.find_set(1), dsu.find_set(6)); + assert_eq!(dsu.find_set(1), dsu.find_set(9)); + assert_eq!(dsu.find_set(1), dsu.find_set(4)); + assert_eq!(dsu.find_set(1), dsu.find_set(5)); + assert_eq!(dsu.find_set(1), dsu.find_set(7)); + assert_eq!(dsu.find_set(1), dsu.find_set(8)); + + assert_ne!(dsu.find_set(1), dsu.find_set(10)); + + dsu.merge(10, 1); + assert_eq!(dsu.find_set(10), dsu.find_set(1)); + assert_eq!(dsu.find_set(10), dsu.find_set(2)); + assert_eq!(dsu.find_set(10), dsu.find_set(3)); + assert_eq!(dsu.find_set(10), dsu.find_set(4)); + assert_eq!(dsu.find_set(10), dsu.find_set(5)); + assert_eq!(dsu.find_set(10), dsu.find_set(6)); + assert_eq!(dsu.find_set(10), dsu.find_set(7)); + assert_eq!(dsu.find_set(10), dsu.find_set(8)); + assert_eq!(dsu.find_set(10), dsu.find_set(9)); } } diff --git a/src/graph/eulerian_path.rs b/src/graph/eulerian_path.rs new file mode 100644 index 00000000000..d37ee053f43 --- /dev/null +++ b/src/graph/eulerian_path.rs @@ -0,0 +1,380 @@ +//! This module provides functionality to find an Eulerian path in a directed graph. +//! An Eulerian path visits every edge exactly once. The algorithm checks if an Eulerian +//! path exists and, if so, constructs and returns the path. + +use std::collections::LinkedList; + +/// Finds an Eulerian path in a directed graph. +/// +/// # Arguments +/// +/// * `node_count` - The number of nodes in the graph. +/// * `edge_list` - A vector of tuples representing directed edges, where each tuple is of the form `(start, end)`. +/// +/// # Returns +/// +/// An `Option>` containing the Eulerian path if it exists; otherwise, `None`. +pub fn find_eulerian_path(node_count: usize, edge_list: Vec<(usize, usize)>) -> Option> { + let mut adjacency_list = vec![Vec::new(); node_count]; + for (start, end) in edge_list { + adjacency_list[start].push(end); + } + + let mut eulerian_solver = EulerianPathSolver::new(adjacency_list); + eulerian_solver.find_path() +} + +/// Struct to represent the solver for finding an Eulerian path in a directed graph. +pub struct EulerianPathSolver { + node_count: usize, + edge_count: usize, + in_degrees: Vec, + out_degrees: Vec, + eulerian_path: LinkedList, + adjacency_list: Vec>, +} + +impl EulerianPathSolver { + /// Creates a new instance of `EulerianPathSolver`. + /// + /// # Arguments + /// + /// * `adjacency_list` - The graph represented as an adjacency list. + /// + /// # Returns + /// + /// A new instance of `EulerianPathSolver`. + pub fn new(adjacency_list: Vec>) -> Self { + Self { + node_count: adjacency_list.len(), + edge_count: 0, + in_degrees: vec![0; adjacency_list.len()], + out_degrees: vec![0; adjacency_list.len()], + eulerian_path: LinkedList::new(), + adjacency_list, + } + } + + /// Find the Eulerian path if it exists. + /// + /// # Returns + /// + /// An `Option>` containing the Eulerian path if found; otherwise, `None`. + /// + /// If multiple Eulerian paths exist, the one found will be returned, but it may not be unique. + fn find_path(&mut self) -> Option> { + self.initialize_degrees(); + + if !self.has_eulerian_path() { + return None; + } + + let start_node = self.get_start_node(); + self.depth_first_search(start_node); + + if self.eulerian_path.len() != self.edge_count + 1 { + return None; + } + + let mut path = Vec::with_capacity(self.edge_count + 1); + while let Some(node) = self.eulerian_path.pop_front() { + path.push(node); + } + + Some(path) + } + + /// Initializes in-degrees and out-degrees for each node and counts total edges. + fn initialize_degrees(&mut self) { + for (start_node, neighbors) in self.adjacency_list.iter().enumerate() { + for &end_node in neighbors { + self.in_degrees[end_node] += 1; + self.out_degrees[start_node] += 1; + self.edge_count += 1; + } + } + } + + /// Checks if an Eulerian path exists in the graph. + /// + /// # Returns + /// + /// `true` if an Eulerian path exists; otherwise, `false`. + fn has_eulerian_path(&self) -> bool { + if self.edge_count == 0 { + return false; + } + + let (mut start_nodes, mut end_nodes) = (0, 0); + for i in 0..self.node_count { + let (in_degree, out_degree) = + (self.in_degrees[i] as isize, self.out_degrees[i] as isize); + match out_degree - in_degree { + 1 => start_nodes += 1, + -1 => end_nodes += 1, + degree_diff if degree_diff.abs() > 1 => return false, + _ => (), + } + } + + (start_nodes == 0 && end_nodes == 0) || (start_nodes == 1 && end_nodes == 1) + } + + /// Finds the starting node for the Eulerian path. + /// + /// # Returns + /// + /// The index of the starting node. + fn get_start_node(&self) -> usize { + for i in 0..self.node_count { + if self.out_degrees[i] > self.in_degrees[i] { + return i; + } + } + (0..self.node_count) + .find(|&i| self.out_degrees[i] > 0) + .unwrap_or(0) + } + + /// Performs depth-first search to construct the Eulerian path. + /// + /// # Arguments + /// + /// * `curr_node` - The current node being visited in the DFS traversal. + fn depth_first_search(&mut self, curr_node: usize) { + while self.out_degrees[curr_node] > 0 { + let next_node = self.adjacency_list[curr_node][self.out_degrees[curr_node] - 1]; + self.out_degrees[curr_node] -= 1; + self.depth_first_search(next_node); + } + self.eulerian_path.push_front(curr_node); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (n, edges, expected) = $test_case; + assert_eq!(find_eulerian_path(n, edges), expected); + } + )* + } + } + + test_cases! { + test_eulerian_cycle: ( + 7, + vec![ + (1, 2), + (1, 3), + (2, 2), + (2, 4), + (2, 4), + (3, 1), + (3, 2), + (3, 5), + (4, 3), + (4, 6), + (5, 6), + (6, 3) + ], + Some(vec![1, 3, 5, 6, 3, 2, 4, 3, 1, 2, 2, 4, 6]) + ), + test_simple_path: ( + 5, + vec![ + (0, 1), + (1, 2), + (1, 4), + (1, 3), + (2, 1), + (4, 1) + ], + Some(vec![0, 1, 4, 1, 2, 1, 3]) + ), + test_disconnected_graph: ( + 4, + vec![ + (0, 1), + (2, 3) + ], + None::> + ), + test_single_cycle: ( + 4, + vec![ + (0, 1), + (1, 2), + (2, 3), + (3, 0) + ], + Some(vec![0, 1, 2, 3, 0]) + ), + test_empty_graph: ( + 3, + vec![], + None::> + ), + test_unbalanced_path: ( + 3, + vec![ + (0, 1), + (1, 2), + (2, 0), + (0, 2) + ], + Some(vec![0, 2, 0, 1, 2]) + ), + test_no_eulerian_path: ( + 3, + vec![ + (0, 1), + (0, 2) + ], + None::> + ), + test_complex_eulerian_path: ( + 6, + vec![ + (0, 1), + (1, 2), + (2, 3), + (3, 4), + (4, 0), + (0, 5), + (5, 0), + (2, 0) + ], + Some(vec![2, 0, 5, 0, 1, 2, 3, 4, 0]) + ), + test_single_node_self_loop: ( + 1, + vec![(0, 0)], + Some(vec![0, 0]) + ), + test_complete_graph: ( + 3, + vec![ + (0, 1), + (0, 2), + (1, 0), + (1, 2), + (2, 0), + (2, 1) + ], + Some(vec![0, 2, 1, 2, 0, 1, 0]) + ), + test_multiple_disconnected_components: ( + 6, + vec![ + (0, 1), + (2, 3), + (4, 5) + ], + None::> + ), + test_unbalanced_graph_with_path: ( + 4, + vec![ + (0, 1), + (1, 2), + (2, 3), + (3, 1) + ], + Some(vec![0, 1, 2, 3, 1]) + ), + test_node_with_no_edges: ( + 4, + vec![ + (0, 1), + (1, 2) + ], + Some(vec![0, 1, 2]) + ), + test_multiple_edges_between_same_nodes: ( + 3, + vec![ + (0, 1), + (1, 2), + (1, 2), + (2, 0) + ], + Some(vec![1, 2, 0, 1, 2]) + ), + test_larger_graph_with_eulerian_path: ( + 10, + vec![ + (0, 1), + (1, 2), + (2, 3), + (3, 4), + (4, 5), + (5, 6), + (6, 7), + (7, 8), + (8, 9), + (9, 0), + (1, 6), + (6, 3), + (3, 8) + ], + Some(vec![1, 6, 3, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8]) + ), + test_no_edges_multiple_nodes: ( + 5, + vec![], + None::> + ), + test_multiple_start_and_end_nodes: ( + 4, + vec![ + (0, 1), + (1, 2), + (2, 0), + (0, 2), + (1, 3) + ], + None::> + ), + test_single_edge: ( + 2, + vec![(0, 1)], + Some(vec![0, 1]) + ), + test_multiple_eulerian_paths: ( + 4, + vec![ + (0, 1), + (1, 2), + (2, 0), + (0, 3), + (3, 0) + ], + Some(vec![0, 3, 0, 1, 2, 0]) + ), + test_dag_path: ( + 4, + vec![ + (0, 1), + (1, 2), + (2, 3) + ], + Some(vec![0, 1, 2, 3]) + ), + test_parallel_edges_case: ( + 2, + vec![ + (0, 1), + (0, 1), + (1, 0) + ], + Some(vec![0, 1, 0, 1]) + ), + } +} diff --git a/src/graph/floyd_warshall.rs b/src/graph/floyd_warshall.rs new file mode 100644 index 00000000000..0a78b992e5d --- /dev/null +++ b/src/graph/floyd_warshall.rs @@ -0,0 +1,191 @@ +use num_traits::Zero; +use std::collections::BTreeMap; +use std::ops::Add; + +type Graph = BTreeMap>; + +/// Performs the Floyd-Warshall algorithm on the input graph.\ +/// The graph is a weighted, directed graph with no negative cycles. +/// +/// Returns a map storing the distance from each node to all the others.\ +/// i.e. For each vertex `u`, `map[u][v] == Some(distance)` means +/// distance is the sum of the weights of the edges on the shortest path +/// from `u` to `v`. +/// +/// For a key `v`, if `map[v].len() == 0`, then `v` cannot reach any other vertex, but is in the graph +/// (island node, or sink in the case of a directed graph) +pub fn floyd_warshall + num_traits::Zero>( + graph: &Graph, +) -> BTreeMap> { + let mut map: BTreeMap> = BTreeMap::new(); + for (u, edges) in graph.iter() { + if !map.contains_key(u) { + map.insert(*u, BTreeMap::new()); + } + map.entry(*u).or_default().insert(*u, Zero::zero()); + for (v, weight) in edges.iter() { + if !map.contains_key(v) { + map.insert(*v, BTreeMap::new()); + } + map.entry(*v).or_default().insert(*v, Zero::zero()); + map.entry(*u).and_modify(|mp| { + mp.insert(*v, *weight); + }); + } + } + let keys = map.keys().copied().collect::>(); + for &k in &keys { + for &i in &keys { + if !map[&i].contains_key(&k) { + continue; + } + for &j in &keys { + if i == j { + continue; + } + if !map[&k].contains_key(&j) { + continue; + } + let entry_i_j = map[&i].get(&j); + let entry_i_k = map[&i][&k]; + let entry_k_j = map[&k][&j]; + match entry_i_j { + Some(&e) => { + if e > entry_i_k + entry_k_j { + map.entry(i).or_default().insert(j, entry_i_k + entry_k_j); + } + } + None => { + map.entry(i).or_default().insert(j, entry_i_k + entry_k_j); + } + }; + } + } + } + map +} + +#[cfg(test)] +mod tests { + use super::{floyd_warshall, Graph}; + use std::collections::BTreeMap; + + fn add_edge(graph: &mut Graph, v1: V, v2: V, c: E) { + graph.entry(v1).or_default().insert(v2, c); + } + + fn bi_add_edge(graph: &mut Graph, v1: V, v2: V, c: E) { + add_edge(graph, v1, v2, c); + add_edge(graph, v2, v1, c); + } + + #[test] + fn single_vertex() { + let mut graph: Graph = BTreeMap::new(); + graph.insert(0, BTreeMap::new()); + + let mut dists = BTreeMap::new(); + dists.insert(0, BTreeMap::new()); + dists.get_mut(&0).unwrap().insert(0, 0); + assert_eq!(floyd_warshall(&graph), dists); + } + + #[test] + fn single_edge() { + let mut graph = BTreeMap::new(); + bi_add_edge(&mut graph, 0, 1, 2); + bi_add_edge(&mut graph, 1, 2, 3); + + let mut dists_0 = BTreeMap::new(); + dists_0.insert(0, BTreeMap::new()); + dists_0.insert(1, BTreeMap::new()); + dists_0.insert(2, BTreeMap::new()); + dists_0.get_mut(&0).unwrap().insert(0, 0); + dists_0.get_mut(&1).unwrap().insert(1, 0); + dists_0.get_mut(&2).unwrap().insert(2, 0); + dists_0.get_mut(&1).unwrap().insert(0, 2); + dists_0.get_mut(&0).unwrap().insert(1, 2); + dists_0.get_mut(&1).unwrap().insert(2, 3); + dists_0.get_mut(&2).unwrap().insert(1, 3); + dists_0.get_mut(&2).unwrap().insert(0, 5); + dists_0.get_mut(&0).unwrap().insert(2, 5); + + assert_eq!(floyd_warshall(&graph), dists_0); + } + + #[test] + fn graph_1() { + let mut graph = BTreeMap::new(); + add_edge(&mut graph, 'a', 'c', 12); + add_edge(&mut graph, 'a', 'd', 60); + add_edge(&mut graph, 'b', 'a', 10); + add_edge(&mut graph, 'c', 'b', 20); + add_edge(&mut graph, 'c', 'd', 32); + add_edge(&mut graph, 'e', 'a', 7); + + let mut dists_a = BTreeMap::new(); + dists_a.insert('d', BTreeMap::new()); + + dists_a.entry('a').or_insert(BTreeMap::new()).insert('a', 0); + dists_a.entry('b').or_insert(BTreeMap::new()).insert('b', 0); + dists_a.entry('c').or_insert(BTreeMap::new()).insert('c', 0); + dists_a.entry('d').or_insert(BTreeMap::new()).insert('d', 0); + dists_a.entry('e').or_insert(BTreeMap::new()).insert('e', 0); + dists_a + .entry('a') + .or_insert(BTreeMap::new()) + .insert('c', 12); + dists_a + .entry('c') + .or_insert(BTreeMap::new()) + .insert('a', 30); + dists_a + .entry('c') + .or_insert(BTreeMap::new()) + .insert('b', 20); + dists_a + .entry('c') + .or_insert(BTreeMap::new()) + .insert('d', 32); + dists_a.entry('e').or_insert(BTreeMap::new()).insert('a', 7); + dists_a + .entry('b') + .or_insert(BTreeMap::new()) + .insert('a', 10); + dists_a + .entry('a') + .or_insert(BTreeMap::new()) + .insert('d', 44); + dists_a + .entry('a') + .or_insert(BTreeMap::new()) + .insert('b', 32); + dists_a + .entry('a') + .or_insert(BTreeMap::new()) + .insert('b', 32); + dists_a + .entry('b') + .or_insert(BTreeMap::new()) + .insert('c', 22); + + dists_a + .entry('b') + .or_insert(BTreeMap::new()) + .insert('d', 54); + dists_a + .entry('e') + .or_insert(BTreeMap::new()) + .insert('c', 19); + dists_a + .entry('e') + .or_insert(BTreeMap::new()) + .insert('d', 51); + dists_a + .entry('e') + .or_insert(BTreeMap::new()) + .insert('b', 39); + + assert_eq!(floyd_warshall(&graph), dists_a); + } +} diff --git a/src/graph/ford_fulkerson.rs b/src/graph/ford_fulkerson.rs new file mode 100644 index 00000000000..c6a2f310ebe --- /dev/null +++ b/src/graph/ford_fulkerson.rs @@ -0,0 +1,314 @@ +//! The Ford-Fulkerson algorithm is a widely used algorithm to solve the maximum flow problem in a flow network. +//! +//! The maximum flow problem involves determining the maximum amount of flow that can be sent from a source vertex to a sink vertex +//! in a directed weighted graph, subject to capacity constraints on the edges. + +use std::collections::VecDeque; + +/// Enum representing the possible errors that can occur when running the Ford-Fulkerson algorithm. +#[derive(Debug, PartialEq)] +pub enum FordFulkersonError { + EmptyGraph, + ImproperGraph, + SourceOutOfBounds, + SinkOutOfBounds, +} + +/// Performs a Breadth-First Search (BFS) on the residual graph to find an augmenting path +/// from the source vertex `source` to the sink vertex `sink`. +/// +/// # Arguments +/// +/// * `graph` - A reference to the residual graph represented as an adjacency matrix. +/// * `source` - The source vertex. +/// * `sink` - The sink vertex. +/// * `parent` - A mutable reference to the parent array used to store the augmenting path. +/// +/// # Returns +/// +/// Returns `true` if an augmenting path is found from `source` to `sink`, `false` otherwise. +fn bfs(graph: &[Vec], source: usize, sink: usize, parent: &mut [usize]) -> bool { + let mut visited = vec![false; graph.len()]; + visited[source] = true; + parent[source] = usize::MAX; + + let mut queue = VecDeque::new(); + queue.push_back(source); + + while let Some(current_vertex) = queue.pop_front() { + for (previous_vertex, &capacity) in graph[current_vertex].iter().enumerate() { + if !visited[previous_vertex] && capacity > 0 { + visited[previous_vertex] = true; + parent[previous_vertex] = current_vertex; + if previous_vertex == sink { + return true; + } + queue.push_back(previous_vertex); + } + } + } + + false +} + +/// Validates the input parameters for the Ford-Fulkerson algorithm. +/// +/// This function checks if the provided graph, source vertex, and sink vertex +/// meet the requirements for the Ford-Fulkerson algorithm. It ensures the graph +/// is non-empty, square (each row has the same length as the number of rows), and +/// that the source and sink vertices are within the valid range of vertex indices. +/// +/// # Arguments +/// +/// * `graph` - A reference to the flow network represented as an adjacency matrix. +/// * `source` - The source vertex. +/// * `sink` - The sink vertex. +/// +/// # Returns +/// +/// Returns `Ok(())` if the input parameters are valid, otherwise returns an appropriate +/// `FordFulkersonError`. +fn validate_ford_fulkerson_input( + graph: &[Vec], + source: usize, + sink: usize, +) -> Result<(), FordFulkersonError> { + if graph.is_empty() { + return Err(FordFulkersonError::EmptyGraph); + } + + if graph.iter().any(|row| row.len() != graph.len()) { + return Err(FordFulkersonError::ImproperGraph); + } + + if source >= graph.len() { + return Err(FordFulkersonError::SourceOutOfBounds); + } + + if sink >= graph.len() { + return Err(FordFulkersonError::SinkOutOfBounds); + } + + Ok(()) +} + +/// Applies the Ford-Fulkerson algorithm to find the maximum flow in a flow network +/// represented by a weighted directed graph. +/// +/// # Arguments +/// +/// * `graph` - A mutable reference to the flow network represented as an adjacency matrix. +/// * `source` - The source vertex. +/// * `sink` - The sink vertex. +/// +/// # Returns +/// +/// Returns the maximum flow and the residual graph +pub fn ford_fulkerson( + graph: &[Vec], + source: usize, + sink: usize, +) -> Result { + validate_ford_fulkerson_input(graph, source, sink)?; + + let mut residual_graph = graph.to_owned(); + let mut parent = vec![usize::MAX; graph.len()]; + let mut max_flow = 0; + + while bfs(&residual_graph, source, sink, &mut parent) { + let mut path_flow = usize::MAX; + let mut previous_vertex = sink; + + while previous_vertex != source { + let current_vertex = parent[previous_vertex]; + path_flow = path_flow.min(residual_graph[current_vertex][previous_vertex]); + previous_vertex = current_vertex; + } + + previous_vertex = sink; + while previous_vertex != source { + let current_vertex = parent[previous_vertex]; + residual_graph[current_vertex][previous_vertex] -= path_flow; + residual_graph[previous_vertex][current_vertex] += path_flow; + previous_vertex = current_vertex; + } + + max_flow += path_flow; + } + + Ok(max_flow) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_max_flow { + ($($name:ident: $tc:expr,)* ) => { + $( + #[test] + fn $name() { + let (graph, source, sink, expected_result) = $tc; + assert_eq!(ford_fulkerson(&graph, source, sink), expected_result); + } + )* + }; + } + + test_max_flow! { + test_empty_graph: ( + vec![], + 0, + 0, + Err(FordFulkersonError::EmptyGraph), + ), + test_source_out_of_bound: ( + vec![ + vec![0, 8, 0, 0, 3, 0], + vec![0, 0, 9, 0, 0, 0], + vec![0, 0, 0, 0, 7, 2], + vec![0, 0, 0, 0, 0, 5], + vec![0, 0, 7, 4, 0, 0], + vec![0, 0, 0, 0, 0, 0], + ], + 6, + 5, + Err(FordFulkersonError::SourceOutOfBounds), + ), + test_sink_out_of_bound: ( + vec![ + vec![0, 8, 0, 0, 3, 0], + vec![0, 0, 9, 0, 0, 0], + vec![0, 0, 0, 0, 7, 2], + vec![0, 0, 0, 0, 0, 5], + vec![0, 0, 7, 4, 0, 0], + vec![0, 0, 0, 0, 0, 0], + ], + 0, + 6, + Err(FordFulkersonError::SinkOutOfBounds), + ), + test_improper_graph: ( + vec![ + vec![0, 8], + vec![0], + ], + 0, + 1, + Err(FordFulkersonError::ImproperGraph), + ), + test_graph_with_small_flow: ( + vec![ + vec![0, 8, 0, 0, 3, 0], + vec![0, 0, 9, 0, 0, 0], + vec![0, 0, 0, 0, 7, 2], + vec![0, 0, 0, 0, 0, 5], + vec![0, 0, 7, 4, 0, 0], + vec![0, 0, 0, 0, 0, 0], + ], + 0, + 5, + Ok(6), + ), + test_graph_with_medium_flow: ( + vec![ + vec![0, 10, 0, 10, 0, 0], + vec![0, 0, 4, 2, 8, 0], + vec![0, 0, 0, 0, 0, 10], + vec![0, 0, 0, 0, 9, 0], + vec![0, 0, 6, 0, 0, 10], + vec![0, 0, 0, 0, 0, 0], + ], + 0, + 5, + Ok(19), + ), + test_graph_with_large_flow: ( + vec![ + vec![0, 12, 0, 13, 0, 0], + vec![0, 0, 10, 0, 0, 0], + vec![0, 0, 0, 13, 3, 15], + vec![0, 0, 7, 0, 15, 0], + vec![0, 0, 6, 0, 0, 17], + vec![0, 0, 0, 0, 0, 0], + ], + 0, + 5, + Ok(23), + ), + test_complex_graph: ( + vec![ + vec![0, 16, 13, 0, 0, 0], + vec![0, 0, 10, 12, 0, 0], + vec![0, 4, 0, 0, 14, 0], + vec![0, 0, 9, 0, 0, 20], + vec![0, 0, 0, 7, 0, 4], + vec![0, 0, 0, 0, 0, 0], + ], + 0, + 5, + Ok(23), + ), + test_disconnected_graph: ( + vec![ + vec![0, 0, 0, 0], + vec![0, 0, 0, 1], + vec![0, 0, 0, 1], + vec![0, 0, 0, 0], + ], + 0, + 3, + Ok(0), + ), + test_unconnected_sink: ( + vec![ + vec![0, 4, 0, 3, 0, 0], + vec![0, 0, 4, 0, 8, 0], + vec![0, 0, 0, 3, 0, 2], + vec![0, 0, 0, 0, 6, 0], + vec![0, 0, 6, 0, 0, 6], + vec![0, 0, 0, 0, 0, 0], + ], + 0, + 5, + Ok(7), + ), + test_no_edges: ( + vec![ + vec![0, 0, 0], + vec![0, 0, 0], + vec![0, 0, 0], + ], + 0, + 2, + Ok(0), + ), + test_single_vertex: ( + vec![ + vec![0], + ], + 0, + 0, + Ok(0), + ), + test_self_loop: ( + vec![ + vec![10, 0], + vec![0, 0], + ], + 0, + 1, + Ok(0), + ), + test_same_source_sink: ( + vec![ + vec![0, 10, 10], + vec![0, 0, 10], + vec![0, 0, 0], + ], + 0, + 0, + Ok(0), + ), + } +} diff --git a/src/graph/graph_enumeration.rs b/src/graph/graph_enumeration.rs index 0218dee4ffb..24326c84aa7 100644 --- a/src/graph/graph_enumeration.rs +++ b/src/graph/graph_enumeration.rs @@ -27,14 +27,8 @@ pub fn enumerate_graph(adj: &Graph) -> Vec> { mod tests { use super::*; fn add_edge(graph: &mut Graph, a: V, b: V) { - graph - .entry(a.clone()) - .or_insert_with(Vec::new) - .push(b.clone()); - graph - .entry(b.clone()) - .or_insert_with(Vec::new) - .push(a.clone()); + graph.entry(a.clone()).or_default().push(b.clone()); + graph.entry(b).or_default().push(a); } #[test] diff --git a/src/graph/heavy_light_decomposition.rs b/src/graph/heavy_light_decomposition.rs index b2767d986ed..e96c8152a54 100644 --- a/src/graph/heavy_light_decomposition.rs +++ b/src/graph/heavy_light_decomposition.rs @@ -166,6 +166,7 @@ mod tests { let mut lcg = LinearCongruenceGenerator::new(1103515245, 12345, 314); parent[2] = 1; adj[1].push(2); + #[allow(clippy::needless_range_loop)] for i in 3..=n { // randomly determine the parent of each vertex. // There will be modulus bias, but it isn't important diff --git a/src/graph/kosaraju.rs b/src/graph/kosaraju.rs new file mode 100644 index 00000000000..842ddd5ffc1 --- /dev/null +++ b/src/graph/kosaraju.rs @@ -0,0 +1,157 @@ +// Kosaraju algorithm, a linear-time algorithm to find the strongly connected components (SCCs) of a directed graph, in Rust. +pub struct Graph { + vertices: usize, + adj_list: Vec>, + transpose_adj_list: Vec>, +} + +impl Graph { + pub fn new(vertices: usize) -> Self { + Graph { + vertices, + adj_list: vec![vec![]; vertices], + transpose_adj_list: vec![vec![]; vertices], + } + } + + pub fn add_edge(&mut self, u: usize, v: usize) { + self.adj_list[u].push(v); + self.transpose_adj_list[v].push(u); + } + + pub fn dfs(&self, node: usize, visited: &mut Vec, stack: &mut Vec) { + visited[node] = true; + for &neighbor in &self.adj_list[node] { + if !visited[neighbor] { + self.dfs(neighbor, visited, stack); + } + } + stack.push(node); + } + + pub fn dfs_scc(&self, node: usize, visited: &mut Vec, scc: &mut Vec) { + visited[node] = true; + scc.push(node); + for &neighbor in &self.transpose_adj_list[node] { + if !visited[neighbor] { + self.dfs_scc(neighbor, visited, scc); + } + } + } +} + +pub fn kosaraju(graph: &Graph) -> Vec> { + let mut visited = vec![false; graph.vertices]; + let mut stack = Vec::new(); + + for i in 0..graph.vertices { + if !visited[i] { + graph.dfs(i, &mut visited, &mut stack); + } + } + + let mut sccs = Vec::new(); + visited = vec![false; graph.vertices]; + + while let Some(node) = stack.pop() { + if !visited[node] { + let mut scc = Vec::new(); + graph.dfs_scc(node, &mut visited, &mut scc); + sccs.push(scc); + } + } + + sccs +} +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_kosaraju_single_sccs() { + let vertices = 5; + let mut graph = Graph::new(vertices); + + graph.add_edge(0, 1); + graph.add_edge(1, 2); + graph.add_edge(2, 3); + graph.add_edge(2, 4); + graph.add_edge(3, 0); + graph.add_edge(4, 2); + + let sccs = kosaraju(&graph); + assert_eq!(sccs.len(), 1); + assert!(sccs.contains(&vec![0, 3, 2, 1, 4])); + } + + #[test] + fn test_kosaraju_multiple_sccs() { + let vertices = 8; + let mut graph = Graph::new(vertices); + + graph.add_edge(1, 0); + graph.add_edge(0, 1); + graph.add_edge(1, 2); + graph.add_edge(2, 0); + graph.add_edge(2, 3); + graph.add_edge(3, 4); + graph.add_edge(4, 5); + graph.add_edge(5, 6); + graph.add_edge(6, 7); + graph.add_edge(4, 7); + graph.add_edge(6, 4); + + let sccs = kosaraju(&graph); + assert_eq!(sccs.len(), 4); + assert!(sccs.contains(&vec![0, 1, 2])); + assert!(sccs.contains(&vec![3])); + assert!(sccs.contains(&vec![4, 6, 5])); + assert!(sccs.contains(&vec![7])); + } + + #[test] + fn test_kosaraju_multiple_sccs1() { + let vertices = 8; + let mut graph = Graph::new(vertices); + graph.add_edge(0, 2); + graph.add_edge(1, 0); + graph.add_edge(2, 3); + graph.add_edge(3, 4); + graph.add_edge(4, 7); + graph.add_edge(5, 2); + graph.add_edge(5, 6); + graph.add_edge(6, 5); + graph.add_edge(7, 6); + + let sccs = kosaraju(&graph); + assert_eq!(sccs.len(), 3); + assert!(sccs.contains(&vec![0])); + assert!(sccs.contains(&vec![1])); + assert!(sccs.contains(&vec![2, 5, 6, 7, 4, 3])); + } + + #[test] + fn test_kosaraju_no_scc() { + let vertices = 4; + let mut graph = Graph::new(vertices); + + graph.add_edge(0, 1); + graph.add_edge(1, 2); + graph.add_edge(2, 3); + + let sccs = kosaraju(&graph); + assert_eq!(sccs.len(), 4); + for (i, _) in sccs.iter().enumerate().take(vertices) { + assert_eq!(sccs[i], vec![i]); + } + } + + #[test] + fn test_kosaraju_empty_graph() { + let vertices = 0; + let graph = Graph::new(vertices); + + let sccs = kosaraju(&graph); + assert_eq!(sccs.len(), 0); + } +} diff --git a/src/graph/lee_breadth_first_search.rs b/src/graph/lee_breadth_first_search.rs new file mode 100644 index 00000000000..e0c25fd7906 --- /dev/null +++ b/src/graph/lee_breadth_first_search.rs @@ -0,0 +1,117 @@ +use std::collections::VecDeque; + +// All four potential movements from a cell are listed here. + +fn validate(matrix: &[Vec], visited: &[Vec], row: isize, col: isize) -> bool { + // Check if it is possible to move to the position (row, col) from the current cell. + let (row, col) = (row as usize, col as usize); + row < matrix.len() && col < matrix[0].len() && matrix[row][col] == 1 && !visited[row][col] +} + +pub fn lee(matrix: Vec>, source: (usize, usize), destination: (usize, usize)) -> isize { + const ROW: [isize; 4] = [-1, 0, 0, 1]; + const COL: [isize; 4] = [0, -1, 1, 0]; + let (i, j) = source; + let (x, y) = destination; + + // Base case: invalid input + if matrix.is_empty() || matrix[i][j] == 0 || matrix[x][y] == 0 { + return -1; + } + + let (m, n) = (matrix.len(), matrix[0].len()); + let mut visited = vec![vec![false; n]; m]; + let mut q = VecDeque::new(); + visited[i][j] = true; + q.push_back((i, j, 0)); + let mut min_dist = isize::MAX; + + // Loop until the queue is empty + while let Some((i, j, dist)) = q.pop_front() { + if i == x && j == y { + // If the destination is found, update `min_dist` and stop + min_dist = dist; + break; + } + + // Check for all four possible movements from the current cell + for k in 0..ROW.len() { + let row = i as isize + ROW[k]; + let col = j as isize + COL[k]; + if validate(&matrix, &visited, row, col) { + // Mark the next cell as visited and enqueue it + let (row, col) = (row as usize, col as usize); + visited[row][col] = true; + q.push_back((row, col, dist + 1)); + } + } + } + + if min_dist != isize::MAX { + min_dist + } else { + -1 + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_lee_exists() { + let mat: Vec> = vec![ + vec![1, 0, 1, 1, 1], + vec![1, 0, 1, 0, 1], + vec![1, 1, 1, 0, 1], + vec![0, 0, 0, 0, 1], + vec![1, 1, 1, 0, 1], + ]; + let source = (0, 0); + let dest = (2, 1); + assert_eq!(lee(mat, source, dest), 3); + } + + #[test] + fn test_lee_does_not_exist() { + let mat: Vec> = vec![ + vec![1, 0, 1, 1, 1], + vec![1, 0, 0, 0, 1], + vec![1, 1, 1, 0, 1], + vec![0, 0, 0, 0, 1], + vec![1, 1, 1, 0, 1], + ]; + let source = (0, 0); + let dest = (3, 4); + assert_eq!(lee(mat, source, dest), -1); + } + + #[test] + fn test_source_equals_destination() { + let mat: Vec> = vec![ + vec![1, 0, 1, 1, 1], + vec![1, 0, 1, 0, 1], + vec![1, 1, 1, 0, 1], + vec![0, 0, 0, 0, 1], + vec![1, 1, 1, 0, 1], + ]; + let source = (2, 1); + let dest = (2, 1); + assert_eq!(lee(mat, source, dest), 0); + } + + #[test] + fn test_lee_exists_2() { + let mat: Vec> = vec![ + vec![1, 1, 1, 1, 1, 0, 0], + vec![1, 1, 1, 1, 1, 1, 0], + vec![1, 0, 1, 0, 1, 1, 1], + vec![1, 1, 1, 1, 1, 0, 1], + vec![0, 0, 0, 1, 0, 0, 0], + vec![1, 0, 1, 1, 1, 0, 0], + vec![0, 0, 0, 0, 1, 0, 0], + ]; + let source = (0, 0); + let dest = (3, 2); + assert_eq!(lee(mat, source, dest), 5); + } +} diff --git a/src/graph/minimum_spanning_tree.rs b/src/graph/minimum_spanning_tree.rs index d6c2e4ddf2d..9d36cafb303 100644 --- a/src/graph/minimum_spanning_tree.rs +++ b/src/graph/minimum_spanning_tree.rs @@ -1,24 +1,22 @@ -use super::DisjointSetUnion; +//! This module implements Kruskal's algorithm to find the Minimum Spanning Tree (MST) +//! of an undirected, weighted graph using a Disjoint Set Union (DSU) for cycle detection. -#[derive(Debug)] -pub struct Edge { - source: i64, - destination: i64, - cost: i64, -} +use crate::graph::DisjointSetUnion; -impl PartialEq for Edge { - fn eq(&self, other: &Self) -> bool { - self.source == other.source - && self.destination == other.destination - && self.cost == other.cost - } +/// Represents an edge in the graph with a source, destination, and associated cost. +#[derive(Debug, PartialEq, Eq)] +pub struct Edge { + /// The starting vertex of the edge. + source: usize, + /// The ending vertex of the edge. + destination: usize, + /// The cost associated with the edge. + cost: usize, } -impl Eq for Edge {} - impl Edge { - fn new(source: i64, destination: i64, cost: i64) -> Self { + /// Creates a new edge with the specified source, destination, and cost. + pub fn new(source: usize, destination: usize, cost: usize) -> Self { Self { source, destination, @@ -27,108 +25,135 @@ impl Edge { } } -pub fn kruskal(mut edges: Vec, number_of_vertices: i64) -> (i64, Vec) { - let mut dsu = DisjointSetUnion::new(number_of_vertices as usize); - - edges.sort_unstable_by(|a, b| a.cost.cmp(&b.cost)); - let mut total_cost: i64 = 0; - let mut final_edges: Vec = Vec::new(); - let mut merge_count: i64 = 0; - for edge in edges.iter() { - if merge_count >= number_of_vertices - 1 { +/// Executes Kruskal's algorithm to compute the Minimum Spanning Tree (MST) of a graph. +/// +/// # Parameters +/// +/// - `edges`: A vector of `Edge` instances representing all edges in the graph. +/// - `num_vertices`: The total number of vertices in the graph. +/// +/// # Returns +/// +/// An `Option` containing a tuple with: +/// +/// - The total cost of the MST (usize). +/// - A vector of edges that are included in the MST. +/// +/// Returns `None` if the graph is disconnected. +/// +/// # Complexity +/// +/// The time complexity is O(E log E), where E is the number of edges. +pub fn kruskal(mut edges: Vec, num_vertices: usize) -> Option<(usize, Vec)> { + let mut dsu = DisjointSetUnion::new(num_vertices); + let mut mst_cost: usize = 0; + let mut mst_edges: Vec = Vec::with_capacity(num_vertices - 1); + + // Sort edges by cost in ascending order + edges.sort_unstable_by_key(|edge| edge.cost); + + for edge in edges { + if mst_edges.len() == num_vertices - 1 { break; } - let source: i64 = edge.source; - let destination: i64 = edge.destination; - if dsu.merge(source as usize, destination as usize) < std::usize::MAX { - merge_count += 1; - let cost: i64 = edge.cost; - total_cost += cost; - let final_edge: Edge = Edge::new(source, destination, cost); - final_edges.push(final_edge); + // Attempt to merge the sets containing the edge’s vertices + if dsu.merge(edge.source, edge.destination) != usize::MAX { + mst_cost += edge.cost; + mst_edges.push(edge); } } - (total_cost, final_edges) + + // Return MST if it includes exactly num_vertices - 1 edges, otherwise None for disconnected graphs + (mst_edges.len() == num_vertices - 1).then_some((mst_cost, mst_edges)) } #[cfg(test)] mod tests { use super::*; - #[test] - fn test_seven_vertices_eleven_edges() { - let mut edges: Vec = Vec::new(); - edges.push(Edge::new(0, 1, 7)); - edges.push(Edge::new(0, 3, 5)); - edges.push(Edge::new(1, 2, 8)); - edges.push(Edge::new(1, 3, 9)); - edges.push(Edge::new(1, 4, 7)); - edges.push(Edge::new(2, 4, 5)); - edges.push(Edge::new(3, 4, 15)); - edges.push(Edge::new(3, 5, 6)); - edges.push(Edge::new(4, 5, 8)); - edges.push(Edge::new(4, 6, 9)); - edges.push(Edge::new(5, 6, 11)); - - let number_of_vertices: i64 = 7; - - let expected_total_cost = 39; - let mut expected_used_edges: Vec = Vec::new(); - expected_used_edges.push(Edge::new(0, 3, 5)); - expected_used_edges.push(Edge::new(2, 4, 5)); - expected_used_edges.push(Edge::new(3, 5, 6)); - expected_used_edges.push(Edge::new(0, 1, 7)); - expected_used_edges.push(Edge::new(1, 4, 7)); - expected_used_edges.push(Edge::new(4, 6, 9)); - - let (actual_total_cost, actual_final_edges) = kruskal(edges, number_of_vertices); - - assert_eq!(actual_total_cost, expected_total_cost); - assert_eq!(actual_final_edges, expected_used_edges); + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (edges, num_vertices, expected_result) = $test_case; + let actual_result = kruskal(edges, num_vertices); + assert_eq!(actual_result, expected_result); + } + )* + }; } - #[test] - fn test_ten_vertices_twenty_edges() { - let mut edges: Vec = Vec::new(); - edges.push(Edge::new(0, 1, 3)); - edges.push(Edge::new(0, 3, 6)); - edges.push(Edge::new(0, 4, 9)); - edges.push(Edge::new(1, 2, 2)); - edges.push(Edge::new(1, 3, 4)); - edges.push(Edge::new(1, 4, 9)); - edges.push(Edge::new(2, 3, 2)); - edges.push(Edge::new(2, 5, 8)); - edges.push(Edge::new(2, 6, 9)); - edges.push(Edge::new(3, 6, 9)); - edges.push(Edge::new(4, 5, 8)); - edges.push(Edge::new(4, 9, 18)); - edges.push(Edge::new(5, 6, 7)); - edges.push(Edge::new(5, 8, 9)); - edges.push(Edge::new(5, 9, 10)); - edges.push(Edge::new(6, 7, 4)); - edges.push(Edge::new(6, 8, 5)); - edges.push(Edge::new(7, 8, 1)); - edges.push(Edge::new(7, 9, 4)); - edges.push(Edge::new(8, 9, 3)); - - let number_of_vertices: i64 = 10; - - let expected_total_cost = 38; - let mut expected_used_edges = Vec::new(); - expected_used_edges.push(Edge::new(7, 8, 1)); - expected_used_edges.push(Edge::new(1, 2, 2)); - expected_used_edges.push(Edge::new(2, 3, 2)); - expected_used_edges.push(Edge::new(0, 1, 3)); - expected_used_edges.push(Edge::new(8, 9, 3)); - expected_used_edges.push(Edge::new(6, 7, 4)); - expected_used_edges.push(Edge::new(5, 6, 7)); - expected_used_edges.push(Edge::new(2, 5, 8)); - expected_used_edges.push(Edge::new(4, 5, 8)); - - let (actual_total_cost, actual_final_edges) = kruskal(edges, number_of_vertices); - - assert_eq!(actual_total_cost, expected_total_cost); - assert_eq!(actual_final_edges, expected_used_edges); + test_cases! { + test_seven_vertices_eleven_edges: ( + vec![ + Edge::new(0, 1, 7), + Edge::new(0, 3, 5), + Edge::new(1, 2, 8), + Edge::new(1, 3, 9), + Edge::new(1, 4, 7), + Edge::new(2, 4, 5), + Edge::new(3, 4, 15), + Edge::new(3, 5, 6), + Edge::new(4, 5, 8), + Edge::new(4, 6, 9), + Edge::new(5, 6, 11), + ], + 7, + Some((39, vec![ + Edge::new(0, 3, 5), + Edge::new(2, 4, 5), + Edge::new(3, 5, 6), + Edge::new(0, 1, 7), + Edge::new(1, 4, 7), + Edge::new(4, 6, 9), + ])) + ), + test_ten_vertices_twenty_edges: ( + vec![ + Edge::new(0, 1, 3), + Edge::new(0, 3, 6), + Edge::new(0, 4, 9), + Edge::new(1, 2, 2), + Edge::new(1, 3, 4), + Edge::new(1, 4, 9), + Edge::new(2, 3, 2), + Edge::new(2, 5, 8), + Edge::new(2, 6, 9), + Edge::new(3, 6, 9), + Edge::new(4, 5, 8), + Edge::new(4, 9, 18), + Edge::new(5, 6, 7), + Edge::new(5, 8, 9), + Edge::new(5, 9, 10), + Edge::new(6, 7, 4), + Edge::new(6, 8, 5), + Edge::new(7, 8, 1), + Edge::new(7, 9, 4), + Edge::new(8, 9, 3), + ], + 10, + Some((38, vec![ + Edge::new(7, 8, 1), + Edge::new(1, 2, 2), + Edge::new(2, 3, 2), + Edge::new(0, 1, 3), + Edge::new(8, 9, 3), + Edge::new(6, 7, 4), + Edge::new(5, 6, 7), + Edge::new(2, 5, 8), + Edge::new(4, 5, 8), + ])) + ), + test_disconnected_graph: ( + vec![ + Edge::new(0, 1, 4), + Edge::new(0, 2, 6), + Edge::new(3, 4, 2), + ], + 5, + None + ), } } diff --git a/src/graph/mod.rs b/src/graph/mod.rs index 62f7ba52dac..d4b0b0d00cb 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -1,35 +1,55 @@ +mod astar; mod bellman_ford; +mod bipartite_matching; mod breadth_first_search; mod centroid_decomposition; +mod decremental_connectivity; mod depth_first_search; mod depth_first_search_tic_tac_toe; +mod detect_cycle; mod dijkstra; mod dinic_maxflow; mod disjoint_set_union; +mod eulerian_path; +mod floyd_warshall; +mod ford_fulkerson; mod graph_enumeration; mod heavy_light_decomposition; +mod kosaraju; +mod lee_breadth_first_search; mod lowest_common_ancestor; mod minimum_spanning_tree; mod prim; mod prufer_code; mod strongly_connected_components; +mod tarjans_ssc; mod topological_sort; mod two_satisfiability; +pub use self::astar::astar; pub use self::bellman_ford::bellman_ford; +pub use self::bipartite_matching::BipartiteMatching; pub use self::breadth_first_search::breadth_first_search; pub use self::centroid_decomposition::CentroidDecomposition; +pub use self::decremental_connectivity::DecrementalConnectivity; pub use self::depth_first_search::depth_first_search; pub use self::depth_first_search_tic_tac_toe::minimax; +pub use self::detect_cycle::DetectCycle; pub use self::dijkstra::dijkstra; pub use self::dinic_maxflow::DinicMaxFlow; pub use self::disjoint_set_union::DisjointSetUnion; +pub use self::eulerian_path::find_eulerian_path; +pub use self::floyd_warshall::floyd_warshall; +pub use self::ford_fulkerson::ford_fulkerson; pub use self::graph_enumeration::enumerate_graph; pub use self::heavy_light_decomposition::HeavyLightDecomposition; +pub use self::kosaraju::kosaraju; +pub use self::lee_breadth_first_search::lee; pub use self::lowest_common_ancestor::{LowestCommonAncestorOffline, LowestCommonAncestorOnline}; pub use self::minimum_spanning_tree::kruskal; pub use self::prim::{prim, prim_with_start}; pub use self::prufer_code::{prufer_decode, prufer_encode}; pub use self::strongly_connected_components::StronglyConnectedComponents; +pub use self::tarjans_ssc::tarjan_scc; pub use self::topological_sort::topological_sort; pub use self::two_satisfiability::solve_two_satisfiability; diff --git a/src/graph/prim.rs b/src/graph/prim.rs index 16497cb0920..2fac8883572 100644 --- a/src/graph/prim.rs +++ b/src/graph/prim.rs @@ -5,8 +5,8 @@ use std::ops::Add; type Graph = BTreeMap>; fn add_edge(graph: &mut Graph, v1: V, v2: V, c: E) { - graph.entry(v1).or_insert_with(BTreeMap::new).insert(v2, c); - graph.entry(v2).or_insert_with(BTreeMap::new).insert(v1, c); + graph.entry(v1).or_default().insert(v2, c); + graph.entry(v2).or_default().insert(v1, c); } // selects a start and run the algorithm from it diff --git a/src/graph/prufer_code.rs b/src/graph/prufer_code.rs index 5478fbaf441..0c965b8cb50 100644 --- a/src/graph/prufer_code.rs +++ b/src/graph/prufer_code.rs @@ -6,8 +6,7 @@ pub fn prufer_encode(tree: &Graph) -> Vec { if tree.len() <= 2 { return vec![]; } - let mut result: Vec = Vec::new(); - result.reserve(tree.len() - 2); + let mut result: Vec = Vec::with_capacity(tree.len() - 2); let mut queue = BinaryHeap::new(); let mut in_tree = BTreeSet::new(); let mut degree = BTreeMap::new(); @@ -33,7 +32,7 @@ pub fn prufer_encode(tree: &Graph) -> Vec { #[inline] fn add_directed_edge(tree: &mut Graph, a: V, b: V) { - tree.entry(a).or_insert(vec![]).push(b); + tree.entry(a).or_default().push(b); } #[inline] @@ -83,7 +82,7 @@ mod tests { for adj in g2.values_mut() { adj.sort(); } - return g1 == g2; + g1 == g2 } #[test] diff --git a/src/graph/tarjans_ssc.rs b/src/graph/tarjans_ssc.rs new file mode 100644 index 00000000000..1f8e258614a --- /dev/null +++ b/src/graph/tarjans_ssc.rs @@ -0,0 +1,173 @@ +pub struct Graph { + n: usize, + adj_list: Vec>, +} + +impl Graph { + pub fn new(n: usize) -> Self { + Self { + n, + adj_list: vec![vec![]; n], + } + } + + pub fn add_edge(&mut self, u: usize, v: usize) { + self.adj_list[u].push(v); + } +} +pub fn tarjan_scc(graph: &Graph) -> Vec> { + struct TarjanState { + index: i32, + stack: Vec, + on_stack: Vec, + index_of: Vec, + lowlink_of: Vec, + components: Vec>, + } + + let mut state = TarjanState { + index: 0, + stack: Vec::new(), + on_stack: vec![false; graph.n], + index_of: vec![-1; graph.n], + lowlink_of: vec![-1; graph.n], + components: Vec::new(), + }; + + fn strong_connect(v: usize, graph: &Graph, state: &mut TarjanState) { + state.index_of[v] = state.index; + state.lowlink_of[v] = state.index; + state.index += 1; + state.stack.push(v); + state.on_stack[v] = true; + + for &w in &graph.adj_list[v] { + if state.index_of[w] == -1 { + strong_connect(w, graph, state); + state.lowlink_of[v] = state.lowlink_of[v].min(state.lowlink_of[w]); + } else if state.on_stack[w] { + state.lowlink_of[v] = state.lowlink_of[v].min(state.index_of[w]); + } + } + + if state.lowlink_of[v] == state.index_of[v] { + let mut component: Vec = Vec::new(); + while let Some(w) = state.stack.pop() { + state.on_stack[w] = false; + component.push(w); + if w == v { + break; + } + } + state.components.push(component); + } + } + + for v in 0..graph.n { + if state.index_of[v] == -1 { + strong_connect(v, graph, &mut state); + } + } + + state.components +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tarjan_scc() { + // Test 1: A graph with multiple strongly connected components + let n_vertices = 11; + let edges = vec![ + (0, 1), + (0, 3), + (1, 2), + (1, 4), + (2, 0), + (2, 6), + (3, 2), + (4, 5), + (4, 6), + (5, 6), + (5, 7), + (5, 8), + (5, 9), + (6, 4), + (7, 9), + (8, 9), + (9, 8), + ]; + let mut graph = Graph::new(n_vertices); + + for &(u, v) in &edges { + graph.add_edge(u, v); + } + + let components = tarjan_scc(&graph); + assert_eq!( + components, + vec![ + vec![8, 9], + vec![7], + vec![5, 4, 6], + vec![3, 2, 1, 0], + vec![10], + ] + ); + + // Test 2: A graph with no edges + let n_vertices = 5; + let edges: Vec<(usize, usize)> = vec![]; + let mut graph = Graph::new(n_vertices); + + for &(u, v) in &edges { + graph.add_edge(u, v); + } + + let components = tarjan_scc(&graph); + + // Each node is its own SCC + assert_eq!( + components, + vec![vec![0], vec![1], vec![2], vec![3], vec![4]] + ); + + // Test 3: A graph with single strongly connected component + let n_vertices = 5; + let edges = vec![(0, 1), (1, 2), (2, 3), (2, 4), (3, 0), (4, 2)]; + let mut graph = Graph::new(n_vertices); + + for &(u, v) in &edges { + graph.add_edge(u, v); + } + + let components = tarjan_scc(&graph); + assert_eq!(components, vec![vec![4, 3, 2, 1, 0]]); + + // Test 4: A graph with multiple strongly connected component + let n_vertices = 7; + let edges = vec![ + (0, 1), + (1, 2), + (2, 0), + (1, 3), + (1, 4), + (1, 6), + (3, 5), + (4, 5), + ]; + let mut graph = Graph::new(n_vertices); + + for &(u, v) in &edges { + graph.add_edge(u, v); + } + + let components = tarjan_scc(&graph); + assert_eq!( + components, + vec![vec![5], vec![3], vec![4], vec![6], vec![2, 1, 0],] + ); + } +} diff --git a/src/graph/topological_sort.rs b/src/graph/topological_sort.rs index f14c5aea802..887758287ea 100644 --- a/src/graph/topological_sort.rs +++ b/src/graph/topological_sort.rs @@ -1,62 +1,123 @@ -use std::collections::{BTreeMap, VecDeque}; - -type Graph = BTreeMap>; - -/// returns topological sort of the graph using Kahn's algorithm -pub fn topological_sort(graph: &Graph) -> Vec { - let mut visited = BTreeMap::new(); - let mut degree = BTreeMap::new(); - for u in graph.keys() { - degree.insert(*u, 0); - for (v, _) in graph.get(u).unwrap() { - let entry = degree.entry(*v).or_insert(0); - *entry += 1; - } +use std::collections::HashMap; +use std::collections::VecDeque; +use std::hash::Hash; + +#[derive(Debug, Eq, PartialEq)] +pub enum TopoligicalSortError { + CycleDetected, +} + +type TopologicalSortResult = Result, TopoligicalSortError>; + +/// Given a directed graph, modeled as a list of edges from source to destination +/// Uses Kahn's algorithm to either: +/// return the topological sort of the graph +/// or detect if there's any cycle +pub fn topological_sort( + edges: &Vec<(Node, Node)>, +) -> TopologicalSortResult { + // Preparation: + // Build a map of edges, organised from source to destinations + // Also, count the number of incoming edges by node + let mut edges_by_source: HashMap> = HashMap::default(); + let mut incoming_edges_count: HashMap = HashMap::default(); + for (source, destination) in edges { + incoming_edges_count.entry(*source).or_insert(0); // if we haven't seen this node yet, mark it as having 0 incoming nodes + edges_by_source // add destination to the list of outgoing edges from source + .entry(*source) + .or_default() + .push(*destination); + // then make destination have one more incoming edge + *incoming_edges_count.entry(*destination).or_insert(0) += 1; } - let mut queue = VecDeque::new(); - for (u, d) in degree.iter() { - if *d == 0 { - queue.push_back(*u); - visited.insert(*u, true); + + // Now Kahn's algorithm: + // Add nodes that have no incoming edges to a queue + let mut no_incoming_edges_q = VecDeque::default(); + for (node, count) in &incoming_edges_count { + if *count == 0 { + no_incoming_edges_q.push_back(*node); } } - let mut ret = Vec::new(); - while let Some(u) = queue.pop_front() { - ret.push(u); - if let Some(from_u) = graph.get(&u) { - for (v, _) in from_u { - *degree.get_mut(v).unwrap() -= 1; - if *degree.get(v).unwrap() == 0 { - queue.push_back(*v); - visited.insert(*v, true); + // For each node in this "O-incoming-edge-queue" + let mut sorted = Vec::default(); + while let Some(no_incoming_edges) = no_incoming_edges_q.pop_back() { + sorted.push(no_incoming_edges); // since the node has no dependency, it can be safely pushed to the sorted result + incoming_edges_count.remove(&no_incoming_edges); + // For each node having this one as dependency + for neighbour in edges_by_source.get(&no_incoming_edges).unwrap_or(&vec![]) { + if let Some(count) = incoming_edges_count.get_mut(neighbour) { + *count -= 1; // decrement the count of incoming edges for the dependent node + if *count == 0 { + // `node` was the last node `neighbour` was dependent on + incoming_edges_count.remove(neighbour); // let's remove it from the map, so that we can know if we covered the whole graph + no_incoming_edges_q.push_front(*neighbour); // it has no incoming edges anymore => push it to the queue } } } } - ret + if incoming_edges_count.is_empty() { + // we have visited every node + Ok(sorted) + } else { + // some nodes haven't been visited, meaning there's a cycle in the graph + Err(TopoligicalSortError::CycleDetected) + } } #[cfg(test)] mod tests { - use std::collections::BTreeMap; + use super::topological_sort; + use crate::graph::topological_sort::TopoligicalSortError; - use super::{topological_sort, Graph}; - fn add_edge(graph: &mut Graph, from: V, to: V, weight: E) { - let edges = graph.entry(from).or_insert(Vec::new()); - edges.push((to, weight)); + fn is_valid_sort(sorted: &[Node], graph: &[(Node, Node)]) -> bool { + for (source, dest) in graph { + let source_pos = sorted.iter().position(|node| node == source); + let dest_pos = sorted.iter().position(|node| node == dest); + match (source_pos, dest_pos) { + (Some(src), Some(dst)) if src < dst => {} + _ => { + return false; + } + }; + } + true } #[test] fn it_works() { - let mut graph = BTreeMap::new(); - add_edge(&mut graph, 1, 2, 1); - add_edge(&mut graph, 1, 3, 1); - add_edge(&mut graph, 2, 3, 1); - add_edge(&mut graph, 3, 4, 1); - add_edge(&mut graph, 4, 5, 1); - add_edge(&mut graph, 5, 6, 1); - add_edge(&mut graph, 6, 7, 1); - - assert_eq!(topological_sort(&graph), vec![1, 2, 3, 4, 5, 6, 7]); + let graph = vec![(1, 2), (1, 3), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)]; + let sort = topological_sort(&graph); + assert!(sort.is_ok()); + let sort = sort.unwrap(); + assert!(is_valid_sort(&sort, &graph)); + assert_eq!(sort, vec![1, 2, 3, 4, 5, 6, 7]); + } + + #[test] + fn test_wikipedia_example() { + let graph = vec![ + (5, 11), + (7, 11), + (7, 8), + (3, 8), + (3, 10), + (11, 2), + (11, 9), + (11, 10), + (8, 9), + ]; + let sort = topological_sort(&graph); + assert!(sort.is_ok()); + let sort = sort.unwrap(); + assert!(is_valid_sort(&sort, &graph)); + } + + #[test] + fn test_cyclic_graph() { + let graph = vec![(1, 2), (2, 3), (3, 4), (4, 5), (4, 2)]; + let sort = topological_sort(&graph); + assert!(sort.is_err()); + assert_eq!(sort.err().unwrap(), TopoligicalSortError::CycleDetected); } } diff --git a/src/graph/two_satisfiability.rs b/src/graph/two_satisfiability.rs index 3d1478df963..a3e727f9323 100644 --- a/src/graph/two_satisfiability.rs +++ b/src/graph/two_satisfiability.rs @@ -12,11 +12,9 @@ fn variable(var: i64) -> usize { } } -/// Returns an assignment that satisfies all the constraints, or a variable -/// that makes such an assignment impossible. Variables should be numbered -/// from 1 to n, and a negative number -m corresponds to the negated variable -/// m. For more information about this problem, please visit: -/// https://en.wikipedia.org/wiki/2-satisfiability +/// Returns an assignment that satisfies all the constraints, or a variable that makes such an assignment impossible.\ +/// Variables should be numbered from 1 to `n`, and a negative number `-m` corresponds to the negated variable `m`.\ +/// For more information about this problem, please visit: pub fn solve_two_satisfiability( expression: &[Condition], num_variables: usize, @@ -26,12 +24,12 @@ pub fn solve_two_satisfiability( let mut sccs = SCCs::new(num_verts); let mut adj = Graph::new(); adj.resize(num_verts, vec![]); - expression.iter().for_each(|cond| { + for cond in expression.iter() { let v1 = variable(cond.0); let v2 = variable(cond.1); adj[v1 ^ 1].push(v2); adj[v2 ^ 1].push(v1); - }); + } sccs.find_components(&adj); result.resize(num_variables + 1, false); for var in (2..num_verts).step_by(2) { diff --git a/src/greedy/mod.rs b/src/greedy/mod.rs new file mode 100644 index 00000000000..e718c149f42 --- /dev/null +++ b/src/greedy/mod.rs @@ -0,0 +1,3 @@ +mod stable_matching; + +pub use self::stable_matching::stable_matching; diff --git a/src/greedy/stable_matching.rs b/src/greedy/stable_matching.rs new file mode 100644 index 00000000000..9b8f603d3d0 --- /dev/null +++ b/src/greedy/stable_matching.rs @@ -0,0 +1,276 @@ +use std::collections::{HashMap, VecDeque}; + +fn initialize_men( + men_preferences: &HashMap>, +) -> (VecDeque, HashMap) { + let mut free_men = VecDeque::new(); + let mut next_proposal = HashMap::new(); + + for man in men_preferences.keys() { + free_men.push_back(man.clone()); + next_proposal.insert(man.clone(), 0); + } + + (free_men, next_proposal) +} + +fn initialize_women( + women_preferences: &HashMap>, +) -> HashMap> { + let mut current_partner = HashMap::new(); + for woman in women_preferences.keys() { + current_partner.insert(woman.clone(), None); + } + current_partner +} + +fn precompute_woman_ranks( + women_preferences: &HashMap>, +) -> HashMap> { + let mut woman_ranks = HashMap::new(); + for (woman, preferences) in women_preferences { + let mut rank_map = HashMap::new(); + for (rank, man) in preferences.iter().enumerate() { + rank_map.insert(man.clone(), rank); + } + woman_ranks.insert(woman.clone(), rank_map); + } + woman_ranks +} + +fn process_proposal( + man: &str, + free_men: &mut VecDeque, + current_partner: &mut HashMap>, + man_engaged: &mut HashMap>, + next_proposal: &mut HashMap, + men_preferences: &HashMap>, + woman_ranks: &HashMap>, +) { + let man_pref_list = &men_preferences[man]; + let next_woman_idx = next_proposal[man]; + let woman = &man_pref_list[next_woman_idx]; + + // Update man's next proposal index + next_proposal.insert(man.to_string(), next_woman_idx + 1); + + if let Some(current_man) = current_partner[woman].clone() { + // Woman is currently engaged, check if she prefers the new man + if woman_prefers_new_man(woman, man, ¤t_man, woman_ranks) { + engage_man( + man, + woman, + free_men, + current_partner, + man_engaged, + Some(current_man), + ); + } else { + // Woman rejects the proposal, so the man remains free + free_men.push_back(man.to_string()); + } + } else { + // Woman is not engaged, so engage her with this man + engage_man(man, woman, free_men, current_partner, man_engaged, None); + } +} + +fn woman_prefers_new_man( + woman: &str, + man1: &str, + man2: &str, + woman_ranks: &HashMap>, +) -> bool { + let ranks = &woman_ranks[woman]; + ranks[man1] < ranks[man2] +} + +fn engage_man( + man: &str, + woman: &str, + free_men: &mut VecDeque, + current_partner: &mut HashMap>, + man_engaged: &mut HashMap>, + current_man: Option, +) { + man_engaged.insert(man.to_string(), Some(woman.to_string())); + current_partner.insert(woman.to_string(), Some(man.to_string())); + + if let Some(current_man) = current_man { + // The current man is now free + free_men.push_back(current_man); + } +} + +fn finalize_matches(man_engaged: HashMap>) -> HashMap { + let mut stable_matches = HashMap::new(); + for (man, woman_option) in man_engaged { + if let Some(woman) = woman_option { + stable_matches.insert(man, woman); + } + } + stable_matches +} + +pub fn stable_matching( + men_preferences: &HashMap>, + women_preferences: &HashMap>, +) -> HashMap { + let (mut free_men, mut next_proposal) = initialize_men(men_preferences); + let mut current_partner = initialize_women(women_preferences); + let mut man_engaged = HashMap::new(); + + let woman_ranks = precompute_woman_ranks(women_preferences); + + while let Some(man) = free_men.pop_front() { + process_proposal( + &man, + &mut free_men, + &mut current_partner, + &mut man_engaged, + &mut next_proposal, + men_preferences, + &woman_ranks, + ); + } + + finalize_matches(man_engaged) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashMap; + + #[test] + fn test_stable_matching_scenario_1() { + let men_preferences = HashMap::from([ + ( + "A".to_string(), + vec!["X".to_string(), "Y".to_string(), "Z".to_string()], + ), + ( + "B".to_string(), + vec!["Y".to_string(), "X".to_string(), "Z".to_string()], + ), + ( + "C".to_string(), + vec!["X".to_string(), "Y".to_string(), "Z".to_string()], + ), + ]); + + let women_preferences = HashMap::from([ + ( + "X".to_string(), + vec!["B".to_string(), "A".to_string(), "C".to_string()], + ), + ( + "Y".to_string(), + vec!["A".to_string(), "B".to_string(), "C".to_string()], + ), + ( + "Z".to_string(), + vec!["A".to_string(), "B".to_string(), "C".to_string()], + ), + ]); + + let matches = stable_matching(&men_preferences, &women_preferences); + + let expected_matches1 = HashMap::from([ + ("A".to_string(), "Y".to_string()), + ("B".to_string(), "X".to_string()), + ("C".to_string(), "Z".to_string()), + ]); + + let expected_matches2 = HashMap::from([ + ("A".to_string(), "X".to_string()), + ("B".to_string(), "Y".to_string()), + ("C".to_string(), "Z".to_string()), + ]); + + assert!(matches == expected_matches1 || matches == expected_matches2); + } + + #[test] + fn test_stable_matching_empty() { + let men_preferences = HashMap::new(); + let women_preferences = HashMap::new(); + + let matches = stable_matching(&men_preferences, &women_preferences); + assert!(matches.is_empty()); + } + + #[test] + fn test_stable_matching_duplicate_preferences() { + let men_preferences = HashMap::from([ + ("A".to_string(), vec!["X".to_string(), "X".to_string()]), // Man with duplicate preferences + ("B".to_string(), vec!["Y".to_string()]), + ]); + + let women_preferences = HashMap::from([ + ("X".to_string(), vec!["A".to_string(), "B".to_string()]), + ("Y".to_string(), vec!["B".to_string()]), + ]); + + let matches = stable_matching(&men_preferences, &women_preferences); + let expected_matches = HashMap::from([ + ("A".to_string(), "X".to_string()), + ("B".to_string(), "Y".to_string()), + ]); + + assert_eq!(matches, expected_matches); + } + + #[test] + fn test_stable_matching_single_pair() { + let men_preferences = HashMap::from([("A".to_string(), vec!["X".to_string()])]); + let women_preferences = HashMap::from([("X".to_string(), vec!["A".to_string()])]); + + let matches = stable_matching(&men_preferences, &women_preferences); + let expected_matches = HashMap::from([("A".to_string(), "X".to_string())]); + + assert_eq!(matches, expected_matches); + } + #[test] + fn test_woman_prefers_new_man() { + let men_preferences = HashMap::from([ + ( + "A".to_string(), + vec!["X".to_string(), "Y".to_string(), "Z".to_string()], + ), + ( + "B".to_string(), + vec!["X".to_string(), "Y".to_string(), "Z".to_string()], + ), + ( + "C".to_string(), + vec!["X".to_string(), "Y".to_string(), "Z".to_string()], + ), + ]); + + let women_preferences = HashMap::from([ + ( + "X".to_string(), + vec!["B".to_string(), "A".to_string(), "C".to_string()], + ), + ( + "Y".to_string(), + vec!["A".to_string(), "B".to_string(), "C".to_string()], + ), + ( + "Z".to_string(), + vec!["A".to_string(), "B".to_string(), "C".to_string()], + ), + ]); + + let matches = stable_matching(&men_preferences, &women_preferences); + + let expected_matches = HashMap::from([ + ("A".to_string(), "Y".to_string()), + ("B".to_string(), "X".to_string()), + ("C".to_string(), "Z".to_string()), + ]); + + assert_eq!(matches, expected_matches); + } +} diff --git a/src/lib.rs b/src/lib.rs index 71e5bec5faa..910bf05de06 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,30 +1,39 @@ +pub mod backtracking; +pub mod big_integer; +pub mod bit_manipulation; pub mod ciphers; +pub mod compression; +pub mod conversions; pub mod data_structures; pub mod dynamic_programming; +pub mod financial; pub mod general; +pub mod geometry; pub mod graph; +pub mod greedy; +pub mod machine_learning; pub mod math; +pub mod navigation; +pub mod number_theory; pub mod searching; pub mod sorting; pub mod string; #[cfg(test)] mod tests { - use sorting; + use super::sorting; #[test] fn quick_sort() { //descending let mut ve1 = vec![6, 5, 4, 3, 2, 1]; sorting::quick_sort(&mut ve1); - for i in 0..ve1.len() - 1 { - assert!(ve1[i] <= ve1[i + 1]); - } + + assert!(sorting::is_sorted(&ve1)); //pre-sorted let mut ve2 = vec![1, 2, 3, 4, 5, 6]; sorting::quick_sort(&mut ve2); - for i in 0..ve2.len() - 1 { - assert!(ve2[i] <= ve2[i + 1]); - } + + assert!(sorting::is_sorted(&ve2)); } } diff --git a/src/machine_learning/cholesky.rs b/src/machine_learning/cholesky.rs new file mode 100644 index 00000000000..3afcc040245 --- /dev/null +++ b/src/machine_learning/cholesky.rs @@ -0,0 +1,99 @@ +pub fn cholesky(mat: Vec, n: usize) -> Vec { + if (mat.is_empty()) || (n == 0) { + return vec![]; + } + let mut res = vec![0.0; mat.len()]; + for i in 0..n { + for j in 0..=i { + let mut s = 0.0; + for k in 0..j { + s += res[i * n + k] * res[j * n + k]; + } + let value = if i == j { + let diag_value = mat[i * n + i] - s; + if diag_value.is_nan() { + 0.0 + } else { + diag_value.sqrt() + } + } else { + let off_diag_value = 1.0 / res[j * n + j] * (mat[i * n + j] - s); + if off_diag_value.is_nan() { + 0.0 + } else { + off_diag_value + } + }; + res[i * n + j] = value; + } + } + res +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cholesky() { + // Test case 1 + let mat1 = vec![25.0, 15.0, -5.0, 15.0, 18.0, 0.0, -5.0, 0.0, 11.0]; + let res1 = cholesky(mat1, 3); + + // The expected Cholesky decomposition values + #[allow(clippy::useless_vec)] + let expected1 = vec![5.0, 0.0, 0.0, 3.0, 3.0, 0.0, -1.0, 1.0, 3.0]; + + assert!(res1 + .iter() + .zip(expected1.iter()) + .all(|(a, b)| (a - b).abs() < 1e-6)); + } + + fn transpose_matrix(mat: &[f64], n: usize) -> Vec { + (0..n) + .flat_map(|i| (0..n).map(move |j| mat[j * n + i])) + .collect() + } + + fn matrix_multiply(mat1: &[f64], mat2: &[f64], n: usize) -> Vec { + (0..n) + .flat_map(|i| { + (0..n).map(move |j| { + (0..n).fold(0.0, |acc, k| acc + mat1[i * n + k] * mat2[k * n + j]) + }) + }) + .collect() + } + + #[test] + fn test_matrix_operations() { + // Test case 1: Transposition + let mat1 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + let transposed_mat1 = transpose_matrix(&mat1, 3); + let expected_transposed_mat1 = vec![1.0, 4.0, 7.0, 2.0, 5.0, 8.0, 3.0, 6.0, 9.0]; + assert_eq!(transposed_mat1, expected_transposed_mat1); + + // Test case 2: Matrix multiplication + let mat2 = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + let mat3 = vec![9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]; + let multiplied_mat = matrix_multiply(&mat2, &mat3, 3); + let expected_multiplied_mat = vec![30.0, 24.0, 18.0, 84.0, 69.0, 54.0, 138.0, 114.0, 90.0]; + assert_eq!(multiplied_mat, expected_multiplied_mat); + } + + #[test] + fn empty_matrix() { + let mat = vec![]; + let res = cholesky(mat, 0); + assert_eq!(res, vec![]); + } + + #[test] + fn matrix_with_all_zeros() { + let mat3 = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; + let res3 = cholesky(mat3, 3); + let expected3 = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]; + assert_eq!(res3, expected3); + } +} diff --git a/src/machine_learning/k_means.rs b/src/machine_learning/k_means.rs new file mode 100644 index 00000000000..cd892d64424 --- /dev/null +++ b/src/machine_learning/k_means.rs @@ -0,0 +1,90 @@ +use rand::random; + +fn get_distance(p1: &(f64, f64), p2: &(f64, f64)) -> f64 { + let dx: f64 = p1.0 - p2.0; + let dy: f64 = p1.1 - p2.1; + + ((dx * dx) + (dy * dy)).sqrt() +} + +fn find_nearest(data_point: &(f64, f64), centroids: &[(f64, f64)]) -> u32 { + let mut cluster: u32 = 0; + + for (i, c) in centroids.iter().enumerate() { + let d1 = get_distance(data_point, c); + let d2 = get_distance(data_point, ¢roids[cluster as usize]); + + if d1 < d2 { + cluster = i as u32; + } + } + + cluster +} + +pub fn k_means(data_points: Vec<(f64, f64)>, n_clusters: usize, max_iter: i32) -> Option> { + if data_points.len() < n_clusters { + return None; + } + + let mut centroids: Vec<(f64, f64)> = Vec::new(); + let mut labels: Vec = vec![0; data_points.len()]; + + for _ in 0..n_clusters { + let x: f64 = random::(); + let y: f64 = random::(); + + centroids.push((x, y)); + } + + let mut count_iter: i32 = 0; + + while count_iter < max_iter { + let mut new_centroids_position: Vec<(f64, f64)> = vec![(0.0, 0.0); n_clusters]; + let mut new_centroids_num: Vec = vec![0; n_clusters]; + + for (i, d) in data_points.iter().enumerate() { + let nearest_cluster = find_nearest(d, ¢roids); + labels[i] = nearest_cluster; + + new_centroids_position[nearest_cluster as usize].0 += d.0; + new_centroids_position[nearest_cluster as usize].1 += d.1; + new_centroids_num[nearest_cluster as usize] += 1; + } + + for i in 0..centroids.len() { + if new_centroids_num[i] == 0 { + continue; + } + + let new_x: f64 = new_centroids_position[i].0 / new_centroids_num[i] as f64; + let new_y: f64 = new_centroids_position[i].1 / new_centroids_num[i] as f64; + + centroids[i] = (new_x, new_y); + } + + count_iter += 1; + } + + Some(labels) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_k_means() { + let mut data_points: Vec<(f64, f64)> = vec![]; + let n_points: usize = 1000; + + for _ in 0..n_points { + let x: f64 = random::() * 100.0; + let y: f64 = random::() * 100.0; + + data_points.push((x, y)); + } + + println!("{:?}", k_means(data_points, 10, 100).unwrap_or_default()); + } +} diff --git a/src/machine_learning/linear_regression.rs b/src/machine_learning/linear_regression.rs new file mode 100644 index 00000000000..22e0840e58b --- /dev/null +++ b/src/machine_learning/linear_regression.rs @@ -0,0 +1,48 @@ +/// Returns the parameters of the line after performing simple linear regression on the input data. +pub fn linear_regression(data_points: Vec<(f64, f64)>) -> Option<(f64, f64)> { + if data_points.is_empty() { + return None; + } + + let count = data_points.len() as f64; + let mean_x = data_points.iter().fold(0.0, |sum, y| sum + y.0) / count; + let mean_y = data_points.iter().fold(0.0, |sum, y| sum + y.1) / count; + + let mut covariance = 0.0; + let mut std_dev_sqr_x = 0.0; + let mut std_dev_sqr_y = 0.0; + + for data_point in data_points { + covariance += (data_point.0 - mean_x) * (data_point.1 - mean_y); + std_dev_sqr_x += (data_point.0 - mean_x).powi(2); + std_dev_sqr_y += (data_point.1 - mean_y).powi(2); + } + + let std_dev_x = std_dev_sqr_x.sqrt(); + let std_dev_y = std_dev_sqr_y.sqrt(); + let std_dev_prod = std_dev_x * std_dev_y; + + let pcc = covariance / std_dev_prod; //Pearson's correlation constant + let b = pcc * (std_dev_y / std_dev_x); //Slope of the line + let a = mean_y - b * mean_x; //Y-Intercept of the line + + Some((a, b)) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_linear_regression() { + assert_eq!( + linear_regression(vec![(0.0, 0.0), (1.0, 1.0), (2.0, 2.0)]), + Some((2.220446049250313e-16, 0.9999999999999998)) + ); + } + + #[test] + fn test_empty_list_linear_regression() { + assert_eq!(linear_regression(vec![]), None); + } +} diff --git a/src/machine_learning/logistic_regression.rs b/src/machine_learning/logistic_regression.rs new file mode 100644 index 00000000000..645cd960f83 --- /dev/null +++ b/src/machine_learning/logistic_regression.rs @@ -0,0 +1,92 @@ +use super::optimization::gradient_descent; +use std::f64::consts::E; + +/// Returns the weights after performing Logistic regression on the input data points. +pub fn logistic_regression( + data_points: Vec<(Vec, f64)>, + iterations: usize, + learning_rate: f64, +) -> Option> { + if data_points.is_empty() { + return None; + } + + let num_features = data_points[0].0.len() + 1; + let mut params = vec![0.0; num_features]; + + let derivative_fn = |params: &[f64]| derivative(params, &data_points); + + gradient_descent(derivative_fn, &mut params, learning_rate, iterations as i32); + + Some(params) +} + +fn derivative(params: &[f64], data_points: &[(Vec, f64)]) -> Vec { + let num_features = params.len(); + let mut gradients = vec![0.0; num_features]; + + for (features, y_i) in data_points { + let z = params[0] + + params[1..] + .iter() + .zip(features) + .map(|(p, x)| p * x) + .sum::(); + let prediction = 1.0 / (1.0 + E.powf(-z)); + + gradients[0] += prediction - y_i; + for (i, x_i) in features.iter().enumerate() { + gradients[i + 1] += (prediction - y_i) * x_i; + } + } + + gradients +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_logistic_regression_simple() { + let data = vec![ + (vec![0.0], 0.0), + (vec![1.0], 0.0), + (vec![2.0], 0.0), + (vec![3.0], 1.0), + (vec![4.0], 1.0), + (vec![5.0], 1.0), + ]; + + let result = logistic_regression(data, 10000, 0.05); + assert!(result.is_some()); + + let params = result.unwrap(); + assert!((params[0] + 17.65).abs() < 1.0); + assert!((params[1] - 7.13).abs() < 1.0); + } + + #[test] + fn test_logistic_regression_extreme_data() { + let data = vec![ + (vec![-100.0], 0.0), + (vec![-10.0], 0.0), + (vec![0.0], 0.0), + (vec![10.0], 1.0), + (vec![100.0], 1.0), + ]; + + let result = logistic_regression(data, 10000, 0.05); + assert!(result.is_some()); + + let params = result.unwrap(); + assert!((params[0] + 6.20).abs() < 1.0); + assert!((params[1] - 5.5).abs() < 1.0); + } + + #[test] + fn test_logistic_regression_no_data() { + let result = logistic_regression(vec![], 5000, 0.1); + assert_eq!(result, None); + } +} diff --git a/src/machine_learning/loss_function/average_margin_ranking_loss.rs b/src/machine_learning/loss_function/average_margin_ranking_loss.rs new file mode 100644 index 00000000000..505bf2a94a7 --- /dev/null +++ b/src/machine_learning/loss_function/average_margin_ranking_loss.rs @@ -0,0 +1,113 @@ +/// Marginal Ranking +/// +/// The 'average_margin_ranking_loss' function calculates the Margin Ranking loss, which is a +/// loss function used for ranking problems in machine learning. +/// +/// ## Formula +/// +/// For a pair of values `x_first` and `x_second`, `margin`, and `y_true`, +/// the Margin Ranking loss is calculated as: +/// +/// - loss = `max(0, -y_true * (x_first - x_second) + margin)`. +/// +/// It returns the average loss by dividing the `total_loss` by total no. of +/// elements. +/// +/// Pytorch implementation: +/// https://pytorch.org/docs/stable/generated/torch.nn.MarginRankingLoss.html +/// https://gombru.github.io/2019/04/03/ranking_loss/ +/// https://vinija.ai/concepts/loss/#pairwise-ranking-loss +/// + +pub fn average_margin_ranking_loss( + x_first: &[f64], + x_second: &[f64], + margin: f64, + y_true: f64, +) -> Result { + check_input(x_first, x_second, margin, y_true)?; + + let total_loss: f64 = x_first + .iter() + .zip(x_second.iter()) + .map(|(f, s)| (margin - y_true * (f - s)).max(0.0)) + .sum(); + Ok(total_loss / (x_first.len() as f64)) +} + +fn check_input( + x_first: &[f64], + x_second: &[f64], + margin: f64, + y_true: f64, +) -> Result<(), MarginalRankingLossError> { + if x_first.len() != x_second.len() { + return Err(MarginalRankingLossError::InputsHaveDifferentLength); + } + if x_first.is_empty() { + return Err(MarginalRankingLossError::EmptyInputs); + } + if margin < 0.0 { + return Err(MarginalRankingLossError::NegativeMargin); + } + if y_true != 1.0 && y_true != -1.0 { + return Err(MarginalRankingLossError::InvalidValues); + } + + Ok(()) +} + +#[derive(Debug, PartialEq, Eq)] +pub enum MarginalRankingLossError { + InputsHaveDifferentLength, + EmptyInputs, + InvalidValues, + NegativeMargin, +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_with_wrong_inputs { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (vec_a, vec_b, margin, y_true, expected) = $inputs; + assert_eq!(average_margin_ranking_loss(&vec_a, &vec_b, margin, y_true), expected); + assert_eq!(average_margin_ranking_loss(&vec_b, &vec_a, margin, y_true), expected); + } + )* + } + } + + test_with_wrong_inputs! { + invalid_length0: (vec![1.0, 2.0, 3.0], vec![2.0, 3.0], 1.0, 1.0, Err(MarginalRankingLossError::InputsHaveDifferentLength)), + invalid_length1: (vec![1.0, 2.0], vec![2.0, 3.0, 4.0], 1.0, 1.0, Err(MarginalRankingLossError::InputsHaveDifferentLength)), + invalid_length2: (vec![], vec![1.0, 2.0, 3.0], 1.0, 1.0, Err(MarginalRankingLossError::InputsHaveDifferentLength)), + invalid_length3: (vec![1.0, 2.0, 3.0], vec![], 1.0, 1.0, Err(MarginalRankingLossError::InputsHaveDifferentLength)), + invalid_values: (vec![1.0, 2.0, 3.0], vec![2.0, 3.0, 4.0], -1.0, 1.0, Err(MarginalRankingLossError::NegativeMargin)), + invalid_y_true: (vec![1.0, 2.0, 3.0], vec![2.0, 3.0, 4.0], 1.0, 2.0, Err(MarginalRankingLossError::InvalidValues)), + empty_inputs: (vec![], vec![], 1.0, 1.0, Err(MarginalRankingLossError::EmptyInputs)), + } + + macro_rules! test_average_margin_ranking_loss { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (x_first, x_second, margin, y_true, expected) = $inputs; + assert_eq!(average_margin_ranking_loss(&x_first, &x_second, margin, y_true), Ok(expected)); + } + )* + } + } + + test_average_margin_ranking_loss! { + set_0: (vec![1.0, 2.0, 3.0], vec![2.0, 3.0, 4.0], 1.0, -1.0, 0.0), + set_1: (vec![1.0, 2.0, 3.0], vec![2.0, 3.0, 4.0], 1.0, 1.0, 2.0), + set_2: (vec![1.0, 2.0, 3.0], vec![1.0, 2.0, 3.0], 0.0, 1.0, 0.0), + set_3: (vec![4.0, 5.0, 6.0], vec![1.0, 2.0, 3.0], 1.0, -1.0, 4.0), + } +} diff --git a/src/machine_learning/loss_function/hinge_loss.rs b/src/machine_learning/loss_function/hinge_loss.rs new file mode 100644 index 00000000000..c02f1eca646 --- /dev/null +++ b/src/machine_learning/loss_function/hinge_loss.rs @@ -0,0 +1,38 @@ +//! # Hinge Loss +//! +//! The `hng_loss` function calculates the Hinge loss, which is a +//! loss function used for classification problems in machine learning. +//! +//! ## Formula +//! +//! For a pair of actual and predicted values, represented as vectors `y_true` and +//! `y_pred`, the Hinge loss is calculated as: +//! +//! - loss = `max(0, 1 - y_true * y_pred)`. +//! +//! It returns the average loss by dividing the `total_loss` by total no. of +//! elements. +//! +pub fn hng_loss(y_true: &[f64], y_pred: &[f64]) -> f64 { + let mut total_loss: f64 = 0.0; + for (p, a) in y_pred.iter().zip(y_true.iter()) { + let loss = (1.0 - a * p).max(0.0); + total_loss += loss; + } + total_loss / (y_pred.len() as f64) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_hinge_loss() { + let predicted_values: Vec = vec![-1.0, 1.0, 1.0]; + let actual_values: Vec = vec![-1.0, -1.0, 1.0]; + assert_eq!( + hng_loss(&predicted_values, &actual_values), + 0.6666666666666666 + ); + } +} diff --git a/src/machine_learning/loss_function/huber_loss.rs b/src/machine_learning/loss_function/huber_loss.rs new file mode 100644 index 00000000000..f81e8deb85c --- /dev/null +++ b/src/machine_learning/loss_function/huber_loss.rs @@ -0,0 +1,74 @@ +/// Computes the Huber loss between arrays of true and predicted values. +/// +/// # Arguments +/// +/// * `y_true` - An array of true values. +/// * `y_pred` - An array of predicted values. +/// * `delta` - The threshold parameter that controls the linear behavior of the loss function. +/// +/// # Returns +/// +/// The average Huber loss for all pairs of true and predicted values. +pub fn huber_loss(y_true: &[f64], y_pred: &[f64], delta: f64) -> Option { + if y_true.len() != y_pred.len() || y_pred.is_empty() { + return None; + } + + let loss: f64 = y_true + .iter() + .zip(y_pred.iter()) + .map(|(&true_val, &pred_val)| { + let residual = (true_val - pred_val).abs(); + match residual { + r if r <= delta => 0.5 * r.powi(2), + _ => delta * residual - 0.5 * delta.powi(2), + } + }) + .sum(); + + Some(loss / (y_pred.len() as f64)) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! huber_loss_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (y_true, y_pred, delta, expected_loss) = $test_case; + assert_eq!(huber_loss(&y_true, &y_pred, delta), expected_loss); + } + )* + }; + } + + huber_loss_tests! { + test_huber_loss_residual_less_than_delta: ( + vec![10.0, 8.0, 12.0], + vec![9.0, 7.0, 11.0], + 1.0, + Some(0.5) + ), + test_huber_loss_residual_greater_than_delta: ( + vec![3.0, 5.0, 7.0], + vec![2.0, 4.0, 8.0], + 0.5, + Some(0.375) + ), + test_huber_loss_invalid_length: ( + vec![10.0, 8.0, 12.0], + vec![7.0, 6.0], + 1.0, + None + ), + test_huber_loss_empty_prediction: ( + vec![10.0, 8.0, 12.0], + vec![], + 1.0, + None + ), + } +} diff --git a/src/machine_learning/loss_function/kl_divergence_loss.rs b/src/machine_learning/loss_function/kl_divergence_loss.rs new file mode 100644 index 00000000000..f477607b20f --- /dev/null +++ b/src/machine_learning/loss_function/kl_divergence_loss.rs @@ -0,0 +1,37 @@ +//! # KL divergence Loss Function +//! +//! For a pair of actual and predicted probability distributions represented as vectors `actual` and `predicted`, the KL divergence loss is calculated as: +//! +//! `L = -Σ(actual[i] * ln(predicted[i]/actual[i]))` for all `i` in the range of the vectors +//! +//! Where `ln` is the natural logarithm function, and `Σ` denotes the summation over all elements of the vectors. +//! +//! ## KL divergence Loss Function Implementation +//! +//! This implementation takes two references to vectors of f64 values, `actual` and `predicted`, and returns the KL divergence loss between them. +//! +pub fn kld_loss(actual: &[f64], predicted: &[f64]) -> f64 { + // epsilon to handle if any of the elements are zero + let eps = 0.00001f64; + let loss: f64 = actual + .iter() + .zip(predicted.iter()) + .map(|(&a, &p)| ((a + eps) * ((a + eps) / (p + eps)).ln())) + .sum(); + loss +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_kld_loss() { + let test_vector_actual = vec![1.346112, 1.337432, 1.246655]; + let test_vector = vec![1.033836, 1.082015, 1.117323]; + assert_eq!( + kld_loss(&test_vector_actual, &test_vector), + 0.7752789394328498 + ); + } +} diff --git a/src/machine_learning/loss_function/mean_absolute_error_loss.rs b/src/machine_learning/loss_function/mean_absolute_error_loss.rs new file mode 100644 index 00000000000..e82cc317624 --- /dev/null +++ b/src/machine_learning/loss_function/mean_absolute_error_loss.rs @@ -0,0 +1,36 @@ +//! # Mean Absolute Error Loss Function +//! +//! The `mae_loss` function calculates the Mean Absolute Error loss, which is a +//! robust loss function used in machine learning. +//! +//! ## Formula +//! +//! For a pair of actual and predicted values, represented as vectors `actual` +//! and `predicted`, the Mean Absolute loss is calculated as: +//! +//! - loss = `(actual - predicted) / n_elements`. +//! +//! It returns the average loss by dividing the `total_loss` by total no. of +//! elements. +//! +pub fn mae_loss(predicted: &[f64], actual: &[f64]) -> f64 { + let mut total_loss: f64 = 0.0; + for (p, a) in predicted.iter().zip(actual.iter()) { + let diff: f64 = p - a; + let absolute_diff = diff.abs(); + total_loss += absolute_diff; + } + total_loss / (predicted.len() as f64) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mae_loss() { + let predicted_values: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let actual_values: Vec = vec![1.0, 3.0, 3.5, 4.5]; + assert_eq!(mae_loss(&predicted_values, &actual_values), 0.5); + } +} diff --git a/src/machine_learning/loss_function/mean_squared_error_loss.rs b/src/machine_learning/loss_function/mean_squared_error_loss.rs new file mode 100644 index 00000000000..407142a3092 --- /dev/null +++ b/src/machine_learning/loss_function/mean_squared_error_loss.rs @@ -0,0 +1,35 @@ +//! # Mean Square Loss Function +//! +//! The `mse_loss` function calculates the Mean Square Error loss, which is a +//! robust loss function used in machine learning. +//! +//! ## Formula +//! +//! For a pair of actual and predicted values, represented as vectors `actual` +//! and `predicted`, the Mean Square loss is calculated as: +//! +//! - loss = `(actual - predicted)^2 / n_elements`. +//! +//! It returns the average loss by dividing the `total_loss` by total no. of +//! elements. +//! +pub fn mse_loss(predicted: &[f64], actual: &[f64]) -> f64 { + let mut total_loss: f64 = 0.0; + for (p, a) in predicted.iter().zip(actual.iter()) { + let diff: f64 = p - a; + total_loss += diff * diff; + } + total_loss / (predicted.len() as f64) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mse_loss() { + let predicted_values: Vec = vec![1.0, 2.0, 3.0, 4.0]; + let actual_values: Vec = vec![1.0, 3.0, 3.5, 4.5]; + assert_eq!(mse_loss(&predicted_values, &actual_values), 0.375); + } +} diff --git a/src/machine_learning/loss_function/mod.rs b/src/machine_learning/loss_function/mod.rs new file mode 100644 index 00000000000..95686eb8c20 --- /dev/null +++ b/src/machine_learning/loss_function/mod.rs @@ -0,0 +1,15 @@ +mod average_margin_ranking_loss; +mod hinge_loss; +mod huber_loss; +mod kl_divergence_loss; +mod mean_absolute_error_loss; +mod mean_squared_error_loss; +mod negative_log_likelihood; + +pub use self::average_margin_ranking_loss::average_margin_ranking_loss; +pub use self::hinge_loss::hng_loss; +pub use self::huber_loss::huber_loss; +pub use self::kl_divergence_loss::kld_loss; +pub use self::mean_absolute_error_loss::mae_loss; +pub use self::mean_squared_error_loss::mse_loss; +pub use self::negative_log_likelihood::neg_log_likelihood; diff --git a/src/machine_learning/loss_function/negative_log_likelihood.rs b/src/machine_learning/loss_function/negative_log_likelihood.rs new file mode 100644 index 00000000000..4fa633091cf --- /dev/null +++ b/src/machine_learning/loss_function/negative_log_likelihood.rs @@ -0,0 +1,100 @@ +// Negative Log Likelihood Loss Function +// +// The `neg_log_likelihood` function calculates the Negative Log Likelyhood loss, +// which is a loss function used for classification problems in machine learning. +// +// ## Formula +// +// For a pair of actual and predicted values, represented as vectors `y_true` and +// `y_pred`, the Negative Log Likelihood loss is calculated as: +// +// - loss = `-y_true * log(y_pred) - (1 - y_true) * log(1 - y_pred)`. +// +// It returns the average loss by dividing the `total_loss` by total no. of +// elements. +// +// https://towardsdatascience.com/cross-entropy-negative-log-likelihood-and-all-that-jazz-47a95bd2e81 +// http://neuralnetworksanddeeplearning.com/chap3.html +// Derivation of the formula: +// https://medium.com/@bhardwajprakarsh/negative-log-likelihood-loss-why-do-we-use-it-for-binary-classification-7625f9e3c944 + +pub fn neg_log_likelihood( + y_true: &[f64], + y_pred: &[f64], +) -> Result { + // Checks if the inputs are empty + if y_true.len() != y_pred.len() { + return Err(NegativeLogLikelihoodLossError::InputsHaveDifferentLength); + } + // Checks if the length of the actual and predicted values are equal + if y_pred.is_empty() { + return Err(NegativeLogLikelihoodLossError::EmptyInputs); + } + // Checks values are between 0 and 1 + if !are_all_values_in_range(y_true) || !are_all_values_in_range(y_pred) { + return Err(NegativeLogLikelihoodLossError::InvalidValues); + } + + let mut total_loss: f64 = 0.0; + for (p, a) in y_pred.iter().zip(y_true.iter()) { + let loss: f64 = -a * p.ln() - (1.0 - a) * (1.0 - p).ln(); + total_loss += loss; + } + Ok(total_loss / (y_pred.len() as f64)) +} + +#[derive(Debug, PartialEq, Eq)] +pub enum NegativeLogLikelihoodLossError { + InputsHaveDifferentLength, + EmptyInputs, + InvalidValues, +} + +fn are_all_values_in_range(values: &[f64]) -> bool { + values.iter().all(|&x| (0.0..=1.0).contains(&x)) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_with_wrong_inputs { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (values_a, values_b, expected_error) = $inputs; + assert_eq!(neg_log_likelihood(&values_a, &values_b), expected_error); + assert_eq!(neg_log_likelihood(&values_b, &values_a), expected_error); + } + )* + } + } + + test_with_wrong_inputs! { + different_length: (vec![0.9, 0.0, 0.8], vec![0.9, 0.1], Err(NegativeLogLikelihoodLossError::InputsHaveDifferentLength)), + different_length_one_empty: (vec![], vec![0.9, 0.1], Err(NegativeLogLikelihoodLossError::InputsHaveDifferentLength)), + value_greater_than_1: (vec![1.1, 0.0, 0.8], vec![0.1, 0.2, 0.3], Err(NegativeLogLikelihoodLossError::InvalidValues)), + value_greater_smaller_than_0: (vec![0.9, 0.0, -0.1], vec![0.1, 0.2, 0.3], Err(NegativeLogLikelihoodLossError::InvalidValues)), + empty_input: (vec![], vec![], Err(NegativeLogLikelihoodLossError::EmptyInputs)), + } + + macro_rules! test_neg_log_likelihood { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (actual_values, predicted_values, expected) = $inputs; + assert_eq!(neg_log_likelihood(&actual_values, &predicted_values).unwrap(), expected); + } + )* + } + } + + test_neg_log_likelihood! { + set_0: (vec![1.0, 0.0, 1.0], vec![0.9, 0.1, 0.8], 0.14462152754328741), + set_1: (vec![1.0, 0.0, 1.0], vec![0.1, 0.2, 0.3], 1.2432338162113972), + set_2: (vec![0.0, 1.0, 0.0], vec![0.1, 0.2, 0.3], 0.6904911240102196), + set_3: (vec![1.0, 0.0, 1.0, 0.0], vec![0.9, 0.1, 0.8, 0.2], 0.164252033486018), + } +} diff --git a/src/machine_learning/mod.rs b/src/machine_learning/mod.rs new file mode 100644 index 00000000000..534326d2121 --- /dev/null +++ b/src/machine_learning/mod.rs @@ -0,0 +1,20 @@ +mod cholesky; +mod k_means; +mod linear_regression; +mod logistic_regression; +mod loss_function; +mod optimization; + +pub use self::cholesky::cholesky; +pub use self::k_means::k_means; +pub use self::linear_regression::linear_regression; +pub use self::logistic_regression::logistic_regression; +pub use self::loss_function::average_margin_ranking_loss; +pub use self::loss_function::hng_loss; +pub use self::loss_function::huber_loss; +pub use self::loss_function::kld_loss; +pub use self::loss_function::mae_loss; +pub use self::loss_function::mse_loss; +pub use self::loss_function::neg_log_likelihood; +pub use self::optimization::gradient_descent; +pub use self::optimization::Adam; diff --git a/src/machine_learning/optimization/adam.rs b/src/machine_learning/optimization/adam.rs new file mode 100644 index 00000000000..6fbebc6d39d --- /dev/null +++ b/src/machine_learning/optimization/adam.rs @@ -0,0 +1,288 @@ +//! # Adam (Adaptive Moment Estimation) optimizer +//! +//! The `Adam (Adaptive Moment Estimation)` optimizer is an adaptive learning rate algorithm used +//! in gradient descent and machine learning, such as for training neural networks to solve deep +//! learning problems. Boasting memory-efficient fast convergence rates, it sets and iteratively +//! updates learning rates individually for each model parameter based on the gradient history. +//! +//! ## Algorithm: +//! +//! Given: +//! - α is the learning rate +//! - (β_1, β_2) are the exponential decay rates for moment estimates +//! - ϵ is any small value to prevent division by zero +//! - g_t are the gradients at time step t +//! - m_t are the biased first moment estimates of the gradient at time step t +//! - v_t are the biased second raw moment estimates of the gradient at time step t +//! - θ_t are the model parameters at time step t +//! - t is the time step +//! +//! Required: +//! θ_0 +//! +//! Initialize: +//! m_0 <- 0 +//! v_0 <- 0 +//! t <- 0 +//! +//! while θ_t not converged do +//! m_t = β_1 * m_{t−1} + (1 − β_1) * g_t +//! v_t = β_2 * v_{t−1} + (1 − β_2) * g_t^2 +//! m_hat_t = m_t / 1 - β_1^t +//! v_hat_t = v_t / 1 - β_2^t +//! θ_t = θ_{t-1} − α * m_hat_t / (sqrt(v_hat_t) + ϵ) +//! +//! ## Resources: +//! - Adam: A Method for Stochastic Optimization (by Diederik P. Kingma and Jimmy Ba): +//! - [https://arxiv.org/abs/1412.6980] +//! - PyTorch Adam optimizer: +//! - [https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam] +//! +pub struct Adam { + learning_rate: f64, // alpha: initial step size for iterative optimization + betas: (f64, f64), // betas: exponential decay rates for moment estimates + epsilon: f64, // epsilon: prevent division by zero + m: Vec, // m: biased first moment estimate of the gradient vector + v: Vec, // v: biased second raw moment estimate of the gradient vector + t: usize, // t: time step +} + +impl Adam { + pub fn new( + learning_rate: Option, + betas: Option<(f64, f64)>, + epsilon: Option, + params_len: usize, + ) -> Self { + Adam { + learning_rate: learning_rate.unwrap_or(1e-3), // typical good default lr + betas: betas.unwrap_or((0.9, 0.999)), // typical good default decay rates + epsilon: epsilon.unwrap_or(1e-8), // typical good default epsilon + m: vec![0.0; params_len], // first moment vector elements all initialized to zero + v: vec![0.0; params_len], // second moment vector elements all initialized to zero + t: 0, // time step initialized to zero + } + } + + pub fn step(&mut self, gradients: &[f64]) -> Vec { + let mut model_params = vec![0.0; gradients.len()]; + self.t += 1; + + for i in 0..gradients.len() { + // update biased first moment estimate and second raw moment estimate + self.m[i] = self.betas.0 * self.m[i] + (1.0 - self.betas.0) * gradients[i]; + self.v[i] = self.betas.1 * self.v[i] + (1.0 - self.betas.1) * gradients[i].powf(2f64); + + // compute bias-corrected first moment estimate and second raw moment estimate + let m_hat = self.m[i] / (1.0 - self.betas.0.powi(self.t as i32)); + let v_hat = self.v[i] / (1.0 - self.betas.1.powi(self.t as i32)); + + // update model parameters + model_params[i] -= self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon); + } + model_params // return updated model parameters + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_adam_init_default_values() { + let optimizer = Adam::new(None, None, None, 1); + + assert_eq!(optimizer.learning_rate, 0.001); + assert_eq!(optimizer.betas, (0.9, 0.999)); + assert_eq!(optimizer.epsilon, 1e-8); + assert_eq!(optimizer.m, vec![0.0; 1]); + assert_eq!(optimizer.v, vec![0.0; 1]); + assert_eq!(optimizer.t, 0); + } + + #[test] + fn test_adam_init_custom_lr_value() { + let optimizer = Adam::new(Some(0.9), None, None, 2); + + assert_eq!(optimizer.learning_rate, 0.9); + assert_eq!(optimizer.betas, (0.9, 0.999)); + assert_eq!(optimizer.epsilon, 1e-8); + assert_eq!(optimizer.m, vec![0.0; 2]); + assert_eq!(optimizer.v, vec![0.0; 2]); + assert_eq!(optimizer.t, 0); + } + + #[test] + fn test_adam_init_custom_betas_value() { + let optimizer = Adam::new(None, Some((0.8, 0.899)), None, 3); + + assert_eq!(optimizer.learning_rate, 0.001); + assert_eq!(optimizer.betas, (0.8, 0.899)); + assert_eq!(optimizer.epsilon, 1e-8); + assert_eq!(optimizer.m, vec![0.0; 3]); + assert_eq!(optimizer.v, vec![0.0; 3]); + assert_eq!(optimizer.t, 0); + } + + #[test] + fn test_adam_init_custom_epsilon_value() { + let optimizer = Adam::new(None, None, Some(1e-10), 4); + + assert_eq!(optimizer.learning_rate, 0.001); + assert_eq!(optimizer.betas, (0.9, 0.999)); + assert_eq!(optimizer.epsilon, 1e-10); + assert_eq!(optimizer.m, vec![0.0; 4]); + assert_eq!(optimizer.v, vec![0.0; 4]); + assert_eq!(optimizer.t, 0); + } + + #[test] + fn test_adam_init_all_custom_values() { + let optimizer = Adam::new(Some(1.0), Some((0.001, 0.099)), Some(1e-1), 5); + + assert_eq!(optimizer.learning_rate, 1.0); + assert_eq!(optimizer.betas, (0.001, 0.099)); + assert_eq!(optimizer.epsilon, 1e-1); + assert_eq!(optimizer.m, vec![0.0; 5]); + assert_eq!(optimizer.v, vec![0.0; 5]); + assert_eq!(optimizer.t, 0); + } + + #[test] + fn test_adam_step_default_params() { + let gradients = vec![-1.0, 2.0, -3.0, 4.0, -5.0, 6.0, -7.0, 8.0]; + + let mut optimizer = Adam::new(None, None, None, 8); + let updated_params = optimizer.step(&gradients); + + assert_eq!( + updated_params, + vec![ + 0.0009999999900000003, + -0.000999999995, + 0.0009999999966666666, + -0.0009999999975, + 0.000999999998, + -0.0009999999983333334, + 0.0009999999985714286, + -0.00099999999875 + ] + ); + } + + #[test] + fn test_adam_step_custom_params() { + let gradients = vec![9.0, -8.0, 7.0, -6.0, 5.0, -4.0, 3.0, -2.0, 1.0]; + + let mut optimizer = Adam::new(Some(0.005), Some((0.5, 0.599)), Some(1e-5), 9); + let updated_params = optimizer.step(&gradients); + + assert_eq!( + updated_params, + vec![ + -0.004999994444450618, + 0.004999993750007813, + -0.004999992857153062, + 0.004999991666680556, + -0.004999990000020001, + 0.004999987500031251, + -0.004999983333388888, + 0.004999975000124999, + -0.0049999500004999945 + ] + ); + } + + #[test] + fn test_adam_step_empty_gradients_array() { + let gradients = vec![]; + + let mut optimizer = Adam::new(None, None, None, 0); + let updated_params = optimizer.step(&gradients); + + assert_eq!(updated_params, vec![]); + } + + #[ignore] + #[test] + fn test_adam_step_iteratively_until_convergence_with_default_params() { + const CONVERGENCE_THRESHOLD: f64 = 1e-5; + let gradients = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + + let mut optimizer = Adam::new(None, None, None, 6); + + let mut model_params = vec![0.0; 6]; + let mut updated_params = optimizer.step(&gradients); + + while (updated_params + .iter() + .zip(model_params.iter()) + .map(|(x, y)| x - y) + .collect::>()) + .iter() + .map(|&x| x.powi(2)) + .sum::() + .sqrt() + > CONVERGENCE_THRESHOLD + { + model_params = updated_params; + updated_params = optimizer.step(&gradients); + } + + assert!(updated_params < vec![CONVERGENCE_THRESHOLD; 6]); + assert_ne!(updated_params, model_params); + assert_eq!( + updated_params, + vec![ + -0.0009999999899999931, + -0.0009999999949999929, + -0.0009999999966666597, + -0.0009999999974999929, + -0.0009999999979999927, + -0.0009999999983333263 + ] + ); + } + + #[ignore] + #[test] + fn test_adam_step_iteratively_until_convergence_with_custom_params() { + const CONVERGENCE_THRESHOLD: f64 = 1e-7; + let gradients = vec![7.0, -8.0, 9.0, -10.0, 11.0, -12.0, 13.0]; + + let mut optimizer = Adam::new(Some(0.005), Some((0.8, 0.899)), Some(1e-5), 7); + + let mut model_params = vec![0.0; 7]; + let mut updated_params = optimizer.step(&gradients); + + while (updated_params + .iter() + .zip(model_params.iter()) + .map(|(x, y)| x - y) + .collect::>()) + .iter() + .map(|&x| x.powi(2)) + .sum::() + .sqrt() + > CONVERGENCE_THRESHOLD + { + model_params = updated_params; + updated_params = optimizer.step(&gradients); + } + + assert!(updated_params < vec![CONVERGENCE_THRESHOLD; 7]); + assert_ne!(updated_params, model_params); + assert_eq!( + updated_params, + vec![ + -0.004999992857153061, + 0.004999993750007814, + -0.0049999944444506185, + 0.004999995000005001, + -0.004999995454549587, + 0.004999995833336807, + -0.004999996153849113 + ] + ); + } +} diff --git a/src/machine_learning/optimization/gradient_descent.rs b/src/machine_learning/optimization/gradient_descent.rs new file mode 100644 index 00000000000..fd322a23ff3 --- /dev/null +++ b/src/machine_learning/optimization/gradient_descent.rs @@ -0,0 +1,86 @@ +/// Gradient Descent Optimization +/// +/// Gradient descent is an iterative optimization algorithm used to find the minimum of a function. +/// It works by updating the parameters (in this case, elements of the vector `x`) in the direction of +/// the steepest decrease in the function's value. This is achieved by subtracting the gradient of +/// the function at the current point from the current point. The learning rate controls the step size. +/// +/// The equation for a single parameter (univariate) is: +/// x_{k+1} = x_k - learning_rate * derivative_of_function(x_k) +/// +/// For multivariate functions, it extends to each parameter: +/// x_{k+1} = x_k - learning_rate * gradient_of_function(x_k) +/// +/// # Arguments +/// +/// * `derivative_fn` - The function that calculates the gradient of the objective function at a given point. +/// * `x` - The initial parameter vector to be optimized. +/// * `learning_rate` - Step size for each iteration. +/// * `num_iterations` - The number of iterations to run the optimization. +/// +/// # Returns +/// +/// A reference to the optimized parameter vector `x`. + +pub fn gradient_descent( + derivative_fn: impl Fn(&[f64]) -> Vec, + x: &mut Vec, + learning_rate: f64, + num_iterations: i32, +) -> &mut Vec { + for _ in 0..num_iterations { + let gradient = derivative_fn(x); + for (x_k, grad) in x.iter_mut().zip(gradient.iter()) { + *x_k -= learning_rate * grad; + } + } + + x +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_gradient_descent_optimized() { + fn derivative_of_square(params: &[f64]) -> Vec { + params.iter().map(|x| 2. * x).collect() + } + + let mut x: Vec = vec![5.0, 6.0]; + let learning_rate: f64 = 0.03; + let num_iterations: i32 = 1000; + + let minimized_vector = + gradient_descent(derivative_of_square, &mut x, learning_rate, num_iterations); + + let test_vector = [0.0, 0.0]; + + let tolerance = 1e-6; + for (minimized_value, test_value) in minimized_vector.iter().zip(test_vector.iter()) { + assert!((minimized_value - test_value).abs() < tolerance); + } + } + + #[test] + fn test_gradient_descent_unoptimized() { + fn derivative_of_square(params: &[f64]) -> Vec { + params.iter().map(|x| 2. * x).collect() + } + + let mut x: Vec = vec![5.0, 6.0]; + let learning_rate: f64 = 0.03; + let num_iterations: i32 = 10; + + let minimized_vector = + gradient_descent(derivative_of_square, &mut x, learning_rate, num_iterations); + + let test_vector = [0.0, 0.0]; + + let tolerance = 1e-6; + for (minimized_value, test_value) in minimized_vector.iter().zip(test_vector.iter()) { + assert!((minimized_value - test_value).abs() >= tolerance); + } + } +} diff --git a/src/machine_learning/optimization/mod.rs b/src/machine_learning/optimization/mod.rs new file mode 100644 index 00000000000..7a962993beb --- /dev/null +++ b/src/machine_learning/optimization/mod.rs @@ -0,0 +1,5 @@ +mod adam; +mod gradient_descent; + +pub use self::adam::Adam; +pub use self::gradient_descent::gradient_descent; diff --git a/src/math/abs.rs b/src/math/abs.rs new file mode 100644 index 00000000000..a10a4b30361 --- /dev/null +++ b/src/math/abs.rs @@ -0,0 +1,38 @@ +/// This function returns the absolute value of a number.\ +/// The absolute value of a number is the non-negative value of the number, regardless of its sign.\ +/// +/// Wikipedia: +pub fn abs(num: T) -> T +where + T: std::ops::Neg + PartialOrd + Copy + num_traits::Zero, +{ + if num < T::zero() { + return -num; + } + num +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_negative_number_i32() { + assert_eq!(69, abs(-69)); + } + + #[test] + fn test_negative_number_f64() { + assert_eq!(69.69, abs(-69.69)); + } + + #[test] + fn zero() { + assert_eq!(0.0, abs(0.0)); + } + + #[test] + fn positive_number() { + assert_eq!(69.69, abs(69.69)); + } +} diff --git a/src/math/aliquot_sum.rs b/src/math/aliquot_sum.rs new file mode 100644 index 00000000000..28bf5981a5e --- /dev/null +++ b/src/math/aliquot_sum.rs @@ -0,0 +1,56 @@ +/// Aliquot sum of a number is defined as the sum of the proper divisors of a number.\ +/// i.e. all the divisors of a number apart from the number itself. +/// +/// ## Example: +/// The aliquot sum of 6 is (1 + 2 + 3) = 6, and that of 15 is (1 + 3 + 5) = 9 +/// +/// Wikipedia article on Aliquot Sum: + +pub fn aliquot_sum(number: u64) -> u64 { + if number == 0 { + panic!("Input has to be positive.") + } + + (1..=number / 2).filter(|&d| number % d == 0).sum() +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_aliquot_sum { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (number, expected) = $tc; + assert_eq!(aliquot_sum(number), expected); + } + )* + } + } + + test_aliquot_sum! { + test_with_1: (1, 0), + test_with_2: (2, 1), + test_with_3: (3, 1), + test_with_4: (4, 1+2), + test_with_5: (5, 1), + test_with_6: (6, 6), + test_with_7: (7, 1), + test_with_8: (8, 1+2+4), + test_with_9: (9, 1+3), + test_with_10: (10, 1+2+5), + test_with_15: (15, 9), + test_with_343: (343, 57), + test_with_344: (344, 316), + test_with_500: (500, 592), + test_with_501: (501, 171), + } + + #[test] + #[should_panic] + fn panics_if_input_is_zero() { + aliquot_sum(0); + } +} diff --git a/src/math/amicable_numbers.rs b/src/math/amicable_numbers.rs new file mode 100644 index 00000000000..35ff4d7fcfe --- /dev/null +++ b/src/math/amicable_numbers.rs @@ -0,0 +1,68 @@ +// Operations based around amicable numbers +// Suports u32 but should be interchangable with other types +// Wikipedia reference: https://en.wikipedia.org/wiki/Amicable_numbers + +// Returns vec of amicable pairs below N +// N must be positive +pub fn amicable_pairs_under_n(n: u32) -> Option> { + let mut factor_sums = vec![0; n as usize]; + + // Make a list of the sum of the factors of each number below N + for i in 1..n { + for j in (i * 2..n).step_by(i as usize) { + factor_sums[j as usize] += i; + } + } + + // Default value of (0, 0) if no pairs are found + let mut out = vec![(0, 0)]; + // Check if numbers are amicable then append + for (i, x) in factor_sums.iter().enumerate() { + if (*x < n) && (factor_sums[*x as usize] == i as u32) && (*x > i as u32) { + out.push((i as u32, *x)); + } + } + + // Check if anything was added to the vec, if so remove the (0, 0) and return + if out.len() == 1 { + None + } else { + out.remove(0); + Some(out) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + pub fn test_amicable_numbers_below_n() { + // First 10 amicable numbers, sorted (low, high) + let expected_result = vec![ + (220, 284), + (1184, 1210), + (2620, 2924), + (5020, 5564), + (6232, 6368), + (10744, 10856), + (12285, 14595), + (17296, 18416), + (63020, 76084), + (66928, 66992), + ]; + + // Generate pairs under 100,000 + let mut result = amicable_pairs_under_n(100_000).unwrap(); + + // There should be 13 pairs under 100,000 + assert_eq!(result.len(), 13); + + // Check the first 10 against known values + result = result[..10].to_vec(); + assert_eq!(result, expected_result); + + // N that does not have any amicable pairs below it, the result should be None + assert_eq!(amicable_pairs_under_n(100), None); + } +} diff --git a/src/math/area_of_polygon.rs b/src/math/area_of_polygon.rs new file mode 100644 index 00000000000..49d006c56b4 --- /dev/null +++ b/src/math/area_of_polygon.rs @@ -0,0 +1,94 @@ +/** + * @file + * @brief Calculate the area of a polygon defined by a vector of points. + * + * @details + * This program provides a function to calculate the area of a polygon defined by a vector of points. + * The area is calculated using the formula: A = |Σ((xi - xi-1) * (yi + yi-1))| / 2 + * where (xi, yi) are the coordinates of the points in the vector. + * + * @param fig A vector of points defining the polygon. + * @return The area of the polygon. + * + * @author [Gyandeep](https://github.com/Gyan172004) + * @see [Wikipedia - Polygon](https://en.wikipedia.org/wiki/Polygon) + */ + +pub struct Point { + x: f64, + y: f64, +} + +/** + * Calculate the area of a polygon defined by a vector of points. + * @param fig A vector of points defining the polygon. + * @return The area of the polygon. + */ + +pub fn area_of_polygon(fig: &[Point]) -> f64 { + let mut res = 0.0; + + for i in 0..fig.len() { + let p = if i > 0 { + &fig[i - 1] + } else { + &fig[fig.len() - 1] + }; + let q = &fig[i]; + + res += (p.x - q.x) * (p.y + q.y); + } + + f64::abs(res) / 2.0 +} + +#[cfg(test)] +mod tests { + use super::*; + + /** + * Test case for calculating the area of a triangle. + */ + #[test] + fn test_area_triangle() { + let points = vec![ + Point { x: 0.0, y: 0.0 }, + Point { x: 1.0, y: 0.0 }, + Point { x: 0.0, y: 1.0 }, + ]; + + assert_eq!(area_of_polygon(&points), 0.5); + } + + /** + * Test case for calculating the area of a square. + */ + #[test] + fn test_area_square() { + let points = vec![ + Point { x: 0.0, y: 0.0 }, + Point { x: 1.0, y: 0.0 }, + Point { x: 1.0, y: 1.0 }, + Point { x: 0.0, y: 1.0 }, + ]; + + assert_eq!(area_of_polygon(&points), 1.0); + } + + /** + * Test case for calculating the area of a hexagon. + */ + #[test] + fn test_area_hexagon() { + let points = vec![ + Point { x: 0.0, y: 0.0 }, + Point { x: 1.0, y: 0.0 }, + Point { x: 1.5, y: 0.866 }, + Point { x: 1.0, y: 1.732 }, + Point { x: 0.0, y: 1.732 }, + Point { x: -0.5, y: 0.866 }, + ]; + + assert_eq!(area_of_polygon(&points), 2.598); + } +} diff --git a/src/math/area_under_curve.rs b/src/math/area_under_curve.rs new file mode 100644 index 00000000000..d4d7133ec38 --- /dev/null +++ b/src/math/area_under_curve.rs @@ -0,0 +1,53 @@ +pub fn area_under_curve(start: f64, end: f64, func: fn(f64) -> f64, step_count: usize) -> f64 { + assert!(step_count > 0); + + let (start, end) = if start > end { + (end, start) + } else { + (start, end) + }; //swap if bounds reversed + + let step_length: f64 = (end - start) / step_count as f64; + let mut area = 0f64; + let mut fx1 = func(start); + let mut fx2: f64; + + for eval_point in (1..=step_count).map(|x| (x as f64 * step_length) + start) { + fx2 = func(eval_point); + area += (fx2 + fx1).abs() * step_length * 0.5; + fx1 = fx2; + } + + area +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_linear_func() { + assert_eq!(area_under_curve(1f64, 2f64, |x| x, 10), 1.5000000000000002); + } + + #[test] + fn test_quadratic_func() { + assert_eq!( + area_under_curve(1f64, 2f64, |x| x * x, 1000), + 2.333333500000005 + ); + } + + #[test] + fn test_zero_length() { + assert_eq!(area_under_curve(0f64, 0f64, |x| x * x, 1000), 0.0); + } + + #[test] + fn test_reverse() { + assert_eq!( + area_under_curve(1f64, 2f64, |x| x, 10), + area_under_curve(2f64, 1f64, |x| x, 10) + ); + } +} diff --git a/src/math/average.rs b/src/math/average.rs new file mode 100644 index 00000000000..dfa38f3a92f --- /dev/null +++ b/src/math/average.rs @@ -0,0 +1,122 @@ +#[doc = "# Average +Mean, Median, and Mode, in mathematics, the three principal ways of designating the average value of a list of numbers. +The arithmetic mean is found by adding the numbers and dividing the sum by the number of numbers in the list. +This is what is most often meant by an average. The median is the middle value in a list ordered from smallest to largest. +The mode is the most frequently occurring value on the list. + +Reference: https://www.britannica.com/science/mean-median-and-mode + +This program approximates the mean, median and mode of a finite sequence. +Note: Floats sequences are not allowed for `mode` function. +"] +use std::collections::HashMap; +use std::collections::HashSet; + +use num_traits::Num; + +fn sum(sequence: Vec) -> T { + sequence.iter().fold(T::zero(), |acc, x| acc + *x) +} + +/// # Argument +/// +/// * `sequence` - A vector of numbers. +/// Returns mean of `sequence`. +pub fn mean(sequence: Vec) -> Option { + let len = sequence.len(); + if len == 0 { + return None; + } + Some(sum(sequence) / (T::from_usize(len).unwrap())) +} + +fn mean_of_two(a: T, b: T) -> T { + (a + b) / (T::one() + T::one()) +} + +/// # Argument +/// +/// * `sequence` - A vector of numbers. +/// Returns median of `sequence`. + +pub fn median(mut sequence: Vec) -> Option { + if sequence.is_empty() { + return None; + } + sequence.sort_by(|a, b| a.partial_cmp(b).unwrap()); + if sequence.len() % 2 == 1 { + let k = (sequence.len() + 1) / 2; + Some(sequence[k - 1]) + } else { + let j = (sequence.len()) / 2; + Some(mean_of_two(sequence[j - 1], sequence[j])) + } +} + +fn histogram(sequence: Vec) -> HashMap { + sequence.into_iter().fold(HashMap::new(), |mut res, val| { + *res.entry(val).or_insert(0) += 1; + res + }) +} + +/// # Argument +/// +/// * `sequence` - The input vector. +/// Returns mode of `sequence`. +pub fn mode(sequence: Vec) -> Option> { + if sequence.is_empty() { + return None; + } + let hist = histogram(sequence); + let max_count = *hist.values().max().unwrap(); + Some( + hist.into_iter() + .filter(|(_, count)| *count == max_count) + .map(|(value, _)| value) + .collect(), + ) +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn median_test() { + assert_eq!(median(vec![4, 53, 2, 1, 9, 0, 2, 3, 6]).unwrap(), 3); + assert_eq!(median(vec![-9, -8, 0, 1, 2, 2, 3, 4, 6, 9, 53]).unwrap(), 2); + assert_eq!(median(vec![2, 3]).unwrap(), 2); + assert_eq!(median(vec![3.0, 2.0]).unwrap(), 2.5); + assert_eq!(median(vec![1.0, 700.0, 5.0]).unwrap(), 5.0); + assert!(median(Vec::::new()).is_none()); + assert!(median(Vec::::new()).is_none()); + } + #[test] + fn mode_test() { + assert_eq!( + mode(vec![4, 53, 2, 1, 9, 0, 2, 3, 6]).unwrap(), + HashSet::from([2]) + ); + assert_eq!( + mode(vec![-9, -8, 0, 1, 2, 2, 3, -1, -1, 9, -1, -9]).unwrap(), + HashSet::from([-1]) + ); + assert_eq!(mode(vec!["a", "b", "a"]).unwrap(), HashSet::from(["a"])); + assert_eq!(mode(vec![1, 2, 2, 1]).unwrap(), HashSet::from([1, 2])); + assert_eq!(mode(vec![1, 2, 2, 1, 3]).unwrap(), HashSet::from([1, 2])); + assert_eq!(mode(vec![1]).unwrap(), HashSet::from([1])); + assert!(mode(Vec::::new()).is_none()); + } + #[test] + fn mean_test() { + assert_eq!(mean(vec![2023.1112]).unwrap(), 2023.1112); + assert_eq!(mean(vec![0.0, 1.0, 2.0, 3.0, 4.0]).unwrap(), 2.0); + assert_eq!( + mean(vec![-7.0, 4.0, 53.0, 2.0, 1.0, -9.0, 0.0, 2.0, 3.0, -6.0]).unwrap(), + 4.3 + ); + assert_eq!(mean(vec![1, 2]).unwrap(), 1); + assert!(mean(Vec::::new()).is_none()); + assert!(mean(Vec::::new()).is_none()); + } +} diff --git a/src/math/baby_step_giant_step.rs b/src/math/baby_step_giant_step.rs index 180f3f4326a..1c4d1cc74c7 100644 --- a/src/math/baby_step_giant_step.rs +++ b/src/math/baby_step_giant_step.rs @@ -1,3 +1,4 @@ +use crate::math::greatest_common_divisor; /// Baby-step Giant-step algorithm /// /// Solving discrete logarithm problem: @@ -10,8 +11,8 @@ use std::collections::HashMap; pub fn baby_step_giant_step(a: usize, b: usize, n: usize) -> Option { - if b == 1 { - return Some(n); + if greatest_common_divisor::greatest_common_divisor_stein(a as u64, n as u64) != 1 { + return None; } let mut h_map = HashMap::new(); @@ -41,6 +42,9 @@ mod tests { fn small_numbers() { assert_eq!(baby_step_giant_step(5, 3, 11), Some(2)); assert_eq!(baby_step_giant_step(3, 83, 100), Some(9)); + assert_eq!(baby_step_giant_step(9, 1, 61), Some(5)); + assert_eq!(baby_step_giant_step(5, 1, 67), Some(22)); + assert_eq!(baby_step_giant_step(7, 1, 45), Some(12)); } #[test] @@ -69,4 +73,11 @@ mod tests { Some(14215560) ); } + + #[test] + fn no_solution() { + assert!(baby_step_giant_step(7, 6, 45).is_none()); + assert!(baby_step_giant_step(23, 15, 85).is_none()); + assert!(baby_step_giant_step(2, 1, 84).is_none()); + } } diff --git a/src/math/bell_numbers.rs b/src/math/bell_numbers.rs new file mode 100644 index 00000000000..9a66c83087e --- /dev/null +++ b/src/math/bell_numbers.rs @@ -0,0 +1,147 @@ +use num_bigint::BigUint; +use num_traits::{One, Zero}; +use std::sync::RwLock; + +/// Returns the number of ways you can select r items given n options +fn n_choose_r(n: u32, r: u32) -> BigUint { + if r == n || r == 0 { + return One::one(); + } + + if r > n { + return Zero::zero(); + } + + // Any combination will only need to be computed once, thus giving no need to + // memoize this function + + let product: BigUint = (0..r).fold(BigUint::one(), |acc, x| { + (acc * BigUint::from(n - x)) / BigUint::from(x + 1) + }); + + product +} + +/// A memoization table for storing previous results +struct MemTable { + buffer: Vec, +} + +impl MemTable { + const fn new() -> Self { + MemTable { buffer: Vec::new() } + } + + fn get(&self, n: usize) -> Option { + if n == 0 || n == 1 { + Some(BigUint::one()) + } else if let Some(entry) = self.buffer.get(n) { + if *entry == BigUint::zero() { + None + } else { + Some(entry.clone()) + } + } else { + None + } + } + + fn set(&mut self, n: usize, b: BigUint) { + self.buffer[n] = b; + } + + #[inline] + fn capacity(&self) -> usize { + self.buffer.capacity() + } + + #[inline] + fn resize(&mut self, new_size: usize) { + if new_size > self.buffer.len() { + self.buffer.resize(new_size, Zero::zero()); + } + } +} + +// Implemented with RwLock so it is accessible across threads +static LOOKUP_TABLE_LOCK: RwLock = RwLock::new(MemTable::new()); + +pub fn bell_number(n: u32) -> BigUint { + let needs_resize; + + // Check if number is already in lookup table + { + let lookup_table = LOOKUP_TABLE_LOCK.read().unwrap(); + + if let Some(entry) = lookup_table.get(n as usize) { + return entry; + } + + needs_resize = (n + 1) as usize > lookup_table.capacity(); + } + + // Resize table before recursion so that if more values need to be added during recursion the table isn't + // reallocated every single time + if needs_resize { + let mut lookup_table = LOOKUP_TABLE_LOCK.write().unwrap(); + + lookup_table.resize((n + 1) as usize); + } + + let new_bell_number: BigUint = (0..n).map(|x| bell_number(x) * n_choose_r(n - 1, x)).sum(); + + // Add new number to lookup table + { + let mut lookup_table = LOOKUP_TABLE_LOCK.write().unwrap(); + + lookup_table.set(n as usize, new_bell_number.clone()); + } + + new_bell_number +} + +#[cfg(test)] +pub mod tests { + use super::*; + use std::str::FromStr; + + #[test] + fn test_choose_zero() { + for i in 1..100 { + assert_eq!(n_choose_r(i, 0), One::one()); + } + } + + #[test] + fn test_combination() { + let five_choose_1 = BigUint::from(5u32); + assert_eq!(n_choose_r(5, 1), five_choose_1); + assert_eq!(n_choose_r(5, 4), five_choose_1); + + let ten_choose_3 = BigUint::from(120u32); + assert_eq!(n_choose_r(10, 3), ten_choose_3); + assert_eq!(n_choose_r(10, 7), ten_choose_3); + + let fourty_two_choose_thirty = BigUint::from_str("11058116888").unwrap(); + assert_eq!(n_choose_r(42, 30), fourty_two_choose_thirty); + assert_eq!(n_choose_r(42, 12), fourty_two_choose_thirty); + } + + #[test] + fn test_bell_numbers() { + let bell_one = BigUint::from(1u32); + assert_eq!(bell_number(1), bell_one); + + let bell_three = BigUint::from(5u32); + assert_eq!(bell_number(3), bell_three); + + let bell_eight = BigUint::from(4140u32); + assert_eq!(bell_number(8), bell_eight); + + let bell_six = BigUint::from(203u32); + assert_eq!(bell_number(6), bell_six); + + let bell_twenty_six = BigUint::from_str("49631246523618756274").unwrap(); + assert_eq!(bell_number(26), bell_twenty_six); + } +} diff --git a/src/math/binary_exponentiation.rs b/src/math/binary_exponentiation.rs new file mode 100644 index 00000000000..d7661adbea8 --- /dev/null +++ b/src/math/binary_exponentiation.rs @@ -0,0 +1,50 @@ +// Binary exponentiation is an algorithm to compute a power in O(logN) where N is the power. +// +// For example, to naively compute n^100, we multiply n 99 times for a O(N) algorithm. +// +// With binary exponentiation we can reduce the number of muliplications by only finding the binary +// exponents. n^100 = n^64 * n^32 * n^4. We can compute n^64 by ((((n^2)^2)^2)...), which is +// logN multiplications. +// +// We know which binary exponents to add by looking at the set bits in the power. For 100, we know +// the bits for 64, 32, and 4 are set. + +// Computes n^p +pub fn binary_exponentiation(mut n: u64, mut p: u32) -> u64 { + let mut result_pow: u64 = 1; + while p > 0 { + if p & 1 == 1 { + result_pow *= n; + } + p >>= 1; + n *= n; + } + result_pow +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basic() { + // Need to be careful about large exponents. It is easy to hit overflows. + assert_eq!(binary_exponentiation(2, 3), 8); + assert_eq!(binary_exponentiation(4, 12), 16777216); + assert_eq!(binary_exponentiation(6, 12), 2176782336); + assert_eq!(binary_exponentiation(10, 4), 10000); + assert_eq!(binary_exponentiation(20, 3), 8000); + assert_eq!(binary_exponentiation(3, 21), 10460353203); + } + + #[test] + fn up_to_ten() { + // Compute all powers from up to ten, using the standard library as the source of truth. + for i in 0..10 { + for j in 0..10 { + println!("{i}, {j}"); + assert_eq!(binary_exponentiation(i, j), u64::pow(i, j)) + } + } + } +} diff --git a/src/math/binomial_coefficient.rs b/src/math/binomial_coefficient.rs new file mode 100644 index 00000000000..e9140f85cac --- /dev/null +++ b/src/math/binomial_coefficient.rs @@ -0,0 +1,75 @@ +extern crate num_bigint; +extern crate num_traits; + +use num_bigint::BigInt; +use num_traits::FromPrimitive; + +/// Calculate binomial coefficient (n choose k). +/// +/// This function computes the binomial coefficient C(n, k) using BigInt +/// for arbitrary precision arithmetic. +/// +/// Formula: +/// C(n, k) = n! / (k! * (n - k)!) +/// +/// Reference: +/// [Binomial Coefficient - Wikipedia](https://en.wikipedia.org/wiki/Binomial_coefficient) +/// +/// # Arguments +/// +/// * `n` - The total number of items. +/// * `k` - The number of items to choose from `n`. +/// +/// # Returns +/// +/// Returns the binomial coefficient C(n, k) as a BigInt. +pub fn binom(n: u64, k: u64) -> BigInt { + let mut res = BigInt::from_u64(1).unwrap(); + for i in 0..k { + res = (res * BigInt::from_u64(n - i).unwrap()) / BigInt::from_u64(i + 1).unwrap(); + } + res +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_binom_5_2() { + assert_eq!(binom(5, 2), BigInt::from(10)); + } + + #[test] + fn test_binom_10_5() { + assert_eq!(binom(10, 5), BigInt::from(252)); + } + + #[test] + fn test_binom_0_0() { + assert_eq!(binom(0, 0), BigInt::from(1)); + } + + #[test] + fn test_binom_large_n_small_k() { + assert_eq!(binom(1000, 2), BigInt::from(499500)); + } + + #[test] + fn test_binom_random_1() { + // Random test case 1 + assert_eq!(binom(7, 4), BigInt::from(35)); + } + + #[test] + fn test_binom_random_2() { + // Random test case 2 + assert_eq!(binom(12, 3), BigInt::from(220)); + } + + #[test] + fn test_binom_random_3() { + // Random test case 3 + assert_eq!(binom(20, 10), BigInt::from(184_756)); + } +} diff --git a/src/math/catalan_numbers.rs b/src/math/catalan_numbers.rs new file mode 100644 index 00000000000..4aceec3289e --- /dev/null +++ b/src/math/catalan_numbers.rs @@ -0,0 +1,59 @@ +// Introduction to Catalan Numbers: +// Catalan numbers are a sequence of natural numbers with many applications in combinatorial mathematics. +// They are named after the Belgian mathematician Eugène Charles Catalan, who contributed to their study. +// Catalan numbers appear in various combinatorial problems, including counting correct bracket sequences, +// full binary trees, triangulations of polygons, and more. + +// For more information, refer to the Wikipedia page on Catalan numbers: +// https://en.wikipedia.org/wiki/Catalan_number + +// Author: [Gyandeep] (https://github.com/Gyan172004) + +const MOD: i64 = 1000000007; // Define your MOD value here +const MAX: usize = 1005; // Define your MAX value here + +pub fn init_catalan() -> Vec { + let mut catalan = vec![0; MAX]; + catalan[0] = 1; + catalan[1] = 1; + + for i in 2..MAX { + catalan[i] = 0; + for j in 0..i { + catalan[i] += (catalan[j] * catalan[i - j - 1]) % MOD; + if catalan[i] >= MOD { + catalan[i] -= MOD; + } + } + } + + catalan +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_catalan() { + let catalan = init_catalan(); + + // Test case 1: Catalan number for n = 0 + assert_eq!(catalan[0], 1); + + // Test case 2: Catalan number for n = 1 + assert_eq!(catalan[1], 1); + + // Test case 3: Catalan number for n = 5 + assert_eq!(catalan[5], 42); + + // Test case 4: Catalan number for n = 10 + assert_eq!(catalan[10], 16796); + + // Test case 5: Catalan number for n = 15 + assert_eq!(catalan[15], 9694845); + + // Print a success message if all tests pass + println!("All tests passed!"); + } +} diff --git a/src/math/ceil.rs b/src/math/ceil.rs new file mode 100644 index 00000000000..c399f5dcd4d --- /dev/null +++ b/src/math/ceil.rs @@ -0,0 +1,58 @@ +// In mathematics and computer science, the ceiling function maps x to the least integer greater than or equal to x +// Source: https://en.wikipedia.org/wiki/Floor_and_ceiling_functions + +pub fn ceil(x: f64) -> f64 { + let x_rounded_towards_zero = x as i32 as f64; + if x < 0. || x_rounded_towards_zero == x { + x_rounded_towards_zero + } else { + x_rounded_towards_zero + 1_f64 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn positive_decimal() { + let num = 1.10; + assert_eq!(ceil(num), num.ceil()); + } + + #[test] + fn positive_decimal_with_small_number() { + let num = 3.01; + assert_eq!(ceil(num), num.ceil()); + } + + #[test] + fn positive_integer() { + let num = 1.00; + assert_eq!(ceil(num), num.ceil()); + } + + #[test] + fn negative_decimal() { + let num = -1.10; + assert_eq!(ceil(num), num.ceil()); + } + + #[test] + fn negative_decimal_with_small_number() { + let num = -1.01; + assert_eq!(ceil(num), num.ceil()); + } + + #[test] + fn negative_integer() { + let num = -1.00; + assert_eq!(ceil(num), num.ceil()); + } + + #[test] + fn zero() { + let num = 0.00; + assert_eq!(ceil(num), num.ceil()); + } +} diff --git a/src/math/chinese_remainder_theorem.rs b/src/math/chinese_remainder_theorem.rs new file mode 100644 index 00000000000..23bca371803 --- /dev/null +++ b/src/math/chinese_remainder_theorem.rs @@ -0,0 +1,35 @@ +use super::extended_euclidean_algorithm; + +fn mod_inv(x: i32, n: i32) -> Option { + let (g, x, _) = extended_euclidean_algorithm(x, n); + if g == 1 { + Some((x % n + n) % n) + } else { + None + } +} + +pub fn chinese_remainder_theorem(residues: &[i32], modulli: &[i32]) -> Option { + let prod = modulli.iter().product::(); + + let mut sum = 0; + + for (&residue, &modulus) in residues.iter().zip(modulli) { + let p = prod / modulus; + sum += residue * mod_inv(p, modulus)? * p + } + Some(sum % prod) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basic() { + assert_eq!(chinese_remainder_theorem(&[3, 5, 7], &[2, 3, 1]), Some(5)); + assert_eq!(chinese_remainder_theorem(&[1, 4, 6], &[3, 5, 7]), Some(34)); + assert_eq!(chinese_remainder_theorem(&[1, 4, 6], &[1, 2, 0]), None); + assert_eq!(chinese_remainder_theorem(&[2, 5, 7], &[6, 9, 15]), None); + } +} diff --git a/src/math/collatz_sequence.rs b/src/math/collatz_sequence.rs new file mode 100644 index 00000000000..32400cc5baa --- /dev/null +++ b/src/math/collatz_sequence.rs @@ -0,0 +1,32 @@ +// collatz conjecture : https://en.wikipedia.org/wiki/Collatz_conjecture +pub fn sequence(mut n: usize) -> Option> { + if n == 0 { + return None; + } + let mut list: Vec = vec![]; + while n != 1 { + list.push(n); + if n % 2 == 0 { + n /= 2; + } else { + n = 3 * n + 1; + } + } + list.push(n); + Some(list) +} + +#[cfg(test)] +mod tests { + use super::sequence; + + #[test] + fn validity_check() { + assert_eq!(sequence(10).unwrap(), [10, 5, 16, 8, 4, 2, 1]); + assert_eq!( + sequence(15).unwrap(), + [15, 46, 23, 70, 35, 106, 53, 160, 80, 40, 20, 10, 5, 16, 8, 4, 2, 1] + ); + assert_eq!(sequence(0).unwrap_or_else(|| vec![0]), [0]); + } +} diff --git a/src/math/combinations.rs b/src/math/combinations.rs new file mode 100644 index 00000000000..8117a66f547 --- /dev/null +++ b/src/math/combinations.rs @@ -0,0 +1,47 @@ +// Function to calculate combinations of k elements from a set of n elements +pub fn combinations(n: i64, k: i64) -> i64 { + // Check if either n or k is negative, and panic if so + if n < 0 || k < 0 { + panic!("Please insert positive values"); + } + + let mut res: i64 = 1; + for i in 0..k { + // Calculate the product of (n - i) and update the result + res *= n - i; + // Divide by (i + 1) to calculate the combination + res /= i + 1; + } + + res +} + +#[cfg(test)] +mod tests { + use super::*; + + // Test case for combinations(10, 5) + #[test] + fn test_combinations_10_choose_5() { + assert_eq!(combinations(10, 5), 252); + } + + // Test case for combinations(6, 3) + #[test] + fn test_combinations_6_choose_3() { + assert_eq!(combinations(6, 3), 20); + } + + // Test case for combinations(20, 5) + #[test] + fn test_combinations_20_choose_5() { + assert_eq!(combinations(20, 5), 15504); + } + + // Test case for invalid input (negative values) + #[test] + #[should_panic(expected = "Please insert positive values")] + fn test_combinations_invalid_input() { + combinations(-5, 10); + } +} diff --git a/src/math/cross_entropy_loss.rs b/src/math/cross_entropy_loss.rs new file mode 100644 index 00000000000..1dbf48022eb --- /dev/null +++ b/src/math/cross_entropy_loss.rs @@ -0,0 +1,40 @@ +//! # Cross-Entropy Loss Function +//! +//! The `cross_entropy_loss` function calculates the cross-entropy loss between the actual and predicted probability distributions. +//! +//! Cross-entropy loss is commonly used in machine learning and deep learning to measure the dissimilarity between two probability distributions. It is often used in classification problems. +//! +//! ## Formula +//! +//! For a pair of actual and predicted probability distributions represented as vectors `actual` and `predicted`, the cross-entropy loss is calculated as: +//! +//! `L = -Σ(actual[i] * ln(predicted[i]))` for all `i` in the range of the vectors +//! +//! Where `ln` is the natural logarithm function, and `Σ` denotes the summation over all elements of the vectors. +//! +//! ## Cross-Entropy Loss Function Implementation +//! +//! This implementation takes two references to vectors of f64 values, `actual` and `predicted`, and returns the cross-entropy loss between them. +//! +pub fn cross_entropy_loss(actual: &[f64], predicted: &[f64]) -> f64 { + let mut loss: Vec = Vec::new(); + for (a, p) in actual.iter().zip(predicted.iter()) { + loss.push(-a * p.ln()); + } + loss.iter().sum() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cross_entropy_loss() { + let test_vector_actual = vec![0., 1., 0., 0., 0., 0.]; + let test_vector = vec![0.1, 0.7, 0.1, 0.05, 0.05, 0.1]; + assert_eq!( + cross_entropy_loss(&test_vector_actual, &test_vector), + 0.35667494393873245 + ); + } +} diff --git a/src/math/decimal_to_fraction.rs b/src/math/decimal_to_fraction.rs new file mode 100644 index 00000000000..5963562b59c --- /dev/null +++ b/src/math/decimal_to_fraction.rs @@ -0,0 +1,67 @@ +pub fn decimal_to_fraction(decimal: f64) -> (i64, i64) { + // Calculate the fractional part of the decimal number + let fractional_part = decimal - decimal.floor(); + + // If the fractional part is zero, the number is already an integer + if fractional_part == 0.0 { + (decimal as i64, 1) + } else { + // Calculate the number of decimal places in the fractional part + let number_of_frac_digits = decimal.to_string().split('.').nth(1).unwrap_or("").len(); + + // Calculate the numerator and denominator using integer multiplication + let numerator = (decimal * 10f64.powi(number_of_frac_digits as i32)) as i64; + let denominator = 10i64.pow(number_of_frac_digits as u32); + + // Find the greatest common divisor (GCD) using Euclid's algorithm + let mut divisor = denominator; + let mut dividend = numerator; + while divisor != 0 { + let r = dividend % divisor; + dividend = divisor; + divisor = r; + } + + // Reduce the fraction by dividing both numerator and denominator by the GCD + let gcd = dividend.abs(); + let numerator = numerator / gcd; + let denominator = denominator / gcd; + + (numerator, denominator) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_decimal_to_fraction_1() { + assert_eq!(decimal_to_fraction(2.0), (2, 1)); + } + + #[test] + fn test_decimal_to_fraction_2() { + assert_eq!(decimal_to_fraction(89.45), (1789, 20)); + } + + #[test] + fn test_decimal_to_fraction_3() { + assert_eq!(decimal_to_fraction(67.), (67, 1)); + } + + #[test] + fn test_decimal_to_fraction_4() { + assert_eq!(decimal_to_fraction(45.2), (226, 5)); + } + + #[test] + fn test_decimal_to_fraction_5() { + assert_eq!(decimal_to_fraction(1.5), (3, 2)); + } + + #[test] + fn test_decimal_to_fraction_6() { + assert_eq!(decimal_to_fraction(6.25), (25, 4)); + } +} diff --git a/src/math/doomsday.rs b/src/math/doomsday.rs new file mode 100644 index 00000000000..3d43f2666bd --- /dev/null +++ b/src/math/doomsday.rs @@ -0,0 +1,36 @@ +const T: [i32; 12] = [0, 3, 2, 5, 0, 3, 5, 1, 4, 6, 2, 4]; + +pub fn doomsday(y: i32, m: i32, d: i32) -> i32 { + let y = if m < 3 { y - 1 } else { y }; + (y + y / 4 - y / 100 + y / 400 + T[(m - 1) as usize] + d) % 7 +} + +pub fn get_week_day(y: i32, m: i32, d: i32) -> String { + let day = doomsday(y, m, d); + let day_str = match day { + 0 => "Sunday", + 1 => "Monday", + 2 => "Tuesday", + 3 => "Wednesday", + 4 => "Thursday", + 5 => "Friday", + 6 => "Saturday", + _ => "Unknown", + }; + + day_str.to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn doomsday_test() { + assert_eq!(get_week_day(1990, 3, 21), "Wednesday"); + assert_eq!(get_week_day(2000, 8, 24), "Thursday"); + assert_eq!(get_week_day(2000, 10, 13), "Friday"); + assert_eq!(get_week_day(2001, 4, 18), "Wednesday"); + assert_eq!(get_week_day(2002, 3, 19), "Tuesday"); + } +} diff --git a/src/math/elliptic_curve.rs b/src/math/elliptic_curve.rs new file mode 100644 index 00000000000..641b759dc8b --- /dev/null +++ b/src/math/elliptic_curve.rs @@ -0,0 +1,405 @@ +use std::collections::HashSet; +use std::fmt; +use std::hash::{Hash, Hasher}; +use std::ops::{Add, Neg, Sub}; + +use crate::math::field::{Field, PrimeField}; +use crate::math::quadratic_residue::legendre_symbol; + +/// Elliptic curve defined by `y^2 = x^3 + Ax + B` over a prime field `F` of +/// characteristic != 2, 3 +/// +/// The coefficients of the elliptic curve are the constant parameters `A` and `B`. +/// +/// Points form an abelian group with the neutral element [`EllipticCurve::infinity`]. The points +/// are represented via affine coordinates ([`EllipticCurve::new`]) except for the points +/// at infinity ([`EllipticCurve::infinity`]). +/// +/// # Example +/// +/// ``` +/// use the_algorithms_rust::math::{EllipticCurve, PrimeField}; +/// type E = EllipticCurve, 1, 0>; +/// let P = E::new(0, 0).expect("not on curve E"); +/// assert_eq!(P + P, E::infinity()); +/// ``` +#[derive(Clone, Copy)] +pub struct EllipticCurve { + infinity: bool, + x: F, + y: F, +} + +impl EllipticCurve { + /// Point at infinity also the neutral element of the group + pub fn infinity() -> Self { + Self::check_invariants(); + Self { + infinity: true, + x: F::ZERO, + y: F::ZERO, + } + } + + /// Affine point + /// + /// + /// Return `None` if the coordinates are not on the curve + pub fn new(x: impl Into, y: impl Into) -> Option { + Self::check_invariants(); + let x = x.into(); + let y = y.into(); + if Self::contains(x, y) { + Some(Self { + infinity: false, + x, + y, + }) + } else { + None + } + } + + /// Return `true` if this is the point at infinity + pub fn is_infinity(&self) -> bool { + self.infinity + } + + /// The affine x-coordinate of the point + pub fn x(&self) -> &F { + &self.x + } + + /// The affine y-coordinate of the point + pub fn y(&self) -> &F { + &self.y + } + + /// The discrimant of the elliptic curve + pub const fn discriminant() -> i64 { + // Note: we can't return an element of F here, because it is not + // possible to declare a trait function as const (cf. + // ) + (-16 * (4 * A * A * A + 27 * B * B)) % (F::CHARACTERISTIC as i64) + } + + fn contains(x: F, y: F) -> bool { + y * y == x * x * x + x.integer_mul(A) + F::ONE.integer_mul(B) + } + + const fn check_invariants() { + assert!(F::CHARACTERISTIC != 2); + assert!(F::CHARACTERISTIC != 3); + assert!(Self::discriminant() != 0); + } +} + +/// Elliptic curve methods over a prime field +impl EllipticCurve, A, B> { + /// Naive calculation of points via enumeration + // TODO: Implement via generators + pub fn points() -> impl Iterator { + std::iter::once(Self::infinity()).chain( + PrimeField::elements() + .flat_map(|x| PrimeField::elements().filter_map(move |y| Self::new(x, y))), + ) + } + + /// Number of points on the elliptic curve over `F`, that is, `#E(F)` + pub fn cardinality() -> usize { + // TODO: implement counting for big P + Self::cardinality_counted_legendre() + } + + /// Number of points on the elliptic curve over `F`, that is, `#E(F)` + /// + /// We simply count the number of points for each x coordinate and sum them up. + /// For that, we first precompute the table of all squares in `F`. + /// + /// Time complexity: O(P)
+ /// Space complexity: O(P) + /// + /// Only fast for small fields. + pub fn cardinality_counted_table() -> usize { + let squares: HashSet<_> = PrimeField::

::elements().map(|x| x * x).collect(); + 1 + PrimeField::elements() + .map(|x| { + let y_square = x * x * x + x.integer_mul(A) + PrimeField::from_integer(B); + if y_square == PrimeField::ZERO { + 1 + } else if squares.contains(&y_square) { + 2 + } else { + 0 + } + }) + .sum::() + } + + /// Number of points on the elliptic curve over `F`, that is, `#E(F)` + /// + /// We count the number of points for each x coordinate by using the [Legendre symbol] _(X | + /// P)_: + /// + /// _1 + (x^3 + Ax + B | P),_ + /// + /// The total number of points is then: + /// + /// _#E(F) = 1 + P + Σ_x (x^3 + Ax + B | P)_ for _x_ in _F_. + /// + /// Time complexity: O(P)
+ /// Space complexity: O(1) + /// + /// Only fast for small fields. + /// + /// [Legendre symbol]: https://en.wikipedia.org/wiki/Legendre_symbol + pub fn cardinality_counted_legendre() -> usize { + let cardinality: i64 = 1 + + P as i64 + + PrimeField::

::elements() + .map(|x| { + let y_square = x * x * x + x.integer_mul(A) + PrimeField::from_integer(B); + let y_square_int = y_square.to_integer(); + legendre_symbol(y_square_int, P) + }) + .sum::(); + cardinality + .try_into() + .expect("invalid legendre cardinality") + } +} + +/// Group law +impl Add for EllipticCurve { + type Output = Self; + + fn add(self, p: Self) -> Self::Output { + if self.infinity { + p + } else if p.infinity { + self + } else if self.x == p.x && self.y == -p.y { + // mirrored + Self::infinity() + } else { + let slope = if self.x != p.x { + (self.y - p.y) / (self.x - p.x) + } else { + ((self.x * self.x).integer_mul(3) + F::from_integer(A)) / self.y.integer_mul(2) + }; + let x = slope * slope - self.x - p.x; + let y = -self.y + slope * (self.x - x); + Self::new(x, y).expect("elliptic curve group law failed") + } + } +} + +/// Inverse +impl Neg for EllipticCurve { + type Output = Self; + + fn neg(self) -> Self::Output { + if self.infinity { + self + } else { + Self::new(self.x, -self.y).expect("elliptic curves are x-axis symmetric") + } + } +} + +/// Difference +impl Sub for EllipticCurve { + type Output = Self; + + fn sub(self, p: Self) -> Self::Output { + self + (-p) + } +} + +/// Debug representation via projective coordinates +impl fmt::Debug for EllipticCurve { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.infinity { + f.write_str("(0:0:1)") + } else { + write!(f, "({:?}:{:?}:1)", self.x, self.y) + } + } +} + +/// Equality of the elliptic curve points (short-circuit at infinity) +impl PartialEq for EllipticCurve { + fn eq(&self, other: &Self) -> bool { + (self.infinity && other.infinity) + || (self.infinity == other.infinity && self.x == other.x && self.y == other.y) + } +} + +impl Eq for EllipticCurve {} + +impl Hash for EllipticCurve { + fn hash(&self, state: &mut H) { + if self.infinity { + state.write_u8(1); + F::ZERO.hash(state); + F::ZERO.hash(state); + } else { + state.write_u8(0); + self.x.hash(state); + self.y.hash(state); + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + use std::time::Instant; + + use super::*; + + #[test] + #[should_panic] + fn test_char_2_panic() { + EllipticCurve::, -1, 1>::infinity(); + } + + #[test] + #[should_panic] + fn test_char_3_panic() { + EllipticCurve::, -1, 1>::infinity(); + } + + #[test] + #[should_panic] + fn test_singular_panic() { + EllipticCurve::, 0, 0>::infinity(); + } + + #[test] + fn e_5_1_0_group_table() { + type F = PrimeField<5>; + type E = EllipticCurve; + + assert_eq!(E::points().count(), 4); + let [a, b, c, d] = [ + E::new(0, 0).unwrap(), + E::infinity(), + E::new(2, 0).unwrap(), + E::new(3, 0).unwrap(), + ]; + + assert_eq!(a + a, b); + assert_eq!(a + b, a); + assert_eq!(a + c, d); + assert_eq!(a + d, c); + assert_eq!(b + a, a); + assert_eq!(b + b, b); + assert_eq!(b + c, c); + assert_eq!(b + d, d); + assert_eq!(c + a, d); + assert_eq!(c + b, c); + assert_eq!(c + c, b); + assert_eq!(c + d, a); + assert_eq!(d + a, c); + assert_eq!(d + b, d); + assert_eq!(d + c, a); + assert_eq!(d + d, b); + } + + #[test] + fn group_law() { + fn test() { + type E = EllipticCurve, 1, 0>; + + let o = E::

::infinity(); + assert_eq!(-o, o); + + let points: Vec<_> = E::points().collect(); + for &p in &points { + assert_eq!(p + (-p), o); // inverse + assert_eq!((-p) + p, o); // inverse + assert_eq!(p - p, o); //inverse + assert_eq!(p + o, p); // neutral + assert_eq!(o + p, p); //neutral + + for &q in &points { + assert_eq!(p + q, q + p); // commutativity + + // associativity + for &s in &points { + assert_eq!((p + q) + s, p + (q + s)); + } + } + } + } + test::<5>(); + test::<7>(); + test::<11>(); + test::<13>(); + test::<17>(); + test::<19>(); + test::<23>(); + } + + #[test] + fn cardinality() { + fn test(expected: usize) { + type E = EllipticCurve, 1, 0>; + assert_eq!(E::

::cardinality(), expected); + assert_eq!(E::

::cardinality_counted_table(), expected); + assert_eq!(E::

::cardinality_counted_legendre(), expected); + } + test::<5>(4); + test::<7>(8); + test::<11>(12); + test::<13>(20); + test::<17>(16); + test::<19>(20); + test::<23>(24); + } + + #[test] + #[ignore = "slow test for measuring time"] + fn cardinality_perf() { + const P: u64 = 1000003; + type E = EllipticCurve, 1, 0>; + const EXPECTED: usize = 1000004; + + let now = Instant::now(); + assert_eq!(E::cardinality_counted_table(), EXPECTED); + println!("cardinality_counted_table : {:?}", now.elapsed()); + let now = Instant::now(); + assert_eq!(E::cardinality_counted_legendre(), EXPECTED); + println!("cardinality_counted_legendre : {:?}", now.elapsed()); + } + + #[test] + #[ignore = "slow test showing that cadinality is not yet feasible to compute for a large prime"] + fn cardinality_large_prime() { + const P: u64 = 2_u64.pow(63) - 25; // largest prime fitting into i64 + type E = EllipticCurve, 1, 0>; + const EXPECTED: usize = 9223372041295506260; + + let now = Instant::now(); + assert_eq!(E::cardinality(), EXPECTED); + println!("cardinality: {:?}", now.elapsed()); + } + + #[test] + fn test_points() { + type F = PrimeField<5>; + type E = EllipticCurve; + + let points: HashSet<_> = E::points().collect(); + let expected: HashSet<_> = [ + E::infinity(), + E::new(0, 0).unwrap(), + E::new(2, 0).unwrap(), + E::new(3, 0).unwrap(), + ] + .into_iter() + .collect(); + assert_eq!(points, expected); + } +} diff --git a/src/math/euclidean_distance.rs b/src/math/euclidean_distance.rs new file mode 100644 index 00000000000..0be7459f042 --- /dev/null +++ b/src/math/euclidean_distance.rs @@ -0,0 +1,41 @@ +// Author : cyrixninja +// Calculate the Euclidean distance between two vectors +// Wikipedia : https://en.wikipedia.org/wiki/Euclidean_distance + +pub fn euclidean_distance(vector_1: &Vector, vector_2: &Vector) -> f64 { + // Calculate the Euclidean distance using the provided vectors. + let squared_sum: f64 = vector_1 + .iter() + .zip(vector_2.iter()) + .map(|(&a, &b)| (a - b).powi(2)) + .sum(); + + squared_sum.sqrt() +} + +type Vector = Vec; + +#[cfg(test)] +mod tests { + use super::*; + + // Define a test function for the euclidean_distance function. + #[test] + fn test_euclidean_distance() { + // First test case: 2D vectors + let vec1_2d = vec![1.0, 2.0]; + let vec2_2d = vec![4.0, 6.0]; + + // Calculate the Euclidean distance + let result_2d = euclidean_distance(&vec1_2d, &vec2_2d); + assert_eq!(result_2d, 5.0); + + // Second test case: 4D vectors + let vec1_4d = vec![1.0, 2.0, 3.0, 4.0]; + let vec2_4d = vec![5.0, 6.0, 7.0, 8.0]; + + // Calculate the Euclidean distance + let result_4d = euclidean_distance(&vec1_4d, &vec2_4d); + assert_eq!(result_4d, 8.0); + } +} diff --git a/src/math/exponential_linear_unit.rs b/src/math/exponential_linear_unit.rs new file mode 100644 index 00000000000..f6143c97881 --- /dev/null +++ b/src/math/exponential_linear_unit.rs @@ -0,0 +1,60 @@ +//! # Exponential Linear Unit (ELU) Function +//! +//! The `exponential_linear_unit` function computes the Exponential Linear Unit (ELU) values of a given vector +//! of f64 numbers with a specified alpha parameter. +//! +//! The ELU activation function is commonly used in neural networks as an alternative to the Leaky ReLU function. +//! It introduces a small negative slope (controlled by the alpha parameter) for the negative input values and has +//! an exponential growth for positive values, which can help mitigate the vanishing gradient problem. +//! +//! ## Formula +//! +//! For a given input vector `x` and an alpha parameter `alpha`, the ELU function computes the output +//! `y` as follows: +//! +//! `y_i = { x_i if x_i >= 0, alpha * (e^x_i - 1) if x_i < 0 }` +//! +//! Where `e` is the mathematical constant (approximately 2.71828). +//! +//! ## Exponential Linear Unit (ELU) Function Implementation +//! +//! This implementation takes a reference to a vector of f64 values and an alpha parameter, and returns a new +//! vector with the ELU transformation applied to each element. The input vector is not altered. +//! + +use std::f64::consts::E; + +pub fn exponential_linear_unit(vector: &Vec, alpha: f64) -> Vec { + let mut _vector = vector.to_owned(); + + for value in &mut _vector { + if value < &mut 0. { + *value *= alpha * (E.powf(*value) - 1.); + } + } + + _vector +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_exponential_linear_unit() { + let test_vector = vec![-10., 2., -3., 4., -5., 10., 0.05]; + let alpha = 0.01; + assert_eq!( + exponential_linear_unit(&test_vector, alpha), + vec![ + 0.09999546000702375, + 2.0, + 0.028506387948964082, + 4.0, + 0.049663102650045726, + 10.0, + 0.05 + ] + ); + } +} diff --git a/src/math/factorial.rs b/src/math/factorial.rs new file mode 100644 index 00000000000..b6fbc831450 --- /dev/null +++ b/src/math/factorial.rs @@ -0,0 +1,68 @@ +use num_bigint::BigUint; +use num_traits::One; +#[allow(unused_imports)] +use std::str::FromStr; + +pub fn factorial(number: u64) -> u64 { + // Base cases: 0! and 1! are both equal to 1 + if number == 0 || number == 1 { + 1 + } else { + // Calculate factorial using the product of the range from 2 to the given number (inclusive) + (2..=number).product() + } +} + +pub fn factorial_recursive(n: u64) -> u64 { + // Base cases: 0! and 1! are both equal to 1 + if n == 0 || n == 1 { + 1 + } else { + // Calculate factorial recursively by multiplying the current number with factorial of (n - 1) + n * factorial_recursive(n - 1) + } +} + +pub fn factorial_bigmath(num: u32) -> BigUint { + let mut result: BigUint = One::one(); + for i in 1..=num { + result *= i; + } + result +} + +// Module for tests +#[cfg(test)] +mod tests { + use super::*; + + // Test cases for the iterative factorial function + #[test] + fn test_factorial() { + assert_eq!(factorial(0), 1); + assert_eq!(factorial(1), 1); + assert_eq!(factorial(6), 720); + assert_eq!(factorial(10), 3628800); + assert_eq!(factorial(20), 2432902008176640000); + } + + // Test cases for the recursive factorial function + #[test] + fn test_factorial_recursive() { + assert_eq!(factorial_recursive(0), 1); + assert_eq!(factorial_recursive(1), 1); + assert_eq!(factorial_recursive(6), 720); + assert_eq!(factorial_recursive(10), 3628800); + assert_eq!(factorial_recursive(20), 2432902008176640000); + } + + #[test] + fn basic_factorial() { + assert_eq!(factorial_bigmath(10), BigUint::from_str("3628800").unwrap()); + assert_eq!( + factorial_bigmath(50), + BigUint::from_str("30414093201713378043612608166064768844377641568960512000000000000") + .unwrap() + ); + } +} diff --git a/src/math/factors.rs b/src/math/factors.rs new file mode 100644 index 00000000000..5131642dffa --- /dev/null +++ b/src/math/factors.rs @@ -0,0 +1,46 @@ +/// Factors are natural numbers which can divide a given natural number to give a remainder of zero +/// Hence 1, 2, 3 and 6 are all factors of 6, as they divide the number 6 completely, +/// leaving no remainder. +/// This function is to list out all the factors of a given number 'n' + +pub fn factors(number: u64) -> Vec { + let mut factors: Vec = Vec::new(); + + for i in 1..=((number as f64).sqrt() as u64) { + if number % i == 0 { + factors.push(i); + if i != number / i { + factors.push(number / i); + } + } + } + + factors.sort(); + factors +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn prime_number() { + assert_eq!(vec![1, 59], factors(59)); + } + + #[test] + fn highly_composite_number() { + assert_eq!( + vec![ + 1, 2, 3, 4, 5, 6, 8, 9, 10, 12, 15, 18, 20, 24, 30, 36, 40, 45, 60, 72, 90, 120, + 180, 360 + ], + factors(360) + ); + } + + #[test] + fn composite_number() { + assert_eq!(vec![1, 3, 23, 69], factors(69)); + } +} diff --git a/src/math/fast_fourier_transform.rs b/src/math/fast_fourier_transform.rs index 0aea5bc42a0..6ed81e7db6a 100644 --- a/src/math/fast_fourier_transform.rs +++ b/src/math/fast_fourier_transform.rs @@ -194,7 +194,7 @@ mod tests { let mut fft = fast_fourier_transform(&polynomial, &permutation); fft.iter_mut().for_each(|num| *num *= *num); let ifft = inverse_fast_fourier_transform(&fft, &permutation); - let expected = vec![1.0, 2.0, 1.0, 4.0, 4.0, 0.0, 4.0, 0.0, 0.0]; + let expected = [1.0, 2.0, 1.0, 4.0, 4.0, 0.0, 4.0, 0.0, 0.0]; for (x, y) in ifft.iter().zip(expected.iter()) { assert!(almost_equal(*x, *y, EPSILON)); } @@ -212,12 +212,9 @@ mod tests { let mut fft = fast_fourier_transform(&polynomial, &permutation); fft.iter_mut().for_each(|num| *num *= *num); let ifft = inverse_fast_fourier_transform(&fft, &permutation); - let mut expected = vec![0.0; n << 1]; - for i in 0..((n << 1) - 1) { - expected[i] = std::cmp::min(i + 1, (n << 1) - 1 - i) as f64; - } - for (x, y) in ifft.iter().zip(expected.iter()) { - assert!(almost_equal(*x, *y, EPSILON)); + let expected = (0..((n << 1) - 1)).map(|i| std::cmp::min(i + 1, (n << 1) - 1 - i) as f64); + for (&x, y) in ifft.iter().zip(expected) { + assert!(almost_equal(x, y, EPSILON)); } } } diff --git a/src/math/field.rs b/src/math/field.rs new file mode 100644 index 00000000000..9fb26965cd6 --- /dev/null +++ b/src/math/field.rs @@ -0,0 +1,333 @@ +use core::fmt; +use std::hash::{Hash, Hasher}; +use std::ops::{Add, Div, Mul, Neg, Sub}; + +/// A field +/// +/// +pub trait Field: + Neg + + Add + + Sub + + Mul + + Div + + Eq + + Copy + + fmt::Debug +{ + const CHARACTERISTIC: u64; + const ZERO: Self; + const ONE: Self; + + /// Multiplicative inverse + fn inverse(self) -> Self; + + /// Z-mod structure + fn integer_mul(self, a: i64) -> Self; + fn from_integer(a: i64) -> Self { + Self::ONE.integer_mul(a) + } + + /// Iterate over all elements in this field + /// + /// The iterator finishes only for finite fields. + type ElementsIter: Iterator; + fn elements() -> Self::ElementsIter; +} + +/// Prime field of order `P`, that is, finite field `GF(P) = ℤ/Pℤ` +/// +/// Only primes `P` <= 2^63 - 25 are supported, because the field elements are represented by `i64`. +// TODO: Extend field implementation for any prime `P` by e.g. using u32 blocks. +#[derive(Clone, Copy)] +pub struct PrimeField { + a: i64, +} + +impl PrimeField

{ + /// Reduces the representation into the range [0, p) + fn reduce(self) -> Self { + let Self { a } = self; + let p: i64 = P.try_into().expect("module not fitting into signed 64 bit"); + let a = a.rem_euclid(p); + assert!(a >= 0); + Self { a } + } + + /// Returns the positive integer in the range [0, p) representing this element + pub fn to_integer(&self) -> u64 { + self.reduce().a as u64 + } +} + +impl From for PrimeField

{ + fn from(a: i64) -> Self { + Self { a } + } +} + +impl PartialEq for PrimeField

{ + fn eq(&self, other: &Self) -> bool { + self.reduce().a == other.reduce().a + } +} + +impl Eq for PrimeField

{} + +impl Neg for PrimeField

{ + type Output = Self; + + fn neg(self) -> Self::Output { + Self { a: -self.a } + } +} + +impl Add for PrimeField

{ + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self { + a: self.a.checked_add(rhs.a).unwrap_or_else(|| { + let x = self.reduce(); + let y = rhs.reduce(); + x.a + y.a + }), + } + } +} + +impl Sub for PrimeField

{ + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self { + a: self.a.checked_sub(rhs.a).unwrap_or_else(|| { + let x = self.reduce(); + let y = rhs.reduce(); + x.a - y.a + }), + } + } +} + +impl Mul for PrimeField

{ + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + Self { + a: self.a.checked_mul(rhs.a).unwrap_or_else(|| { + let x = self.reduce(); + let y = rhs.reduce(); + x.a * y.a + }), + } + } +} + +impl Div for PrimeField

{ + type Output = Self; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: Self) -> Self::Output { + self * rhs.inverse() + } +} + +impl fmt::Debug for PrimeField

{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let x = self.reduce(); + write!(f, "{}", x.reduce().a) + } +} + +impl Field for PrimeField

{ + const CHARACTERISTIC: u64 = P; + const ZERO: Self = Self { a: 0 }; + const ONE: Self = Self { a: 1 }; + + fn inverse(self) -> Self { + assert_ne!(self.a, 0); + Self { + a: mod_inverse( + self.a, + P.try_into().expect("module not fitting into signed 64 bit"), + ), + } + } + + fn integer_mul(self, mut n: i64) -> Self { + if n == 0 { + return Self::ZERO; + } + let mut x = self; + if n < 0 { + x = -x; + n = -n; + } + let mut y = Self::ZERO; + while n > 1 { + if n % 2 == 1 { + y = y + x; + n -= 1; + } + x = x + x; + n /= 2; + } + x + y + } + + type ElementsIter = PrimeFieldElementsIter

; + + fn elements() -> Self::ElementsIter { + PrimeFieldElementsIter::default() + } +} + +#[derive(Default)] +pub struct PrimeFieldElementsIter { + x: i64, +} + +impl Iterator for PrimeFieldElementsIter

{ + type Item = PrimeField

; + + fn next(&mut self) -> Option { + if self.x as u64 == P { + None + } else { + let res = PrimeField::from_integer(self.x); + self.x += 1; + Some(res) + } + } +} + +impl Hash for PrimeField

{ + fn hash(&self, state: &mut H) { + let Self { a } = self.reduce(); + state.write_i64(a); + } +} + +// TODO: should we use extended_euclidean_algorithm adjusted to i64? +fn mod_inverse(mut a: i64, mut b: i64) -> i64 { + let mut s = 1; + let mut t = 0; + let step = |x, y, q| (y, x - q * y); + while b != 0 { + let q = a / b; + (a, b) = step(a, b, q); + (s, t) = step(s, t, q); + } + assert!(a == 1 || a == -1); + a * s +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use super::*; + + #[test] + fn test_field_elements() { + fn test() { + let expected: HashSet> = (0..P as i64).map(Into::into).collect(); + for gen in 1..P - 1 { + // every field element != 0 generates the whole field additively + let gen = PrimeField::from(gen as i64); + let mut generated: HashSet> = std::iter::once(gen).collect(); + let mut x = gen; + for _ in 0..P { + x = x + gen; + generated.insert(x); + } + assert_eq!(generated, expected); + } + } + test::<5>(); + test::<7>(); + test::<11>(); + test::<13>(); + test::<17>(); + test::<19>(); + test::<23>(); + test::<71>(); + test::<101>(); + } + + #[test] + fn large_prime_field() { + const P: u64 = 2_u64.pow(63) - 25; // largest prime fitting into i64 + type F = PrimeField

; + let x = F::from(P as i64 - 1); + let y = x.inverse(); + assert_eq!(x * y, F::ONE); + } + + #[test] + fn inverse() { + fn test() { + for x in -7..7 { + let x = PrimeField::

::from(x); + if x != PrimeField::ZERO { + // multiplicative + assert_eq!(x.inverse() * x, PrimeField::ONE); + assert_eq!(x * x.inverse(), PrimeField::ONE); + assert_eq!((x.inverse().a * x.a).rem_euclid(P as i64), 1); + assert_eq!(x / x, PrimeField::ONE); + } + // additive + assert_eq!(x + (-x), PrimeField::ZERO); + assert_eq!((-x) + x, PrimeField::ZERO); + assert_eq!(x - x, PrimeField::ZERO); + } + } + test::<5>(); + test::<7>(); + test::<11>(); + test::<13>(); + test::<17>(); + test::<19>(); + test::<23>(); + test::<71>(); + test::<101>(); + } + + #[test] + fn test_mod_inverse() { + assert_eq!(mod_inverse(-6, 7), 1); + assert_eq!(mod_inverse(-5, 7), -3); + assert_eq!(mod_inverse(-4, 7), -2); + assert_eq!(mod_inverse(-3, 7), 2); + assert_eq!(mod_inverse(-2, 7), 3); + assert_eq!(mod_inverse(-1, 7), -1); + assert_eq!(mod_inverse(1, 7), 1); + assert_eq!(mod_inverse(2, 7), -3); + assert_eq!(mod_inverse(3, 7), -2); + assert_eq!(mod_inverse(4, 7), 2); + assert_eq!(mod_inverse(5, 7), 3); + assert_eq!(mod_inverse(6, 7), -1); + } + + #[test] + fn integer_mul() { + type F = PrimeField<23>; + for x in 0..23 { + let x = F { a: x }; + for n in -7..7 { + assert_eq!(x.integer_mul(n), F { a: n * x.a }); + } + } + } + + #[test] + fn from_integer() { + type F = PrimeField<23>; + for x in -100..100 { + assert_eq!(F::from_integer(x), F { a: x }); + } + assert_eq!(F::from(0), F::ZERO); + assert_eq!(F::from(1), F::ONE); + } +} diff --git a/src/math/frizzy_number.rs b/src/math/frizzy_number.rs new file mode 100644 index 00000000000..22f154a412c --- /dev/null +++ b/src/math/frizzy_number.rs @@ -0,0 +1,61 @@ +/// This Rust program calculates the n-th Frizzy number for a given base. +/// A Frizzy number is defined as the n-th number that is a sum of powers +/// of the given base, with the powers corresponding to the binary representation +/// of n. + +/// The `get_nth_frizzy` function takes two arguments: +/// * `base` - The base whose n-th sum of powers is required. +/// * `n` - Index from ascending order of the sum of powers of the base. + +/// It returns the n-th sum of powers of the base. + +/// # Example +/// To find the Frizzy number with a base of 3 and n equal to 4: +/// - Ascending order of sums of powers of 3: 3^0 = 1, 3^1 = 3, 3^1 + 3^0 = 4, 3^2 + 3^0 = 9. +/// - The answer is 9. +/// +/// # Arguments +/// * `base` - The base whose n-th sum of powers is required. +/// * `n` - Index from ascending order of the sum of powers of the base. +/// +/// # Returns +/// The n-th sum of powers of the base. + +pub fn get_nth_frizzy(base: i32, mut n: i32) -> f64 { + let mut final1 = 0.0; + let mut i = 0; + while n > 0 { + final1 += (base.pow(i) as f64) * ((n % 2) as f64); + i += 1; + n /= 2; + } + final1 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_nth_frizzy() { + // Test case 1: base = 3, n = 4 + // 3^2 + 3^0 = 9 + assert_eq!(get_nth_frizzy(3, 4), 9.0); + + // Test case 2: base = 2, n = 5 + // 2^2 + 2^0 = 5 + assert_eq!(get_nth_frizzy(2, 5), 5.0); + + // Test case 3: base = 4, n = 3 + // 4^1 + 4^0 = 5 + assert_eq!(get_nth_frizzy(4, 3), 5.0); + + // Test case 4: base = 5, n = 2 + // 5^1 + 5^0 = 5 + assert_eq!(get_nth_frizzy(5, 2), 5.0); + + // Test case 5: base = 6, n = 1 + // 6^0 = 1 + assert_eq!(get_nth_frizzy(6, 1), 1.0); + } +} diff --git a/src/math/gaussian_elimination.rs b/src/math/gaussian_elimination.rs index c2f98cb98f5..1370b15ddd7 100644 --- a/src/math/gaussian_elimination.rs +++ b/src/math/gaussian_elimination.rs @@ -37,8 +37,8 @@ fn echelon(matrix: &mut [Vec], i: usize, j: usize) { let size = matrix.len(); if matrix[i][i] == 0f32 { } else { - let factor = matrix[j + 1][i] as f32 / matrix[i][i] as f32; - (i..size + 1).for_each(|k| { + let factor = matrix[j + 1][i] / matrix[i][i]; + (i..=size).for_each(|k| { matrix[j + 1][k] -= factor * matrix[i][k]; }); } @@ -48,10 +48,10 @@ fn eliminate(matrix: &mut [Vec], i: usize) { let size = matrix.len(); if matrix[i][i] == 0f32 { } else { - for j in (1..i + 1).rev() { - let factor = matrix[j - 1][i] as f32 / matrix[i][i] as f32; - for k in (0..size + 1).rev() { - matrix[j - 1][k] -= factor * matrix[i][k] as f32; + for j in (1..=i).rev() { + let factor = matrix[j - 1][i] / matrix[i][i]; + for k in (0..=size).rev() { + matrix[j - 1][k] -= factor * matrix[i][k]; } } } diff --git a/src/math/gaussian_error_linear_unit.rs b/src/math/gaussian_error_linear_unit.rs new file mode 100644 index 00000000000..c7e52cca6a4 --- /dev/null +++ b/src/math/gaussian_error_linear_unit.rs @@ -0,0 +1,58 @@ +//! # Gaussian Error Linear Unit (GELU) Function +//! +//! The `gaussian_error_linear_unit` function computes the Gaussian Error Linear Unit (GELU) values of a given f64 number or a vector of f64 numbers. +//! +//! GELU is an activation function used in neural networks that introduces a smooth approximation of the rectifier function (ReLU). +//! It is defined using the Gaussian cumulative distribution function and can help mitigate the vanishing gradient problem. +//! +//! ## Formula +//! +//! For a given input value `x`, the GELU function computes the output `y` as follows: +//! +//! `y = 0.5 * (1.0 + tanh(2.0 / sqrt(π) * (x + 0.044715 * x^3)))` +//! +//! Where `tanh` is the hyperbolic tangent function and `π` is the mathematical constant (approximately 3.14159). +//! +//! ## Gaussian Error Linear Unit (GELU) Function Implementation +//! +//! This implementation takes either a single f64 value or a reference to a vector of f64 values and returns the GELU transformation applied to each element. The input values are not altered. +//! +use std::f64::consts::E; +use std::f64::consts::PI; + +fn tanh(vector: f64) -> f64 { + (2. / (1. + E.powf(-2. * vector.to_owned()))) - 1. +} + +pub fn gaussian_error_linear_unit(vector: &Vec) -> Vec { + let mut gelu_vec = vector.to_owned(); + for value in &mut gelu_vec { + *value = *value + * 0.5 + * (1. + tanh(f64::powf(2. / PI, 0.5) * (*value + 0.044715 * value.powf(3.)))); + } + + gelu_vec +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_gaussian_error_linear_unit() { + let test_vector = vec![-10., 2., -3., 4., -5., 10., 0.05]; + assert_eq!( + gaussian_error_linear_unit(&test_vector), + vec![ + -0.0, + 1.9545976940877752, + -0.0036373920817729943, + 3.9999297540518075, + -2.2917961972623857e-7, + 10.0, + 0.025996938238622008 + ] + ); + } +} diff --git a/src/math/geometric_series.rs b/src/math/geometric_series.rs new file mode 100644 index 00000000000..e9631f09ff5 --- /dev/null +++ b/src/math/geometric_series.rs @@ -0,0 +1,49 @@ +// Author : cyrixninja +// Wikipedia : https://en.wikipedia.org/wiki/Geometric_series +// Calculate a geometric series. + +pub fn geometric_series(nth_term: f64, start_term_a: f64, common_ratio_r: f64) -> Vec { + let mut series = Vec::new(); + let mut multiple = 1.0; + + for _ in 0..(nth_term as i32) { + series.push(start_term_a * multiple); + multiple *= common_ratio_r; + } + + series +} + +#[cfg(test)] +mod tests { + use super::*; + + fn assert_approx_eq(a: f64, b: f64) { + let epsilon = 1e-10; + assert!((a - b).abs() < epsilon, "Expected {a}, found {b}"); + } + + #[test] + fn test_geometric_series() { + let result = geometric_series(4.0, 2.0, 2.0); + assert_eq!(result.len(), 4); + assert_approx_eq(result[0], 2.0); + assert_approx_eq(result[1], 4.0); + assert_approx_eq(result[2], 8.0); + assert_approx_eq(result[3], 16.0); + + let result = geometric_series(4.1, 2.1, 2.1); + assert_eq!(result.len(), 4); + assert_approx_eq(result[0], 2.1); + assert_approx_eq(result[1], 4.41); + assert_approx_eq(result[2], 9.261); + assert_approx_eq(result[3], 19.4481); + + let result = geometric_series(4.0, -2.0, 2.0); + assert_eq!(result.len(), 4); + assert_approx_eq(result[0], -2.0); + assert_approx_eq(result[1], -4.0); + assert_approx_eq(result[2], -8.0); + assert_approx_eq(result[3], -16.0); + } +} diff --git a/src/math/greatest_common_divisor.rs b/src/math/greatest_common_divisor.rs index af1f44d8009..8c88434bade 100644 --- a/src/math/greatest_common_divisor.rs +++ b/src/math/greatest_common_divisor.rs @@ -4,6 +4,7 @@ /// /// Wikipedia reference: https://en.wikipedia.org/wiki/Greatest_common_divisor /// gcd(a, b) = gcd(a, -b) = gcd(-a, b) = gcd(-a, -b) by definition of divisibility +use std::cmp::{max, min}; pub fn greatest_common_divisor_recursive(a: i64, b: i64) -> i64 { if a == 0 { @@ -22,6 +23,28 @@ pub fn greatest_common_divisor_iterative(mut a: i64, mut b: i64) -> i64 { b.abs() } +pub fn greatest_common_divisor_stein(a: u64, b: u64) -> u64 { + match ((a, b), (a & 1, b & 1)) { + // gcd(x, x) = x + ((x, y), _) if x == y => y, + // gcd(x, 0) = gcd(0, x) = x + ((0, x), _) | ((x, 0), _) => x, + // gcd(x, y) = gcd(x / 2, y) if x is even and y is odd + // gcd(x, y) = gcd(x, y / 2) if y is even and x is odd + ((x, y), (0, 1)) | ((y, x), (1, 0)) => greatest_common_divisor_stein(x >> 1, y), + // gcd(x, y) = 2 * gcd(x / 2, y / 2) if x and y are both even + ((x, y), (0, 0)) => greatest_common_divisor_stein(x >> 1, y >> 1) << 1, + // if x and y are both odd + ((x, y), (1, 1)) => { + // then gcd(x, y) = gcd((x - y) / 2, y) if x >= y + // gcd(x, y) = gcd((y - x) / 2, x) otherwise + let (x, y) = (min(x, y), max(x, y)); + greatest_common_divisor_stein((y - x) >> 1, x) + } + _ => unreachable!(), + } +} + #[cfg(test)] mod tests { use super::*; @@ -44,6 +67,15 @@ mod tests { assert_eq!(greatest_common_divisor_iterative(27, 12), 3); } + #[test] + fn positive_number_stein() { + assert_eq!(greatest_common_divisor_stein(4, 16), 4); + assert_eq!(greatest_common_divisor_stein(16, 4), 4); + assert_eq!(greatest_common_divisor_stein(3, 5), 1); + assert_eq!(greatest_common_divisor_stein(40, 40), 40); + assert_eq!(greatest_common_divisor_stein(27, 12), 3); + } + #[test] fn negative_number_recursive() { assert_eq!(greatest_common_divisor_recursive(-32, -8), 8); diff --git a/src/math/huber_loss.rs b/src/math/huber_loss.rs new file mode 100644 index 00000000000..7bb304b81b0 --- /dev/null +++ b/src/math/huber_loss.rs @@ -0,0 +1,43 @@ +//! # Huber Loss Function +//! +//! The `huber_loss` function calculates the Huber loss, which is a robust loss function used in machine learning, particularly in regression problems. +//! +//! Huber loss combines the benefits of mean squared error (MSE) and mean absolute error (MAE). It behaves like MSE when the difference between actual and predicted values is small (less than a specified `delta`), and like MAE when the difference is large. +//! +//! ## Formula +//! +//! For a pair of actual and predicted values, represented as vectors `actual` and `predicted`, and a specified `delta` value, the Huber loss is calculated as: +//! +//! - If the absolute difference between `actual[i]` and `predicted[i]` is less than or equal to `delta`, the loss is `0.5 * (actual[i] - predicted[i])^2`. +//! - If the absolute difference is greater than `delta`, the loss is `delta * |actual[i] - predicted[i]| - 0.5 * delta`. +//! +//! The total loss is the sum of individual losses over all elements. +//! +//! ## Huber Loss Function Implementation +//! +//! This implementation takes two references to vectors of f64 values, `actual` and `predicted`, and a `delta` value. It returns the Huber loss between them, providing a robust measure of dissimilarity between actual and predicted values. +//! +pub fn huber_loss(actual: &[f64], predicted: &[f64], delta: f64) -> f64 { + let mut loss: Vec = Vec::new(); + for (a, p) in actual.iter().zip(predicted.iter()) { + if (a - p).abs() <= delta { + loss.push(0.5 * (a - p).powf(2.)); + } else { + loss.push(delta * (a - p).abs() - (0.5 * delta)); + } + } + + loss.iter().sum() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_huber_loss() { + let test_vector_actual = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + let test_vector = vec![5.0, 7.0, 9.0, 11.0, 13.0]; + assert_eq!(huber_loss(&test_vector_actual, &test_vector, 1.0), 27.5); + } +} diff --git a/src/math/infix_to_postfix.rs b/src/math/infix_to_postfix.rs new file mode 100644 index 00000000000..123c792779d --- /dev/null +++ b/src/math/infix_to_postfix.rs @@ -0,0 +1,94 @@ +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum InfixToPostfixError { + UnknownCharacter(char), + UnmatchedParent, +} + +/// Function to convert [infix expression](https://en.wikipedia.org/wiki/Infix_notation) to [postfix expression](https://en.wikipedia.org/wiki/Reverse_Polish_notation) +pub fn infix_to_postfix(infix: &str) -> Result { + let mut postfix = String::new(); + let mut stack: Vec = Vec::new(); + + // Define the precedence of operators + let precedence = |op: char| -> u8 { + match op { + '+' | '-' => 1, + '*' | '/' => 2, + '^' => 3, + _ => 0, + } + }; + + for token in infix.chars() { + match token { + c if c.is_alphanumeric() => { + postfix.push(c); + } + '(' => { + stack.push('('); + } + ')' => { + while let Some(top) = stack.pop() { + if top == '(' { + break; + } + postfix.push(top); + } + } + '+' | '-' | '*' | '/' | '^' => { + while let Some(top) = stack.last() { + if *top == '(' || precedence(*top) < precedence(token) { + break; + } + postfix.push(stack.pop().unwrap()); + } + stack.push(token); + } + other => return Err(InfixToPostfixError::UnknownCharacter(other)), + } + } + + while let Some(top) = stack.pop() { + if top == '(' { + return Err(InfixToPostfixError::UnmatchedParent); + } + + postfix.push(top); + } + + Ok(postfix) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_infix_to_postfix { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (infix, expected) = $inputs; + assert_eq!(infix_to_postfix(infix), expected) + } + )* + } + } + + test_infix_to_postfix! { + single_symbol: ("x", Ok(String::from("x"))), + simple_sum: ("x+y", Ok(String::from("xy+"))), + multiply_sum_left: ("x*(y+z)", Ok(String::from("xyz+*"))), + multiply_sum_right: ("(x+y)*z", Ok(String::from("xy+z*"))), + multiply_two_sums: ("(a+b)*(c+d)", Ok(String::from("ab+cd+*"))), + product_and_power: ("a*b^c", Ok(String::from("abc^*"))), + power_and_product: ("a^b*c", Ok(String::from("ab^c*"))), + product_of_powers: ("(a*b)^c", Ok(String::from("ab*c^"))), + product_in_exponent: ("a^(b*c)", Ok(String::from("abc*^"))), + regular_0: ("a-b+c-d*e", Ok(String::from("ab-c+de*-"))), + regular_1: ("a*(b+c)+d/(e+f)", Ok(String::from("abc+*def+/+"))), + regular_2: ("(a-b+c)*(d+e*f)", Ok(String::from("ab-c+def*+*"))), + unknown_character: ("(a-b)*#", Err(InfixToPostfixError::UnknownCharacter('#'))), + unmatched_paren: ("((a-b)", Err(InfixToPostfixError::UnmatchedParent)), + } +} diff --git a/src/math/interest.rs b/src/math/interest.rs new file mode 100644 index 00000000000..6347f211abe --- /dev/null +++ b/src/math/interest.rs @@ -0,0 +1,55 @@ +// value of e +use std::f64::consts::E; + +// function to calculate simple interest +pub fn simple_interest(principal: f64, annual_rate: f64, years: f64) -> (f64, f64) { + let interest = principal * annual_rate * years; + let value = principal * (1.0 + (annual_rate * years)); + + println!("Interest earned: {interest}"); + println!("Future value: {value}"); + + (interest, value) +} + +// function to calculate compound interest compounded over periods or continuously +pub fn compound_interest(principal: f64, annual_rate: f64, years: f64, period: Option) -> f64 { + // checks if the period is None type, if so calculates continuous compounding interest + let value = if period.is_none() { + principal * E.powf(annual_rate * years) + } else { + // unwraps the option type or defaults to 0 if None type and assigns it to prim_period + let prim_period: f64 = period.unwrap_or(0.0); + // checks if the period is less than or equal to zero + if prim_period <= 0.0_f64 { + return f64::NAN; + } + principal * (1.0 + (annual_rate / prim_period).powf(prim_period * years)) + }; + println!("Future value: {value}"); + value +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_simple() { + let x = 385.65_f64 * 0.03_f64 * 5.0_f64; + let y = 385.65_f64 * (1.0 + (0.03_f64 * 5.0_f64)); + assert_eq!(simple_interest(385.65_f64, 0.03_f64, 5.0_f64), (x, y)); + } + #[test] + fn test_compounding() { + let x = 385.65_f64 * E.powf(0.03_f64 * 5.0_f64); + assert_eq!(compound_interest(385.65_f64, 0.03_f64, 5.0_f64, None), x); + + let y = 385.65_f64 * (1.0 + (0.03_f64 / 5.0_f64).powf(5.0_f64 * 5.0_f64)); + assert_eq!( + compound_interest(385.65_f64, 0.03_f64, 5.0_f64, Some(5.0_f64)), + y + ); + assert!(compound_interest(385.65_f64, 0.03_f64, 5.0_f64, Some(-5.0_f64)).is_nan()); + assert!(compound_interest(385.65_f64, 0.03_f64, 5.0_f64, Some(0.0_f64)).is_nan()); + } +} diff --git a/src/math/interpolation.rs b/src/math/interpolation.rs new file mode 100644 index 00000000000..b5bb39ce5d6 --- /dev/null +++ b/src/math/interpolation.rs @@ -0,0 +1,90 @@ +/// In mathematics, linear interpolation is a method of curve fitting +/// using linear polynomials to construct new data points within the range of a discrete set of known data points. +/// Formula: y = y0 + (x - x0) * (y1 - y0) / (x1 - x0) +/// Source: https://en.wikipedia.org/wiki/Linear_interpolation +/// point0 and point1 are a tuple containing x and y values we want to interpolate between +pub fn linear_interpolation(x: f64, point0: (f64, f64), point1: (f64, f64)) -> f64 { + point0.1 + (x - point0.0) * (point1.1 - point0.1) / (point1.0 - point0.0) +} + +/// In numerical analysis, the Lagrange interpolating polynomial +/// is the unique polynomial of lowest degree that interpolates a given set of data. +/// +/// Source: https://en.wikipedia.org/wiki/Lagrange_polynomial +/// Source: https://mathworld.wolfram.com/LagrangeInterpolatingPolynomial.html +/// x is the point we wish to interpolate +/// defined points are a vector of tuples containing known x and y values of our function +pub fn lagrange_polynomial_interpolation(x: f64, defined_points: &Vec<(f64, f64)>) -> f64 { + let mut defined_x_values: Vec = Vec::new(); + let mut defined_y_values: Vec = Vec::new(); + + for (x, y) in defined_points { + defined_x_values.push(*x); + defined_y_values.push(*y); + } + + let mut sum = 0.0; + + for y_index in 0..defined_y_values.len() { + let mut numerator = 1.0; + let mut denominator = 1.0; + for x_index in 0..defined_x_values.len() { + if y_index == x_index { + continue; + } + denominator *= defined_x_values[y_index] - defined_x_values[x_index]; + numerator *= x - defined_x_values[x_index]; + } + + sum += numerator / denominator * defined_y_values[y_index]; + } + sum +} + +#[cfg(test)] +mod tests { + + use std::assert_eq; + + use super::*; + #[test] + fn test_linear_intepolation() { + let point1 = (0.0, 0.0); + let point2 = (1.0, 1.0); + let point3 = (2.0, 2.0); + + let x1 = 0.5; + let x2 = 1.5; + + let y1 = linear_interpolation(x1, point1, point2); + let y2 = linear_interpolation(x2, point2, point3); + + assert_eq!(y1, x1); + assert_eq!(y2, x2); + assert_eq!( + linear_interpolation(x1, point1, point2), + linear_interpolation(x1, point2, point1) + ); + } + + #[test] + fn test_lagrange_polynomial_interpolation() { + // defined values for x^2 function + let defined_points = vec![(0.0, 0.0), (1.0, 1.0), (2.0, 4.0), (3.0, 9.0)]; + + // check for equality + assert_eq!(lagrange_polynomial_interpolation(1.0, &defined_points), 1.0); + assert_eq!(lagrange_polynomial_interpolation(2.0, &defined_points), 4.0); + assert_eq!(lagrange_polynomial_interpolation(3.0, &defined_points), 9.0); + + //other + assert_eq!( + lagrange_polynomial_interpolation(0.5, &defined_points), + 0.25 + ); + assert_eq!( + lagrange_polynomial_interpolation(2.5, &defined_points), + 6.25 + ); + } +} diff --git a/src/math/interquartile_range.rs b/src/math/interquartile_range.rs new file mode 100644 index 00000000000..fed92b77709 --- /dev/null +++ b/src/math/interquartile_range.rs @@ -0,0 +1,85 @@ +// Author : cyrixninja +// Interquartile Range : An implementation of interquartile range (IQR) which is a measure of statistical +// dispersion, which is the spread of the data. +// Wikipedia Reference : https://en.wikipedia.org/wiki/Interquartile_range + +use std::cmp::Ordering; + +pub fn find_median(numbers: &[f64]) -> f64 { + let mut numbers = numbers.to_vec(); + numbers.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); + + let length = numbers.len(); + let mid = length / 2; + + if length % 2 == 0 { + f64::midpoint(numbers[mid - 1], numbers[mid]) + } else { + numbers[mid] + } +} + +pub fn interquartile_range(numbers: &[f64]) -> f64 { + if numbers.is_empty() { + panic!("Error: The list is empty. Please provide a non-empty list."); + } + + let mut numbers = numbers.to_vec(); + numbers.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); + + let length = numbers.len(); + let mid = length / 2; + let (q1, q3) = if length % 2 == 0 { + let first_half = &numbers[0..mid]; + let second_half = &numbers[mid..length]; + (find_median(first_half), find_median(second_half)) + } else { + let first_half = &numbers[0..mid]; + let second_half = &numbers[mid + 1..length]; + (find_median(first_half), find_median(second_half)) + }; + + q3 - q1 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_find_median() { + let numbers1 = vec![1.0, 2.0, 2.0, 3.0, 4.0]; + assert_eq!(find_median(&numbers1), 2.0); + + let numbers2 = vec![1.0, 2.0, 2.0, 3.0, 4.0, 4.0]; + assert_eq!(find_median(&numbers2), 2.5); + + let numbers3 = vec![-1.0, 2.0, 0.0, 3.0, 4.0, -4.0]; + assert_eq!(find_median(&numbers3), 1.0); + + let numbers4 = vec![1.1, 2.2, 2.0, 3.3, 4.4, 4.0]; + assert_eq!(find_median(&numbers4), 2.75); + } + + #[test] + fn test_interquartile_range() { + let numbers1 = vec![4.0, 1.0, 2.0, 3.0, 2.0]; + assert_eq!(interquartile_range(&numbers1), 2.0); + + let numbers2 = vec![-2.0, -7.0, -10.0, 9.0, 8.0, 4.0, -67.0, 45.0]; + assert_eq!(interquartile_range(&numbers2), 17.0); + + let numbers3 = vec![-2.1, -7.1, -10.1, 9.1, 8.1, 4.1, -67.1, 45.1]; + assert_eq!(interquartile_range(&numbers3), 17.2); + + let numbers4 = vec![0.0, 0.0, 0.0, 0.0, 0.0]; + assert_eq!(interquartile_range(&numbers4), 0.0); + } + + #[test] + #[should_panic(expected = "Error: The list is empty. Please provide a non-empty list.")] + fn test_interquartile_range_empty_list() { + let numbers: Vec = vec![]; + interquartile_range(&numbers); + } +} diff --git a/src/math/karatsuba_multiplication.rs b/src/math/karatsuba_multiplication.rs index e8d1d3bbfe6..4547faf9119 100644 --- a/src/math/karatsuba_multiplication.rs +++ b/src/math/karatsuba_multiplication.rs @@ -35,9 +35,8 @@ fn _multiply(num1: i128, num2: i128) -> i128 { } fn normalize(mut a: String, n: usize) -> String { - for (counter, _) in (a.len()..n).enumerate() { - a.insert(counter, '0'); - } + let padding = n.saturating_sub(a.len()); + a.insert_str(0, &"0".repeat(padding)); a } #[cfg(test)] diff --git a/src/math/leaky_relu.rs b/src/math/leaky_relu.rs new file mode 100644 index 00000000000..954a1023d6d --- /dev/null +++ b/src/math/leaky_relu.rs @@ -0,0 +1,47 @@ +//! # Leaky ReLU Function +//! +//! The `leaky_relu` function computes the Leaky Rectified Linear Unit (ReLU) values of a given vector +//! of f64 numbers with a specified alpha parameter. +//! +//! The Leaky ReLU activation function is commonly used in neural networks to introduce a small negative +//! slope (controlled by the alpha parameter) for the negative input values, preventing neurons from dying +//! during training. +//! +//! ## Formula +//! +//! For a given input vector `x` and an alpha parameter `alpha`, the Leaky ReLU function computes the output +//! `y` as follows: +//! +//! `y_i = { x_i if x_i >= 0, alpha * x_i if x_i < 0 }` +//! +//! ## Leaky ReLU Function Implementation +//! +//! This implementation takes a reference to a vector of f64 values and an alpha parameter, and returns a new +//! vector with the Leaky ReLU transformation applied to each element. The input vector is not altered. +//! +pub fn leaky_relu(vector: &Vec, alpha: f64) -> Vec { + let mut _vector = vector.to_owned(); + + for value in &mut _vector { + if value < &mut 0. { + *value *= alpha; + } + } + + _vector +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_leaky_relu() { + let test_vector = vec![-10., 2., -3., 4., -5., 10., 0.05]; + let alpha = 0.01; + assert_eq!( + leaky_relu(&test_vector, alpha), + vec![-0.1, 2.0, -0.03, 4.0, -0.05, 10.0, 0.05] + ); + } +} diff --git a/src/math/least_square_approx.rs b/src/math/least_square_approx.rs new file mode 100644 index 00000000000..bc12d8e766e --- /dev/null +++ b/src/math/least_square_approx.rs @@ -0,0 +1,115 @@ +/// Least Square Approximation

+/// Function that returns a polynomial which very closely passes through the given points (in 2D) +/// +/// The result is made of coeficients, in descending order (from x^degree to free term) +/// +/// Parameters: +/// +/// points -> coordinates of given points +/// +/// degree -> degree of the polynomial +/// +pub fn least_square_approx + Copy, U: Into + Copy>( + points: &[(T, U)], + degree: i32, +) -> Option> { + use nalgebra::{DMatrix, DVector}; + + /* Used for rounding floating numbers */ + fn round_to_decimals(value: f64, decimals: i32) -> f64 { + let multiplier = 10f64.powi(decimals); + (value * multiplier).round() / multiplier + } + + /* Casting the data parsed to this function to f64 (as some points can have decimals) */ + let vals: Vec<(f64, f64)> = points + .iter() + .map(|(x, y)| ((*x).into(), (*y).into())) + .collect(); + /* Because of collect we need the Copy Trait for T and U */ + + /* Computes the sums in the system of equations */ + let mut sums = Vec::::new(); + for i in 1..=(2 * degree + 1) { + sums.push(vals.iter().map(|(x, _)| x.powi(i - 1)).sum()); + } + + /* Compute the free terms column vector */ + let mut free_col = Vec::::new(); + for i in 1..=(degree + 1) { + free_col.push(vals.iter().map(|(x, y)| y * (x.powi(i - 1))).sum()); + } + let b = DVector::from_row_slice(&free_col); + + /* Create and fill the system's matrix */ + let size = (degree + 1) as usize; + let a = DMatrix::from_fn(size, size, |i, j| sums[degree as usize + i - j]); + + /* Solve the system of equations: A * x = b */ + match a.qr().solve(&b) { + Some(x) => { + let rez: Vec = x.iter().map(|x| round_to_decimals(*x, 5)).collect(); + Some(rez) + } + None => None, //<-- The system cannot be solved (badly conditioned system's matrix) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ten_points_1st_degree() { + let points = vec![ + (5.3, 7.8), + (4.9, 8.1), + (6.1, 6.9), + (4.7, 8.3), + (6.5, 7.7), + (5.6, 7.0), + (5.8, 8.2), + (4.5, 8.0), + (6.3, 7.2), + (5.1, 8.4), + ]; + + assert_eq!( + least_square_approx(&points, 1), + Some(vec![-0.49069, 10.44898]) + ); + } + + #[test] + fn eight_points_5th_degree() { + let points = vec![ + (4f64, 8f64), + (8f64, 2f64), + (1f64, 7f64), + (10f64, 3f64), + (11.0, 0.0), + (7.0, 3.0), + (10.0, 1.0), + (13.0, 13.0), + ]; + + assert_eq!( + least_square_approx(&points, 5), + Some(vec![ + 0.00603, -0.21304, 2.79929, -16.53468, 40.29473, -19.35771 + ]) + ); + } + + #[test] + fn four_points_2nd_degree() { + let points = vec![ + (2.312, 8.345344), + (-2.312, 8.345344), + (-0.7051, 3.49716601), + (0.7051, 3.49716601), + ]; + + assert_eq!(least_square_approx(&points, 2), Some(vec![1.0, 0.0, 3.0])); + } +} diff --git a/src/math/linear_sieve.rs b/src/math/linear_sieve.rs index 5341df0a120..8fb49d26a75 100644 --- a/src/math/linear_sieve.rs +++ b/src/math/linear_sieve.rs @@ -76,6 +76,12 @@ impl LinearSieve { } } +impl Default for LinearSieve { + fn default() -> Self { + Self::new() + } +} + #[cfg(test)] mod tests { use super::LinearSieve; @@ -109,7 +115,7 @@ mod tests { let factorization = ls.factorize(i).unwrap(); let mut product = 1usize; for (idx, p) in factorization.iter().enumerate() { - assert!(ls.primes.binary_search(&p).is_ok()); + assert!(ls.primes.binary_search(p).is_ok()); product *= *p; if idx > 0 { assert!(*p >= factorization[idx - 1]); diff --git a/src/math/logarithm.rs b/src/math/logarithm.rs new file mode 100644 index 00000000000..c94e8247d11 --- /dev/null +++ b/src/math/logarithm.rs @@ -0,0 +1,77 @@ +use std::f64::consts::E; + +/// Calculates the **logbase(x)** +/// +/// Parameters: +///

-> base: base of log +///

-> x: value for which log shall be evaluated +///

-> tol: tolerance; the precision of the approximation (submultiples of 10-1) +/// +/// Advisable to use **std::f64::consts::*** for specific bases (like 'e') +pub fn log, U: Into>(base: U, x: T, tol: f64) -> f64 { + let mut rez = 0f64; + let mut argument: f64 = x.into(); + let usable_base: f64 = base.into(); + + if argument <= 0f64 || usable_base <= 0f64 { + println!("Log does not support negative argument or negative base."); + f64::NAN + } else if argument < 1f64 && usable_base == E { + argument -= 1f64; + let mut prev_rez = 1f64; + let mut step: i32 = 1; + /* + For x in (0, 1) and base 'e', the function is using MacLaurin Series: + ln(|1 + x|) = Σ "(-1)^n-1 * x^n / n", for n = 1..inf + Substituting x with x-1 yields: + ln(|x|) = Σ "(-1)^n-1 * (x-1)^n / n" + */ + while (prev_rez - rez).abs() > tol { + prev_rez = rez; + rez += (-1f64).powi(step - 1) * argument.powi(step) / step as f64; + step += 1; + } + + rez + } else { + /* Using the basic change of base formula for log */ + let ln_x = argument.ln(); + let ln_base = usable_base.ln(); + + ln_x / ln_base + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn basic() { + assert_eq!(log(E, E, 0.0), 1.0); + assert_eq!(log(E, E.powi(100), 0.0), 100.0); + assert_eq!(log(10, 10000.0, 0.0), 4.0); + assert_eq!(log(234501.0, 1.0, 1.0), 0.0); + } + + #[test] + fn test_log_positive_base() { + assert_eq!(log(10.0, 100.0, 0.00001), 2.0); + assert_eq!(log(2.0, 8.0, 0.00001), 3.0); + } + + #[test] + fn test_log_zero_base() { + assert!(log(0.0, 100.0, 0.00001).is_nan()); + } + + #[test] + fn test_log_negative_base() { + assert!(log(-1.0, 100.0, 0.00001).is_nan()); + } + + #[test] + fn test_log_tolerance() { + assert_eq!(log(10.0, 100.0, 1e-10), 2.0); + } +} diff --git a/src/math/lucas_series.rs b/src/math/lucas_series.rs new file mode 100644 index 00000000000..cbc6f48fc40 --- /dev/null +++ b/src/math/lucas_series.rs @@ -0,0 +1,60 @@ +// Author : cyrixninja +// Lucas Series : Function to get the Nth Lucas Number +// Wikipedia Reference : https://en.wikipedia.org/wiki/Lucas_number +// Other References : https://the-algorithms.com/algorithm/lucas-series?lang=python + +pub fn recursive_lucas_number(n: u32) -> u32 { + match n { + 0 => 2, + 1 => 1, + _ => recursive_lucas_number(n - 1) + recursive_lucas_number(n - 2), + } +} + +pub fn dynamic_lucas_number(n: u32) -> u32 { + let mut a = 2; + let mut b = 1; + + for _ in 0..n { + let temp = a; + a = b; + b += temp; + } + + a +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_lucas_number { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (n, expected) = $inputs; + assert_eq!(recursive_lucas_number(n), expected); + assert_eq!(dynamic_lucas_number(n), expected); + } + )* + } + } + + test_lucas_number! { + input_0: (0, 2), + input_1: (1, 1), + input_2: (2, 3), + input_3: (3, 4), + input_4: (4, 7), + input_5: (5, 11), + input_6: (6, 18), + input_7: (7, 29), + input_8: (8, 47), + input_9: (9, 76), + input_10: (10, 123), + input_15: (15, 1364), + input_20: (20, 15127), + input_25: (25, 167761), + } +} diff --git a/src/math/matrix_ops.rs b/src/math/matrix_ops.rs index 21113c3f2cf..29fa722f46f 100644 --- a/src/math/matrix_ops.rs +++ b/src/math/matrix_ops.rs @@ -1,186 +1,569 @@ -// Basic matrix operations using row vectors wrapped in column vectors as matrices. -// Supports i32, should be interchangeable for other types. +// Basic matrix operations using a Matrix type with internally uses +// a vector representation to store matrix elements. +// Generic using the MatrixElement trait, which can be implemented with +// the matrix_element_type_def macro. // Wikipedia reference: https://www.wikiwand.com/en/Matrix_(mathematics) +use std::ops::{Add, AddAssign, Index, IndexMut, Mul, Sub}; -pub fn matrix_add(summand0: &[Vec], summand1: &[Vec]) -> Vec> { - // Add two matrices of identical dimensions - let mut result: Vec> = vec![]; - if summand0.len() != summand1.len() { - panic!("Matrix dimensions do not match"); +// Define macro to build a matrix idiomatically +#[macro_export] +macro_rules! matrix { + [$([$($x:expr),* $(,)*]),* $(,)*] => {{ + Matrix::from(vec![$(vec![$($x,)*],)*]) + }}; +} + +// Define a trait "alias" for suitable matrix elements +pub trait MatrixElement: + Add + Sub + Mul + AddAssign + Copy + From +{ +} + +// Define a macro to implement the MatrixElement trait for desired types +#[macro_export] +macro_rules! matrix_element_type_def { + ($T: ty) => { + // Implement trait for type + impl MatrixElement for $T {} + + // Defining left-hand multiplication in this form + // prevents errors for uncovered types + impl Mul<&Matrix<$T>> for $T { + type Output = Matrix<$T>; + + fn mul(self, rhs: &Matrix<$T>) -> Self::Output { + rhs * self + } + } + }; + + ($T: ty, $($Ti: ty),+) => { + // Decompose type definitions recursively + matrix_element_type_def!($T); + matrix_element_type_def!($($Ti),+); + }; +} + +matrix_element_type_def!(i16, i32, i64, i128, u8, u16, u32, u128, f32, f64); + +#[derive(PartialEq, Eq, Debug)] +pub struct Matrix { + data: Vec, + rows: usize, + cols: usize, +} + +impl Matrix { + pub fn new(data: Vec, rows: usize, cols: usize) -> Self { + // Build a matrix from the internal vector representation + if data.len() != rows * cols { + panic!("Inconsistent data and dimensions combination for matrix") + } + Matrix { data, rows, cols } } - for row in 0..summand0.len() { - if summand0[row].len() != summand1[row].len() { - panic!("Matrix dimensions do not match"); + + pub fn zero(rows: usize, cols: usize) -> Self { + // Build a matrix of zeros + Matrix { + data: vec![0.into(); rows * cols], + rows, + cols, + } + } + + pub fn identity(len: usize) -> Self { + // Build an identity matrix + let mut identity = Matrix::zero(len, len); + // Diagonal of ones + for i in 0..len { + identity[[i, i]] = 1.into(); } - result.push(vec![]); - for column in 0..summand1[0].len() { - result[row].push(summand0[row][column] + summand1[row][column]); + identity + } + + pub fn transpose(&self) -> Self { + // Transpose a matrix of any size + let mut result = Matrix::zero(self.cols, self.rows); + for i in 0..self.rows { + for j in 0..self.cols { + result[[i, j]] = self[[j, i]]; + } } + result } - result } -pub fn matrix_subtract(minuend: &[Vec], subtrahend: &[Vec]) -> Vec> { - // Subtract one matrix from another. They need to have identical dimensions. - let mut result: Vec> = vec![]; - if minuend.len() != subtrahend.len() { - panic!("Matrix dimensions do not match"); +impl Index<[usize; 2]> for Matrix { + type Output = T; + + fn index(&self, index: [usize; 2]) -> &Self::Output { + let [i, j] = index; + if i >= self.rows || j >= self.cols { + panic!("Matrix index out of bounds"); + } + + &self.data[(self.cols * i) + j] + } +} + +impl IndexMut<[usize; 2]> for Matrix { + fn index_mut(&mut self, index: [usize; 2]) -> &mut Self::Output { + let [i, j] = index; + if i >= self.rows || j >= self.cols { + panic!("Matrix index out of bounds"); + } + + &mut self.data[(self.cols * i) + j] } - for row in 0..minuend.len() { - if minuend[row].len() != subtrahend[row].len() { +} + +impl Add<&Matrix> for &Matrix { + type Output = Matrix; + + fn add(self, rhs: &Matrix) -> Self::Output { + // Add two matrices. They need to have identical dimensions. + if self.rows != rhs.rows || self.cols != rhs.cols { panic!("Matrix dimensions do not match"); } - result.push(vec![]); - for column in 0..subtrahend[0].len() { - result[row].push(minuend[row][column] - subtrahend[row][column]); + + let mut result = Matrix::zero(self.rows, self.cols); + for i in 0..self.rows { + for j in 0..self.cols { + result[[i, j]] = self[[i, j]] + rhs[[i, j]]; + } } + result } - result } -// Disable cargo clippy warnings about needless range loops. -// As the iterating variable is used as index while multiplying, -// using the item itself would defeat the variables purpose. -#[allow(clippy::needless_range_loop)] -pub fn matrix_multiply(multiplier: &[Vec], multiplicand: &[Vec]) -> Vec> { - // Multiply two matching matrices. The multiplier needs to have the same amount - // of columns as the multiplicand has rows. - let mut result: Vec> = vec![]; - let mut temp; - // Using variable to compare lenghts of rows in multiplicand later - let row_right_length = multiplicand[0].len(); - for row_left in 0..multiplier.len() { - if multiplier[row_left].len() != multiplicand.len() { +impl Sub for &Matrix { + type Output = Matrix; + + fn sub(self, rhs: Self) -> Self::Output { + // Subtract one matrix from another. They need to have identical dimensions. + if self.rows != rhs.rows || self.cols != rhs.cols { panic!("Matrix dimensions do not match"); } - result.push(vec![]); - for column_right in 0..multiplicand[0].len() { - temp = 0; - for row_right in 0..multiplicand.len() { - if row_right_length != multiplicand[row_right].len() { - // If row is longer than a previous row cancel operation with error - panic!("Matrix dimensions do not match"); - } - temp += multiplier[row_left][row_right] * multiplicand[row_right][column_right]; + + let mut result = Matrix::zero(self.rows, self.cols); + for i in 0..self.rows { + for j in 0..self.cols { + result[[i, j]] = self[[i, j]] - rhs[[i, j]]; + } + } + result + } +} + +impl Mul for &Matrix { + type Output = Matrix; + + fn mul(self, rhs: Self) -> Self::Output { + // Multiply two matrices. The multiplier needs to have the same amount + // of columns as the multiplicand has rows. + if self.cols != rhs.rows { + panic!("Matrix dimensions do not match"); + } + + let mut result = Matrix::zero(self.rows, rhs.cols); + for i in 0..self.rows { + for j in 0..rhs.cols { + result[[i, j]] = { + let mut sum = 0.into(); + for k in 0..self.cols { + sum += self[[i, k]] * rhs[[k, j]]; + } + sum + }; } - result[row_left].push(temp); } + result } - result } -pub fn matrix_transpose(matrix: &[Vec]) -> Vec> { - // Transpose a matrix of any size - let mut result: Vec> = vec![Vec::with_capacity(matrix.len()); matrix[0].len()]; - for row in matrix { - for col in 0..row.len() { - result[col].push(row[col]); +impl Mul for &Matrix { + type Output = Matrix; + + fn mul(self, rhs: T) -> Self::Output { + // Multiply a matrix of any size with a scalar + let mut result = Matrix::zero(self.rows, self.cols); + for i in 0..self.rows { + for j in 0..self.cols { + result[[i, j]] = rhs * self[[i, j]]; + } } + result } - result } -pub fn matrix_scalar_multiplication(matrix: &[Vec], scalar: i32) -> Vec> { - // Multiply a matrix of any size with a scalar - let mut result: Vec> = vec![Vec::with_capacity(matrix.len()); matrix[0].len()]; - for row in 0..matrix.len() { - for column in 0..matrix[row].len() { - result[row].push(scalar * matrix[row][column]); +impl From>> for Matrix { + fn from(v: Vec>) -> Self { + let rows = v.len(); + let cols = v.first().map_or(0, |row| row.len()); + + // Ensure consistent dimensions + for row in v.iter().skip(1) { + if row.len() != cols { + panic!("Invalid matrix dimensions. Columns must be consistent."); + } + } + if rows != 0 && cols == 0 { + panic!("Invalid matrix dimensions. Multiple empty rows"); } + + let data = v.into_iter().flat_map(|row| row.into_iter()).collect(); + Self::new(data, rows, cols) } - result } #[cfg(test)] +// rustfmt skipped to prevent unformatting matrix definitions to a single line +#[rustfmt::skip] mod tests { - use super::matrix_add; - use super::matrix_multiply; - use super::matrix_scalar_multiplication; - use super::matrix_subtract; - use super::matrix_transpose; - - #[test] - fn test_add() { - let input0: Vec> = vec![vec![1, 0, 1], vec![0, 2, 0], vec![5, 0, 1]]; - let input1: Vec> = vec![vec![1, 0, 0], vec![0, 1, 0], vec![0, 0, 1]]; - let input_wrong0: Vec> = vec![vec![1, 0, 0, 4], vec![0, 1, 0], vec![0, 0, 1]]; - let input_wrong1: Vec> = - vec![vec![1, 0, 0], vec![0, 1, 0], vec![0, 0, 1], vec![1, 1, 1]]; - let input_wrong2: Vec> = vec![vec![]]; - let exp_result: Vec> = vec![vec![2, 0, 1], vec![0, 3, 0], vec![5, 0, 2]]; - assert_eq!(matrix_add(&input0, &input1), exp_result); - let result0 = std::panic::catch_unwind(|| matrix_add(&input0, &input_wrong0)); - assert!(result0.is_err()); - let result1 = std::panic::catch_unwind(|| matrix_add(&input0, &input_wrong1)); - assert!(result1.is_err()); - let result2 = std::panic::catch_unwind(|| matrix_add(&input0, &input_wrong2)); - assert!(result2.is_err()); - } - - #[test] - fn test_subtract() { - let input0: Vec> = vec![vec![1, 0, 1], vec![0, 2, 0], vec![5, 0, 1]]; - let input1: Vec> = vec![vec![1, 0, 0], vec![0, 1, 3], vec![0, 0, 1]]; - let input_wrong0: Vec> = vec![vec![1, 0, 0, 4], vec![0, 1, 0], vec![0, 0, 1]]; - let input_wrong1: Vec> = - vec![vec![1, 0, 0], vec![0, 1, 0], vec![0, 0, 1], vec![1, 1, 1]]; - let input_wrong2: Vec> = vec![vec![]]; - let exp_result: Vec> = vec![vec![0, 0, 1], vec![0, 1, -3], vec![5, 0, 0]]; - assert_eq!(matrix_subtract(&input0, &input1), exp_result); - let result0 = std::panic::catch_unwind(|| matrix_subtract(&input0, &input_wrong0)); - assert!(result0.is_err()); - let result1 = std::panic::catch_unwind(|| matrix_subtract(&input0, &input_wrong1)); - assert!(result1.is_err()); - let result2 = std::panic::catch_unwind(|| matrix_subtract(&input0, &input_wrong2)); - assert!(result2.is_err()); - } - - #[test] - fn test_multiply() { - let input0: Vec> = - vec![vec![1, 2, 3], vec![4, 2, 6], vec![3, 4, 1], vec![2, 4, 8]]; - let input1: Vec> = vec![vec![1, 3, 3, 2], vec![7, 6, 2, 1], vec![3, 4, 2, 1]]; - let input_wrong0: Vec> = vec![ - vec![1, 3, 3, 2, 4, 6, 6], - vec![7, 6, 2, 1], - vec![3, 4, 2, 1], - ]; - let input_wrong1: Vec> = vec![ - vec![1, 3, 3, 2], - vec![7, 6, 2, 1], - vec![3, 4, 2, 1], - vec![3, 4, 2, 1], - ]; - let exp_result: Vec> = vec![ - vec![24, 27, 13, 7], - vec![36, 48, 28, 16], - vec![34, 37, 19, 11], - vec![54, 62, 30, 16], - ]; - assert_eq!(matrix_multiply(&input0, &input1), exp_result); - let result0 = std::panic::catch_unwind(|| matrix_multiply(&input0, &input_wrong0)); - assert!(result0.is_err()); - let result1 = std::panic::catch_unwind(|| matrix_multiply(&input0, &input_wrong1)); - assert!(result1.is_err()); - } - - #[test] - fn test_transpose() { - let input0: Vec> = vec![vec![1, 0, 1], vec![0, 2, 0], vec![5, 0, 1]]; - let input1: Vec> = vec![vec![3, 4, 2], vec![0, 1, 3], vec![3, 1, 1]]; - let exp_result1: Vec> = vec![vec![1, 0, 5], vec![0, 2, 0], vec![1, 0, 1]]; - let exp_result2: Vec> = vec![vec![3, 0, 3], vec![4, 1, 1], vec![2, 3, 1]]; - assert_eq!(matrix_transpose(&input0), exp_result1); - assert_eq!(matrix_transpose(&input1), exp_result2); - } - - #[test] - fn test_matrix_scalar_multiplication() { - let input0: Vec> = vec![vec![3, 2, 2], vec![0, 2, 0], vec![5, 4, 1]]; - let input1: Vec> = vec![vec![1, 0, 0], vec![0, 1, 0], vec![0, 0, 1]]; - let exp_result1: Vec> = vec![vec![9, 6, 6], vec![0, 6, 0], vec![15, 12, 3]]; - let exp_result2: Vec> = vec![vec![3, 0, 0], vec![0, 3, 0], vec![0, 0, 3]]; - assert_eq!(matrix_scalar_multiplication(&input0, 3), exp_result1); - assert_eq!(matrix_scalar_multiplication(&input1, 3), exp_result2); + use super::Matrix; + use std::panic; + + const DELTA: f64 = 1e-3; + + macro_rules! assert_f64_eq { + ($a:expr, $b:expr) => { + assert_eq!($a.data.len(), $b.data.len()); + if !$a + .data + .iter() + .zip($b.data.iter()) + .all(|(x, y)| (*x as f64 - *y as f64).abs() < DELTA) + { + panic!(); + } + }; + } + + #[test] + fn test_invalid_matrix() { + let result = panic::catch_unwind(|| matrix![ + [1, 0, 0, 4], + [0, 1, 0], + [0, 0, 1], + ]); + assert!(result.is_err()); + } + + #[test] + fn test_empty_matrix() { + let a: Matrix = matrix![]; + + let result = panic::catch_unwind(|| a[[0, 0]]); + assert!(result.is_err()); + } + + #[test] + fn test_zero_matrix() { + let a: Matrix = Matrix::zero(3, 5); + + let z = matrix![ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ]; + + assert_f64_eq!(a, z); + } + + #[test] + fn test_identity_matrix() { + let a: Matrix = Matrix::identity(5); + + let id = matrix![ + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ]; + + assert_f64_eq!(a, id); + } + + #[test] + fn test_invalid_add() { + let a = matrix![ + [1, 0, 1], + [0, 2, 0], + [5, 0, 1] + ]; + + let err = matrix![ + [1, 2], + [2, 4], + ]; + + let result = panic::catch_unwind(|| &a + &err); + assert!(result.is_err()); + } + + #[test] + fn test_add_i32() { + let a = matrix![ + [1, 0, 1], + [0, 2, 0], + [5, 0, 1] + ]; + + let b = matrix![ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + ]; + + let add = matrix![ + [2, 0, 1], + [0, 3, 0], + [5, 0, 2], + ]; + + assert_eq!(&a + &b, add); + } + + #[test] + fn test_add_f64() { + let a = matrix![ + [1.0, 2.0, 1.0], + [3.0, 2.0, 0.0], + [5.0, 0.0, 1.0], + ]; + + let b = matrix![ + [1.0, 10.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0], + ]; + + let add = matrix![ + [2.0, 12.0, 1.0], + [3.0, 3.0, 0.0], + [5.0, 0.0, 2.0], + ]; + + assert_f64_eq!(&a + &b, add); + } + + #[test] + fn test_invalid_sub() { + let a = matrix![ + [2, 3], + [10, 2], + ]; + + let err = matrix![ + [5, 6, 10], + [7, 2, 2], + [12, 0, 1], + ]; + + let result = panic::catch_unwind(|| &a - &err); + assert!(result.is_err()); + } + + #[test] + fn test_subtract_i32() { + let a = matrix![ + [1, 0, 1], + [0, 2, 0], + [5, 0, 1], + ]; + + let b = matrix![ + [1, 0, 0], + [0, 1, 3], + [0, 0, 1], + ]; + + let sub = matrix![ + [0, 0, 1], + [0, 1, -3], + [5, 0, 0], + ]; + + assert_eq!(&a - &b, sub); + } + + #[test] + fn test_subtract_f64() { + let a = matrix![ + [7.0, 2.0, 1.0], + [0.0, 3.0, 2.0], + [5.3, 8.8, std::f64::consts::PI], + ]; + + let b = matrix![ + [1.0, 0.0, 5.0], + [-2.0, 1.0, 3.0], + [0.0, 2.2, std::f64::consts::PI], + ]; + + let sub = matrix![ + [6.0, 2.0, -4.0], + [2.0, 2.0, -1.0], + [5.3, 6.6, 0.0], + ]; + + assert_f64_eq!(&a - &b, sub); + } + + #[test] + fn test_invalid_mul() { + let a = matrix![ + [1, 2, 3], + [4, 2, 6], + [3, 4, 1], + [2, 4, 8], + ]; + + let err = matrix![ + [1, 3, 3, 2], + [7, 6, 2, 1], + [3, 4, 2, 1], + [3, 4, 2, 1], + ]; + + let result = panic::catch_unwind(|| &a * &err); + assert!(result.is_err()); + } + + #[test] + fn test_mul_i32() { + let a = matrix![ + [1, 2, 3], + [4, 2, 6], + [3, 4, 1], + [2, 4, 8], + ]; + + let b = matrix![ + [1, 3, 3, 2], + [7, 6, 2, 1], + [3, 4, 2, 1], + ]; + + let mul = matrix![ + [24, 27, 13, 7], + [36, 48, 28, 16], + [34, 37, 19, 11], + [54, 62, 30, 16], + ]; + + assert_eq!(&a * &b, mul); + } + + #[test] + fn test_mul_f64() { + let a = matrix![ + [5.5, 2.9, 1.13, 9.0], + [0.0, 3.0, 11.0, 17.2], + [5.3, 8.8, 2.76, 3.3], + ]; + + let b = matrix![ + [1.0, 0.3, 5.0], + [-2.0, 1.0, 3.0], + [-3.6, 1.5, 3.0], + [0.0, 2.2, 2.0], + ]; + + let mul = matrix![ + [-4.368, 26.045, 57.59], + [-45.6, 57.34, 76.4], + [-22.236, 21.79, 67.78], + ]; + + assert_f64_eq!(&a * &b, mul); + } + + #[test] + fn test_transpose_i32() { + let a = matrix![ + [1, 0, 1], + [0, 2, 0], + [5, 0, 1], + ]; + + let t = matrix![ + [1, 0, 5], + [0, 2, 0], + [1, 0, 1], + ]; + + assert_eq!(a.transpose(), t); + } + + #[test] + fn test_transpose_f64() { + let a = matrix![ + [3.0, 4.0, 2.0], + [0.0, 1.0, 3.0], + [3.0, 1.0, 1.0], + ]; + + let t = matrix![ + [3.0, 0.0, 3.0], + [4.0, 1.0, 1.0], + [2.0, 3.0, 1.0], + ]; + + assert_eq!(a.transpose(), t); + } + + #[test] + fn test_matrix_scalar_zero_mul() { + let a = matrix![ + [3, 2, 2], + [0, 2, 0], + [5, 4, 1], + ]; + + let scalar = 0; + + let scalar_mul = Matrix::zero(3, 3); + + assert_eq!(scalar * &a, scalar_mul); + } + + #[test] + fn test_matrix_scalar_mul_i32() { + let a = matrix![ + [3, 2, 2], + [0, 2, 0], + [5, 4, 1], + ]; + + let scalar = 3; + + let scalar_mul = matrix![ + [9, 6, 6], + [0, 6, 0], + [15, 12, 3], + ]; + + assert_eq!(scalar * &a, scalar_mul); + } + + #[test] + fn test_matrix_scalar_mul_f64() { + let a = matrix![ + [3.2, 5.5, 9.2], + [1.1, 0.0, 2.3], + [0.3, 4.2, 0.0], + ]; + + let scalar = 1.5_f64; + + let scalar_mul = matrix![ + [4.8, 8.25, 13.8], + [1.65, 0.0, 3.45], + [0.45, 6.3, 0.0], + ]; + + assert_f64_eq!(scalar * &a, scalar_mul); } } diff --git a/src/math/miller_rabin.rs b/src/math/miller_rabin.rs index 650222e2a39..dbeeac5acbd 100644 --- a/src/math/miller_rabin.rs +++ b/src/math/miller_rabin.rs @@ -1,3 +1,7 @@ +use num_bigint::BigUint; +use num_traits::{One, ToPrimitive, Zero}; +use std::cmp::Ordering; + fn modulo_power(mut base: u64, mut power: u64, modulo: u64) -> u64 { base %= modulo; if base == 0 { @@ -43,8 +47,7 @@ pub fn miller_rabin(number: u64, bases: &[u64]) -> u64 { 0 => { panic!("0 is invalid input for Miller-Rabin. 0 is not prime by definition, but has no witness"); } - 2 => return 0, - 3 => return 0, + 2 | 3 => return 0, _ => return number, } } @@ -61,45 +64,174 @@ pub fn miller_rabin(number: u64, bases: &[u64]) -> u64 { 0 } +pub fn big_miller_rabin(number_ref: &BigUint, bases: &[u64]) -> u64 { + let number = number_ref.clone(); + + if BigUint::from(5u32).cmp(&number) == Ordering::Greater { + if number.eq(&BigUint::zero()) { + panic!("0 is invalid input for Miller-Rabin. 0 is not prime by definition, but has no witness"); + } else if number.eq(&BigUint::from(2u32)) || number.eq(&BigUint::from(3u32)) { + return 0; + } else { + return number.to_u64().unwrap(); + } + } + + if let Some(num) = number.to_u64() { + if bases.contains(&num) { + return 0; + } + } + + let num_minus_one = &number - BigUint::one(); + + let two_power: u64 = num_minus_one.trailing_zeros().unwrap(); + let odd_power: BigUint = &num_minus_one >> two_power; + for base in bases { + let mut x = BigUint::from(*base).modpow(&odd_power, &number); + + if x.eq(&BigUint::one()) || x.eq(&num_minus_one) { + continue; + } + + let mut not_a_witness = false; + + for _ in 1..two_power { + x = (&x * &x) % &number; + if x.eq(&num_minus_one) { + not_a_witness = true; + break; + } + } + + if not_a_witness { + continue; + } + + return *base; + } + + 0 +} + #[cfg(test)] mod tests { use super::*; + static DEFAULT_BASES: [u64; 12] = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]; + #[test] fn basic() { - let default_bases = vec![2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]; // these bases make miller rabin deterministic for any number < 2 ^ 64 // can use smaller number of bases for deterministic performance for numbers < 2 ^ 32 - assert_eq!(miller_rabin(3, &default_bases), 0); - assert_eq!(miller_rabin(7, &default_bases), 0); - assert_eq!(miller_rabin(11, &default_bases), 0); - assert_eq!(miller_rabin(2003, &default_bases), 0); + assert_eq!(miller_rabin(3, &DEFAULT_BASES), 0); + assert_eq!(miller_rabin(7, &DEFAULT_BASES), 0); + assert_eq!(miller_rabin(11, &DEFAULT_BASES), 0); + assert_eq!(miller_rabin(2003, &DEFAULT_BASES), 0); - assert_ne!(miller_rabin(1, &default_bases), 0); - assert_ne!(miller_rabin(4, &default_bases), 0); - assert_ne!(miller_rabin(6, &default_bases), 0); - assert_ne!(miller_rabin(21, &default_bases), 0); - assert_ne!(miller_rabin(2004, &default_bases), 0); + assert_ne!(miller_rabin(1, &DEFAULT_BASES), 0); + assert_ne!(miller_rabin(4, &DEFAULT_BASES), 0); + assert_ne!(miller_rabin(6, &DEFAULT_BASES), 0); + assert_ne!(miller_rabin(21, &DEFAULT_BASES), 0); + assert_ne!(miller_rabin(2004, &DEFAULT_BASES), 0); // bigger test cases. // primes are generated using openssl // non primes are randomly picked and checked using openssl // primes: - assert_eq!(miller_rabin(3629611793, &default_bases), 0); - assert_eq!(miller_rabin(871594686869, &default_bases), 0); - assert_eq!(miller_rabin(968236663804121, &default_bases), 0); - assert_eq!(miller_rabin(6920153791723773023, &default_bases), 0); + assert_eq!(miller_rabin(3629611793, &DEFAULT_BASES), 0); + assert_eq!(miller_rabin(871594686869, &DEFAULT_BASES), 0); + assert_eq!(miller_rabin(968236663804121, &DEFAULT_BASES), 0); + assert_eq!(miller_rabin(6920153791723773023, &DEFAULT_BASES), 0); // random non primes: - assert_ne!(miller_rabin(4546167556336341257, &default_bases), 0); - assert_ne!(miller_rabin(4363186415423517377, &default_bases), 0); - assert_ne!(miller_rabin(815479701131020226, &default_bases), 0); + assert_ne!(miller_rabin(4546167556336341257, &DEFAULT_BASES), 0); + assert_ne!(miller_rabin(4363186415423517377, &DEFAULT_BASES), 0); + assert_ne!(miller_rabin(815479701131020226, &DEFAULT_BASES), 0); // these two are made of two 31 bit prime factors: // 1950202127 * 2058609037 = 4014703722618821699 - assert_ne!(miller_rabin(4014703722618821699, &default_bases), 0); + assert_ne!(miller_rabin(4014703722618821699, &DEFAULT_BASES), 0); // 1679076769 * 2076341633 = 3486337000477823777 - assert_ne!(miller_rabin(3486337000477823777, &default_bases), 0); + assert_ne!(miller_rabin(3486337000477823777, &DEFAULT_BASES), 0); + } + + #[test] + fn big_basic() { + assert_eq!(big_miller_rabin(&BigUint::from(3u32), &DEFAULT_BASES), 0); + assert_eq!(big_miller_rabin(&BigUint::from(7u32), &DEFAULT_BASES), 0); + assert_eq!(big_miller_rabin(&BigUint::from(11u32), &DEFAULT_BASES), 0); + assert_eq!(big_miller_rabin(&BigUint::from(2003u32), &DEFAULT_BASES), 0); + + assert_ne!(big_miller_rabin(&BigUint::from(1u32), &DEFAULT_BASES), 0); + assert_ne!(big_miller_rabin(&BigUint::from(4u32), &DEFAULT_BASES), 0); + assert_ne!(big_miller_rabin(&BigUint::from(6u32), &DEFAULT_BASES), 0); + assert_ne!(big_miller_rabin(&BigUint::from(21u32), &DEFAULT_BASES), 0); + assert_ne!(big_miller_rabin(&BigUint::from(2004u32), &DEFAULT_BASES), 0); + + assert_eq!( + big_miller_rabin(&BigUint::from(3629611793u64), &DEFAULT_BASES), + 0 + ); + assert_eq!( + big_miller_rabin(&BigUint::from(871594686869u64), &DEFAULT_BASES), + 0 + ); + assert_eq!( + big_miller_rabin(&BigUint::from(968236663804121u64), &DEFAULT_BASES), + 0 + ); + assert_eq!( + big_miller_rabin(&BigUint::from(6920153791723773023u64), &DEFAULT_BASES), + 0 + ); + + assert_ne!( + big_miller_rabin(&BigUint::from(4546167556336341257u64), &DEFAULT_BASES), + 0 + ); + assert_ne!( + big_miller_rabin(&BigUint::from(4363186415423517377u64), &DEFAULT_BASES), + 0 + ); + assert_ne!( + big_miller_rabin(&BigUint::from(815479701131020226u64), &DEFAULT_BASES), + 0 + ); + assert_ne!( + big_miller_rabin(&BigUint::from(4014703722618821699u64), &DEFAULT_BASES), + 0 + ); + assert_ne!( + big_miller_rabin(&BigUint::from(3486337000477823777u64), &DEFAULT_BASES), + 0 + ); + } + + #[test] + #[ignore] + fn big_primes() { + let p1 = + BigUint::parse_bytes(b"4764862697132131451620315518348229845593592794669", 10).unwrap(); + assert_eq!(big_miller_rabin(&p1, &DEFAULT_BASES), 0); + + let p2 = BigUint::parse_bytes( + b"12550757946601963214089118080443488976766669415957018428703", + 10, + ) + .unwrap(); + assert_eq!(big_miller_rabin(&p2, &DEFAULT_BASES), 0); + + // An RSA-worthy prime + let p3 = BigUint::parse_bytes(b"157d6l5zkv45ve4azfw7nyyjt6rzir2gcjoytjev5iacnkaii8hlkyk3op7bx9qfqiie23vj9iw4qbp7zupydfq9ut6mq6m36etya6cshtqi1yi9q5xyiws92el79dqt8qk7l2pqmxaa0sxhmd2vpaibo9dkfd029j1rvkwlw4724ctgaqs5jzy0bqi5pqdjc2xerhn", 36).unwrap(); + assert_eq!(big_miller_rabin(&p3, &DEFAULT_BASES), 0); + + let n1 = BigUint::parse_bytes(b"coy6tkiaqswmce1r03ycdif3t796wzjwneewbe3cmncaplm85jxzcpdmvy0moic3lql70a81t5qdn2apac0dndhohewkspuk1wyndxsgxs3ux4a7730unru7dfmygh", 36).unwrap(); + assert_ne!(big_miller_rabin(&n1, &DEFAULT_BASES), 0); + + // RSA-2048 + let n2 = BigUint::parse_bytes(b"4l91lq4a2sgekpv8ukx1gxsk7mfeks46haggorlkazm0oufxwijid6q6v44u5me3kz3ne6yczp4fcvo62oej72oe7pjjtyxgid5b8xdz1e8daafspbzcy1hd8i4urjh9hm0tyylsgqsss3jn372d6fmykpw4bb9cr1ngxnncsbod3kg49o7owzqnsci5pwqt8bch0t60gq0st2gyx7ii3mzhb1pp1yvjyor35hwvok1sxj3ih46rpd27li8y5yli3mgdttcn65k3szfa6rbcnbgkojqjjq72gar6raslnh6sjd2fy7yj3bwo43obvbg3ws8y28kpol3okb5b3fld03sq1kgrj2fugiaxgplva6x5ssilqq4g0b21xy2kiou3sqsgonmqx55v", 36).unwrap(); + assert_ne!(big_miller_rabin(&n2, &DEFAULT_BASES), 0); } } diff --git a/src/math/mod.rs b/src/math/mod.rs index 821034673e1..7407465c3b0 100644 --- a/src/math/mod.rs +++ b/src/math/mod.rs @@ -1,68 +1,183 @@ +mod abs; +mod aliquot_sum; +mod amicable_numbers; +mod area_of_polygon; +mod area_under_curve; mod armstrong_number; +mod average; mod baby_step_giant_step; +mod bell_numbers; +mod binary_exponentiation; +mod binomial_coefficient; +mod catalan_numbers; +mod ceil; +mod chinese_remainder_theorem; +mod collatz_sequence; +mod combinations; +mod cross_entropy_loss; +mod decimal_to_fraction; +mod doomsday; +mod elliptic_curve; +mod euclidean_distance; +mod exponential_linear_unit; mod extended_euclidean_algorithm; +pub mod factorial; +mod factors; mod fast_fourier_transform; mod fast_power; mod faster_perfect_numbers; +mod field; +mod frizzy_number; mod gaussian_elimination; +mod gaussian_error_linear_unit; mod gcd_of_n_numbers; +mod geometric_series; mod greatest_common_divisor; +mod huber_loss; +mod infix_to_postfix; +mod interest; +mod interpolation; +mod interquartile_range; mod karatsuba_multiplication; mod lcm_of_n_numbers; +mod leaky_relu; +mod least_square_approx; mod linear_sieve; +mod logarithm; +mod lucas_series; mod matrix_ops; mod mersenne_primes; mod miller_rabin; +mod modular_exponential; mod newton_raphson; mod nthprime; mod pascal_triangle; +mod perfect_cube; mod perfect_numbers; +mod perfect_square; mod pollard_rho; +mod postfix_evaluation; mod prime_check; mod prime_factors; mod prime_numbers; mod quadratic_residue; mod random; +mod relu; mod sieve_of_eratosthenes; -mod simpson_integration; +mod sigmoid; +mod signum; +mod simpsons_integration; +mod softmax; +mod sprague_grundy_theorem; +mod square_pyramidal_numbers; mod square_root; +mod sum_of_digits; +mod sum_of_geometric_progression; +mod sum_of_harmonic_series; +mod sylvester_sequence; +mod tanh; +mod trapezoidal_integration; mod trial_division; +mod trig_functions; +mod vector_cross_product; mod zellers_congruence_algorithm; +pub use self::abs::abs; +pub use self::aliquot_sum::aliquot_sum; +pub use self::amicable_numbers::amicable_pairs_under_n; +pub use self::area_of_polygon::area_of_polygon; +pub use self::area_under_curve::area_under_curve; pub use self::armstrong_number::is_armstrong_number; +pub use self::average::{mean, median, mode}; pub use self::baby_step_giant_step::baby_step_giant_step; +pub use self::bell_numbers::bell_number; +pub use self::binary_exponentiation::binary_exponentiation; +pub use self::binomial_coefficient::binom; +pub use self::catalan_numbers::init_catalan; +pub use self::ceil::ceil; +pub use self::chinese_remainder_theorem::chinese_remainder_theorem; +pub use self::collatz_sequence::sequence; +pub use self::combinations::combinations; +pub use self::cross_entropy_loss::cross_entropy_loss; +pub use self::decimal_to_fraction::decimal_to_fraction; +pub use self::doomsday::get_week_day; +pub use self::elliptic_curve::EllipticCurve; +pub use self::euclidean_distance::euclidean_distance; +pub use self::exponential_linear_unit::exponential_linear_unit; pub use self::extended_euclidean_algorithm::extended_euclidean_algorithm; +pub use self::factorial::{factorial, factorial_bigmath, factorial_recursive}; +pub use self::factors::factors; pub use self::fast_fourier_transform::{ fast_fourier_transform, fast_fourier_transform_input_permutation, inverse_fast_fourier_transform, }; pub use self::fast_power::fast_power; pub use self::faster_perfect_numbers::generate_perfect_numbers; +pub use self::field::{Field, PrimeField}; +pub use self::frizzy_number::get_nth_frizzy; pub use self::gaussian_elimination::gaussian_elimination; +pub use self::gaussian_error_linear_unit::gaussian_error_linear_unit; pub use self::gcd_of_n_numbers::gcd; +pub use self::geometric_series::geometric_series; pub use self::greatest_common_divisor::{ greatest_common_divisor_iterative, greatest_common_divisor_recursive, + greatest_common_divisor_stein, }; +pub use self::huber_loss::huber_loss; +pub use self::infix_to_postfix::infix_to_postfix; +pub use self::interest::{compound_interest, simple_interest}; +pub use self::interpolation::{lagrange_polynomial_interpolation, linear_interpolation}; +pub use self::interquartile_range::interquartile_range; pub use self::karatsuba_multiplication::multiply; pub use self::lcm_of_n_numbers::lcm; +pub use self::leaky_relu::leaky_relu; +pub use self::least_square_approx::least_square_approx; pub use self::linear_sieve::LinearSieve; -pub use self::matrix_ops::{ - matrix_add, matrix_multiply, matrix_scalar_multiplication, matrix_subtract, matrix_transpose, -}; +pub use self::logarithm::log; +pub use self::lucas_series::dynamic_lucas_number; +pub use self::lucas_series::recursive_lucas_number; +pub use self::matrix_ops::Matrix; pub use self::mersenne_primes::{get_mersenne_primes, is_mersenne_prime}; -pub use self::miller_rabin::miller_rabin; +pub use self::miller_rabin::{big_miller_rabin, miller_rabin}; +pub use self::modular_exponential::{mod_inverse, modular_exponential}; pub use self::newton_raphson::find_root; pub use self::nthprime::nthprime; pub use self::pascal_triangle::pascal_triangle; +pub use self::perfect_cube::perfect_cube_binary_search; pub use self::perfect_numbers::perfect_numbers; +pub use self::perfect_square::perfect_square; +pub use self::perfect_square::perfect_square_binary_search; pub use self::pollard_rho::{pollard_rho_factorize, pollard_rho_get_one_factor}; +pub use self::postfix_evaluation::evaluate_postfix; pub use self::prime_check::prime_check; pub use self::prime_factors::prime_factors; pub use self::prime_numbers::prime_numbers; -pub use self::quadratic_residue::cipolla; +pub use self::quadratic_residue::{cipolla, tonelli_shanks}; pub use self::random::PCG32; +pub use self::relu::relu; pub use self::sieve_of_eratosthenes::sieve_of_eratosthenes; -pub use self::simpson_integration::simpson_integration; -pub use self::square_root::square_root; +pub use self::sigmoid::sigmoid; +pub use self::signum::signum; +pub use self::simpsons_integration::simpsons_integration; +pub use self::softmax::softmax; +pub use self::sprague_grundy_theorem::calculate_grundy_number; +pub use self::square_pyramidal_numbers::square_pyramidal_number; +pub use self::square_root::{fast_inv_sqrt, square_root}; +pub use self::sum_of_digits::{sum_digits_iterative, sum_digits_recursive}; +pub use self::sum_of_geometric_progression::sum_of_geometric_progression; +pub use self::sum_of_harmonic_series::sum_of_harmonic_progression; +pub use self::sylvester_sequence::sylvester; +pub use self::tanh::tanh; +pub use self::trapezoidal_integration::trapezoidal_integral; pub use self::trial_division::trial_division; +pub use self::trig_functions::cosine; +pub use self::trig_functions::cosine_no_radian_arg; +pub use self::trig_functions::cotan; +pub use self::trig_functions::cotan_no_radian_arg; +pub use self::trig_functions::sine; +pub use self::trig_functions::sine_no_radian_arg; +pub use self::trig_functions::tan; +pub use self::trig_functions::tan_no_radian_arg; +pub use self::vector_cross_product::cross_product; +pub use self::vector_cross_product::vector_magnitude; pub use self::zellers_congruence_algorithm::zellers_congruence_algorithm; diff --git a/src/math/modular_exponential.rs b/src/math/modular_exponential.rs new file mode 100644 index 00000000000..1e9a9f41cac --- /dev/null +++ b/src/math/modular_exponential.rs @@ -0,0 +1,136 @@ +/// Calculate the greatest common divisor (GCD) of two numbers and the +/// coefficients of Bézout's identity using the Extended Euclidean Algorithm. +/// +/// # Arguments +/// +/// * `a` - One of the numbers to find the GCD of +/// * `m` - The other number to find the GCD of +/// +/// # Returns +/// +/// A tuple (gcd, x1, x2) such that: +/// gcd - the greatest common divisor of a and m. +/// x1, x2 - the coefficients such that `a * x1 + m * x2` is equivalent to `gcd` modulo `m`. +pub fn gcd_extended(a: i64, m: i64) -> (i64, i64, i64) { + if a == 0 { + (m, 0, 1) + } else { + let (gcd, x1, x2) = gcd_extended(m % a, a); + let x = x2 - (m / a) * x1; + (gcd, x, x1) + } +} + +/// Find the modular multiplicative inverse of a number modulo `m`. +/// +/// # Arguments +/// +/// * `b` - The number to find the modular inverse of +/// * `m` - The modulus +/// +/// # Returns +/// +/// The modular inverse of `b` modulo `m`. +/// +/// # Panics +/// +/// Panics if the inverse does not exist (i.e., `b` and `m` are not coprime). +pub fn mod_inverse(b: i64, m: i64) -> i64 { + let (gcd, x, _) = gcd_extended(b, m); + if gcd != 1 { + panic!("Inverse does not exist"); + } else { + // Ensure the modular inverse is positive + (x % m + m) % m + } +} + +/// Perform modular exponentiation of a number raised to a power modulo `m`. +/// This function handles both positive and negative exponents. +/// +/// # Arguments +/// +/// * `base` - The base number to be raised to the `power` +/// * `power` - The exponent to raise the `base` to +/// * `modulus` - The modulus to perform the operation under +/// +/// # Returns +/// +/// The result of `base` raised to `power` modulo `modulus`. +pub fn modular_exponential(base: i64, mut power: i64, modulus: i64) -> i64 { + if modulus == 1 { + return 0; // Base case: any number modulo 1 is 0 + } + + // Adjust if the exponent is negative by finding the modular inverse + let mut base = if power < 0 { + mod_inverse(base, modulus) + } else { + base % modulus + }; + + let mut result = 1; // Initialize result + power = power.abs(); // Work with the absolute value of the exponent + + // Perform the exponentiation + while power > 0 { + if power & 1 == 1 { + result = (result * base) % modulus; + } + power >>= 1; // Divide the power by 2 + base = (base * base) % modulus; // Square the base + } + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_modular_exponential_positive() { + assert_eq!(modular_exponential(2, 3, 5), 3); // 2^3 % 5 = 8 % 5 = 3 + assert_eq!(modular_exponential(7, 2, 13), 10); // 7^2 % 13 = 49 % 13 = 10 + assert_eq!(modular_exponential(5, 5, 31), 25); // 5^5 % 31 = 3125 % 31 = 25 + assert_eq!(modular_exponential(10, 8, 11), 1); // 10^8 % 11 = 100000000 % 11 = 1 + assert_eq!(modular_exponential(123, 45, 67), 62); // 123^45 % 67 + } + + #[test] + fn test_modular_inverse() { + assert_eq!(mod_inverse(7, 13), 2); // Inverse of 7 mod 13 is 2 + assert_eq!(mod_inverse(5, 31), 25); // Inverse of 5 mod 31 is 25 + assert_eq!(mod_inverse(10, 11), 10); // Inverse of 10 mod 1 is 10 + assert_eq!(mod_inverse(123, 67), 6); // Inverse of 123 mod 67 is 6 + assert_eq!(mod_inverse(9, 17), 2); // Inverse of 9 mod 17 is 2 + } + + #[test] + fn test_modular_exponential_negative() { + assert_eq!( + modular_exponential(7, -2, 13), + mod_inverse(7, 13).pow(2) % 13 + ); // Inverse of 7 mod 13 is 2, 2^2 % 13 = 4 % 13 = 4 + assert_eq!( + modular_exponential(5, -5, 31), + mod_inverse(5, 31).pow(5) % 31 + ); // Inverse of 5 mod 31 is 25, 25^5 % 31 = 25 + assert_eq!( + modular_exponential(10, -8, 11), + mod_inverse(10, 11).pow(8) % 11 + ); // Inverse of 10 mod 11 is 10, 10^8 % 11 = 10 + assert_eq!( + modular_exponential(123, -5, 67), + mod_inverse(123, 67).pow(5) % 67 + ); // Inverse of 123 mod 67 is calculated via the function + } + + #[test] + fn test_modular_exponential_edge_cases() { + assert_eq!(modular_exponential(0, 0, 1), 0); // 0^0 % 1 should be 0 as the modulus is 1 + assert_eq!(modular_exponential(0, 10, 1), 0); // 0^n % 1 should be 0 for any n + assert_eq!(modular_exponential(10, 0, 1), 0); // n^0 % 1 should be 0 for any n + assert_eq!(modular_exponential(1, 1, 1), 0); // 1^1 % 1 should be 0 + assert_eq!(modular_exponential(-1, 2, 1), 0); // (-1)^2 % 1 should be 0 + } +} diff --git a/src/math/newton_raphson.rs b/src/math/newton_raphson.rs index ad45451396f..5a21ef625a9 100644 --- a/src/math/newton_raphson.rs +++ b/src/math/newton_raphson.rs @@ -15,10 +15,10 @@ mod tests { use super::*; fn math_fn(x: f64) -> f64 { - return x.cos() - (x * x * x); + x.cos() - (x * x * x) } fn math_fnd(x: f64) -> f64 { - return -x.sin() - 3.0 * (x * x); + -x.sin() - 3.0 * (x * x) } #[test] fn basic() { diff --git a/src/math/nthprime.rs b/src/math/nthprime.rs index 2802d3191ed..1b0e93c855b 100644 --- a/src/math/nthprime.rs +++ b/src/math/nthprime.rs @@ -1,5 +1,5 @@ // Generate the nth prime number. -// Algorithm is inspired by the the optimized version of the Sieve of Eratosthenes. +// Algorithm is inspired by the optimized version of the Sieve of Eratosthenes. pub fn nthprime(nth: u64) -> u64 { let mut total_prime: u64 = 0; let mut size_factor: u64 = 2; @@ -39,8 +39,8 @@ fn get_primes(s: u64) -> Vec { fn count_prime(primes: Vec, n: u64) -> Option { let mut counter: u64 = 0; - for i in 2..primes.len() { - counter += primes.get(i).unwrap(); + for (i, prime) in primes.iter().enumerate().skip(2) { + counter += prime; if counter == n { return Some(i as u64); } diff --git a/src/math/pascal_triangle.rs b/src/math/pascal_triangle.rs index 3e504801d58..34643029b6b 100644 --- a/src/math/pascal_triangle.rs +++ b/src/math/pascal_triangle.rs @@ -1,17 +1,18 @@ -/// ## Paslcal's triangle problem - -/// pascal_triangle(num_rows) returns the first num_rows of Pascal's triangle. -/// About Pascal's triangle: https://en.wikipedia.org/wiki/Pascal%27s_triangle +/// ## Pascal's triangle problem +/// +/// pascal_triangle(num_rows) returns the first num_rows of Pascal's triangle.\ +/// About Pascal's triangle: +/// +/// # Arguments: +/// * `num_rows`: number of rows of triangle /// -/// Arguments: -/// * `num_rows` - number of rows of triangle -/// Complexity -/// - time complexity: O(n^2), -/// - space complexity: O(n^2), +/// # Complexity +/// - time complexity: O(n^2), +/// - space complexity: O(n^2), pub fn pascal_triangle(num_rows: i32) -> Vec> { let mut ans: Vec> = vec![]; - for i in 1..num_rows + 1 { + for i in 1..=num_rows { let mut vec: Vec = vec![1]; let mut res: i32 = 1; diff --git a/src/math/perfect_cube.rs b/src/math/perfect_cube.rs new file mode 100644 index 00000000000..d4f2c7becef --- /dev/null +++ b/src/math/perfect_cube.rs @@ -0,0 +1,61 @@ +// Check if a number is a perfect cube using binary search. +pub fn perfect_cube_binary_search(n: i64) -> bool { + if n < 0 { + return perfect_cube_binary_search(-n); + } + + // Initialize left and right boundaries for binary search. + let mut left = 0; + let mut right = n.abs(); // Use the absolute value to handle negative numbers + + // Binary search loop to find the cube root. + while left <= right { + // Calculate the mid-point. + let mid = left + (right - left) / 2; + // Calculate the cube of the mid-point. + let cube = mid * mid * mid; + + // Check if the cube equals the original number. + match cube.cmp(&n) { + std::cmp::Ordering::Equal => return true, + std::cmp::Ordering::Less => left = mid + 1, + std::cmp::Ordering::Greater => right = mid - 1, + } + } + + // If no cube root is found, return false. + false +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_perfect_cube { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (n, expected) = $inputs; + assert_eq!(perfect_cube_binary_search(n), expected); + assert_eq!(perfect_cube_binary_search(-n), expected); + } + )* + } + } + + test_perfect_cube! { + num_0_a_perfect_cube: (0, true), + num_1_is_a_perfect_cube: (1, true), + num_27_is_a_perfect_cube: (27, true), + num_64_is_a_perfect_cube: (64, true), + num_8_is_a_perfect_cube: (8, true), + num_2_is_not_a_perfect_cube: (2, false), + num_3_is_not_a_perfect_cube: (3, false), + num_4_is_not_a_perfect_cube: (4, false), + num_5_is_not_a_perfect_cube: (5, false), + num_999_is_not_a_perfect_cube: (999, false), + num_1000_is_a_perfect_cube: (1000, true), + num_1001_is_not_a_perfect_cube: (1001, false), + } +} diff --git a/src/math/perfect_numbers.rs b/src/math/perfect_numbers.rs index b4b50b334c3..0d819d2b2f1 100644 --- a/src/math/perfect_numbers.rs +++ b/src/math/perfect_numbers.rs @@ -14,7 +14,7 @@ pub fn perfect_numbers(max: usize) -> Vec { let mut result: Vec = Vec::new(); // It is not known if there are any odd perfect numbers, so we go around all the numbers. - for i in 1..max + 1 { + for i in 1..=max { if is_perfect_number(i) { result.push(i); } @@ -29,15 +29,15 @@ mod tests { #[test] fn basic() { - assert_eq!(is_perfect_number(6), true); - assert_eq!(is_perfect_number(28), true); - assert_eq!(is_perfect_number(496), true); - assert_eq!(is_perfect_number(8128), true); - - assert_eq!(is_perfect_number(5), false); - assert_eq!(is_perfect_number(86), false); - assert_eq!(is_perfect_number(497), false); - assert_eq!(is_perfect_number(8120), false); + assert!(is_perfect_number(6)); + assert!(is_perfect_number(28)); + assert!(is_perfect_number(496)); + assert!(is_perfect_number(8128)); + + assert!(!is_perfect_number(5)); + assert!(!is_perfect_number(86)); + assert!(!is_perfect_number(497)); + assert!(!is_perfect_number(8120)); assert_eq!(perfect_numbers(10), vec![6]); assert_eq!(perfect_numbers(100), vec![6, 28]); diff --git a/src/math/perfect_square.rs b/src/math/perfect_square.rs new file mode 100644 index 00000000000..7b0f69976c4 --- /dev/null +++ b/src/math/perfect_square.rs @@ -0,0 +1,57 @@ +// Author : cyrixninja +// Perfect Square : Checks if a number is perfect square number or not +// https://en.wikipedia.org/wiki/Perfect_square +pub fn perfect_square(num: i32) -> bool { + if num < 0 { + return false; + } + let sqrt_num = (num as f64).sqrt() as i32; + sqrt_num * sqrt_num == num +} + +pub fn perfect_square_binary_search(n: i32) -> bool { + if n < 0 { + return false; + } + + let mut left = 0; + let mut right = n; + + while left <= right { + let mid = i32::midpoint(left, right); + let mid_squared = mid * mid; + + match mid_squared.cmp(&n) { + std::cmp::Ordering::Equal => return true, + std::cmp::Ordering::Greater => right = mid - 1, + std::cmp::Ordering::Less => left = mid + 1, + } + } + + false +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_perfect_square() { + assert!(perfect_square(9)); + assert!(perfect_square(81)); + assert!(perfect_square(4)); + assert!(perfect_square(0)); + assert!(!perfect_square(3)); + assert!(!perfect_square(-19)); + } + + #[test] + fn test_perfect_square_binary_search() { + assert!(perfect_square_binary_search(9)); + assert!(perfect_square_binary_search(81)); + assert!(perfect_square_binary_search(4)); + assert!(perfect_square_binary_search(0)); + assert!(!perfect_square_binary_search(3)); + assert!(!perfect_square_binary_search(-19)); + } +} diff --git a/src/math/pollard_rho.rs b/src/math/pollard_rho.rs index bcd61cc534b..1ba7481e989 100644 --- a/src/math/pollard_rho.rs +++ b/src/math/pollard_rho.rs @@ -178,8 +178,7 @@ pub fn pollard_rho_factorize( return result; } let mut to_be_factored = vec![number]; - while !to_be_factored.is_empty() { - let last = to_be_factored.pop().unwrap(); + while let Some(last) = to_be_factored.pop() { if last < minimum_prime_factors.len() as u64 { result.append(&mut factor_using_mpf(last as usize, minimum_prime_factors)); continue; @@ -273,7 +272,7 @@ mod test { for num in numbers { assert!(check_factorization( num, - &pollard_rho_factorize(num, &mut seed, &vec![], &vec![]) + &pollard_rho_factorize(num, &mut seed, &[], &[]) )); } } diff --git a/src/math/postfix_evaluation.rs b/src/math/postfix_evaluation.rs new file mode 100644 index 00000000000..27bf4e3eacc --- /dev/null +++ b/src/math/postfix_evaluation.rs @@ -0,0 +1,105 @@ +//! This module provides a function to evaluate postfix (Reverse Polish Notation) expressions. +//! Postfix notation is a mathematical notation in which every operator follows all of its operands. +//! +//! The evaluator supports the four basic arithmetic operations: addition, subtraction, multiplication, and division. +//! It handles errors such as division by zero, invalid operators, insufficient operands, and invalid postfix expressions. + +/// Enumeration of errors that can occur when evaluating a postfix expression. +#[derive(Debug, PartialEq)] +pub enum PostfixError { + DivisionByZero, + InvalidOperator, + InsufficientOperands, + InvalidExpression, +} + +/// Evaluates a postfix expression and returns the result or an error. +/// +/// # Arguments +/// +/// * `expression` - A string slice that contains the postfix expression to be evaluated. +/// The tokens (numbers and operators) should be separated by whitespace. +/// +/// # Returns +/// +/// * `Ok(isize)` if the expression is valid and evaluates to an integer. +/// * `Err(PostfixError)` if the expression is invalid or encounters errors during evaluation. +/// +/// # Errors +/// +/// * `PostfixError::DivisionByZero` - If a division by zero is attempted. +/// * `PostfixError::InvalidOperator` - If an unknown operator is encountered. +/// * `PostfixError::InsufficientOperands` - If there are not enough operands for an operator. +/// * `PostfixError::InvalidExpression` - If the expression is malformed (e.g., multiple values are left on the stack). +pub fn evaluate_postfix(expression: &str) -> Result { + let mut stack: Vec = Vec::new(); + + for token in expression.split_whitespace() { + if let Ok(number) = token.parse::() { + // If the token is a number, push it onto the stack. + stack.push(number); + } else { + // If the token is an operator, pop the top two values from the stack, + // apply the operator, and push the result back onto the stack. + if let (Some(b), Some(a)) = (stack.pop(), stack.pop()) { + match token { + "+" => stack.push(a + b), + "-" => stack.push(a - b), + "*" => stack.push(a * b), + "/" => { + if b == 0 { + return Err(PostfixError::DivisionByZero); + } + stack.push(a / b); + } + _ => return Err(PostfixError::InvalidOperator), + } + } else { + return Err(PostfixError::InsufficientOperands); + } + } + } + // The final result should be the only element on the stack. + if stack.len() == 1 { + Ok(stack[0]) + } else { + Err(PostfixError::InvalidExpression) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! postfix_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(evaluate_postfix(input), expected); + } + )* + } + } + + postfix_tests! { + test_addition_of_two_numbers: ("2 3 +", Ok(5)), + test_multiplication_and_addition: ("5 2 * 4 +", Ok(14)), + test_simple_division: ("10 2 /", Ok(5)), + test_operator_without_operands: ("+", Err(PostfixError::InsufficientOperands)), + test_division_by_zero_error: ("5 0 /", Err(PostfixError::DivisionByZero)), + test_invalid_operator_in_expression: ("2 3 #", Err(PostfixError::InvalidOperator)), + test_missing_operator_for_expression: ("2 3", Err(PostfixError::InvalidExpression)), + test_extra_operands_in_expression: ("2 3 4 +", Err(PostfixError::InvalidExpression)), + test_empty_expression_error: ("", Err(PostfixError::InvalidExpression)), + test_single_number_expression: ("42", Ok(42)), + test_addition_of_negative_numbers: ("-3 -2 +", Ok(-5)), + test_complex_expression_with_multiplication_and_addition: ("3 5 8 * 7 + *", Ok(141)), + test_expression_with_extra_whitespace: (" 3 4 + ", Ok(7)), + test_valid_then_invalid_operator: ("5 2 + 1 #", Err(PostfixError::InvalidOperator)), + test_first_division_by_zero: ("5 0 / 6 0 /", Err(PostfixError::DivisionByZero)), + test_complex_expression_with_multiple_operators: ("5 1 2 + 4 * + 3 -", Ok(14)), + test_expression_with_only_whitespace: (" ", Err(PostfixError::InvalidExpression)), + } +} diff --git a/src/math/prime_check.rs b/src/math/prime_check.rs index 2daf4cc67c1..4902a65dbf7 100644 --- a/src/math/prime_check.rs +++ b/src/math/prime_check.rs @@ -20,14 +20,14 @@ mod tests { #[test] fn basic() { - assert_eq!(prime_check(3), true); - assert_eq!(prime_check(7), true); - assert_eq!(prime_check(11), true); - assert_eq!(prime_check(2003), true); + assert!(prime_check(3)); + assert!(prime_check(7)); + assert!(prime_check(11)); + assert!(prime_check(2003)); - assert_eq!(prime_check(4), false); - assert_eq!(prime_check(6), false); - assert_eq!(prime_check(21), false); - assert_eq!(prime_check(2004), false); + assert!(!prime_check(4)); + assert!(!prime_check(6)); + assert!(!prime_check(21)); + assert!(!prime_check(2004)); } } diff --git a/src/math/prime_factors.rs b/src/math/prime_factors.rs index c380fc58488..7b89b09c9b8 100644 --- a/src/math/prime_factors.rs +++ b/src/math/prime_factors.rs @@ -4,13 +4,6 @@ pub fn prime_factors(n: u64) -> Vec { let mut i = 2; let mut n = n; let mut factors = Vec::new(); - if n == 0 { - return factors; - } - if n == 1 { - factors.push(1); - return factors; - } while i * i <= n { if n % i != 0 { if i != 2 { @@ -34,6 +27,7 @@ mod tests { #[test] fn it_works() { assert_eq!(prime_factors(0), vec![]); + assert_eq!(prime_factors(1), vec![]); assert_eq!(prime_factors(11), vec![11]); assert_eq!(prime_factors(25), vec![5, 5]); assert_eq!(prime_factors(33), vec![3, 11]); diff --git a/src/math/prime_numbers.rs b/src/math/prime_numbers.rs index 1643340f8ff..f045133a168 100644 --- a/src/math/prime_numbers.rs +++ b/src/math/prime_numbers.rs @@ -4,9 +4,9 @@ pub fn prime_numbers(max: usize) -> Vec { if max >= 2 { result.push(2) } - for i in (3..max + 1).step_by(2) { + for i in (3..=max).step_by(2) { let stop: usize = (i as f64).sqrt() as usize + 1; - let mut status: bool = true; + let mut status = true; for j in (3..stop).step_by(2) { if i % j == 0 { diff --git a/src/math/quadratic_residue.rs b/src/math/quadratic_residue.rs index 44696d7d8fd..e3f2e6b819b 100644 --- a/src/math/quadratic_residue.rs +++ b/src/math/quadratic_residue.rs @@ -10,15 +10,17 @@ use std::rc::Rc; use std::time::{SystemTime, UNIX_EPOCH}; +use rand::Rng; + use super::{fast_power, PCG32}; #[derive(Debug)] -struct CustomFiniteFiled { +struct CustomFiniteField { modulus: u64, i_square: u64, } -impl CustomFiniteFiled { +impl CustomFiniteField { pub fn new(modulus: u64, i_square: u64) -> Self { Self { modulus, i_square } } @@ -28,11 +30,11 @@ impl CustomFiniteFiled { struct CustomComplexNumber { real: u64, imag: u64, - f: Rc, + f: Rc, } impl CustomComplexNumber { - pub fn new(real: u64, imag: u64, f: Rc) -> Self { + pub fn new(real: u64, imag: u64, f: Rc) -> Self { Self { real, imag, f } } @@ -70,6 +72,22 @@ fn is_residue(x: u64, modulus: u64) -> bool { x != 0 && fast_power(x as usize, power as usize, modulus as usize) == 1 } +/// The Legendre symbol `(a | p)` +/// +/// Returns 0 if a = 0 mod p, 1 if a is a square mod p, -1 if it not a square mod p. +/// +/// +pub fn legendre_symbol(a: u64, odd_prime: u64) -> i64 { + debug_assert!(odd_prime % 2 != 0, "prime must be odd"); + if a == 0 { + 0 + } else if is_residue(a, odd_prime) { + 1 + } else { + -1 + } +} + // return two solutions (x1, x2) for Quadratic Residue problem x^2 = a (mod p), where p is an odd prime // if a is Quadratic Nonresidues, return None pub fn cipolla(a: u32, p: u32, seed: Option) -> Option<(u32, u32)> { @@ -97,11 +115,11 @@ pub fn cipolla(a: u32, p: u32, seed: Option) -> Option<(u32, u32)> { break r; } }; - let filed = Rc::new(CustomFiniteFiled::new(p, (p + r * r - a) % p)); + let filed = Rc::new(CustomFiniteField::new(p, (p + r * r - a) % p)); let comp = CustomComplexNumber::new(r, 1, filed); let power = (p + 1) >> 1; let x0 = CustomComplexNumber::fast_power(comp, power).real as u32; - let x1 = p as u32 - x0 as u32; + let x1 = p as u32 - x0; if x0 < x1 { Some((x0, x1)) } else { @@ -109,19 +127,96 @@ pub fn cipolla(a: u32, p: u32, seed: Option) -> Option<(u32, u32)> { } } +/// Returns one of the two possible solutions of _x² = a mod p_, if any. +/// +/// The other solution is _-x mod p_. If there is no solution, returns `None`. +/// +/// Reference: H. Cohen, _A course in computational algebraic number theory_, Algorithm 1.4.3 +/// +/// ## Implementation details +/// +/// To avoid multiplication overflows, internally the algorithm uses the `128`-bit arithmetic. +/// +/// Also see [`cipolla`]. +pub fn tonelli_shanks(a: i64, odd_prime: u64) -> Option { + let p: u128 = odd_prime as u128; + let e = (p - 1).trailing_zeros(); + let q = (p - 1) >> e; // p = 2^e * q, with q odd + + let a = if a < 0 { + a.rem_euclid(p as i64) as u128 + } else { + a as u128 + }; + + let power_mod_p = |b, e| fast_power(b as usize, e as usize, p as usize) as u128; + + // find generator: choose a random non-residue n mod p + let mut rng = rand::rng(); + let n = loop { + let n = rng.random_range(0..p); + if legendre_symbol(n as u64, p as u64) == -1 { + break n; + } + }; + let z = power_mod_p(n, q); + + // init + let mut y = z; + let mut r = e; + let mut x = power_mod_p(a, (q - 1) / 2) % p; + let mut b = (a * x * x) % p; + x = (a * x) % p; + + while b % p != 1 { + // find exponent + let m = (1..r) + .scan(b, |prev, m| { + *prev = (*prev * *prev) % p; + Some((m, *prev == 1)) + }) + .find_map(|(m, cond)| cond.then_some(m)); + let Some(m) = m else { + return None; // non-residue + }; + + // reduce exponent + let t = power_mod_p(y as u128, 2_u128.pow(r - m - 1)); + y = (t * t) % p; + r = m; + x = (x * t) % p; + b = (b * y) % p; + } + + Some(x as u64) +} + #[cfg(test)] mod tests { use super::*; + fn tonelli_shanks_residues(x: u64, odd_prime: u64) -> Option<(u64, u64)> { + let x = tonelli_shanks(x as i64, odd_prime)?; + let x2 = (-(x as i64)).rem_euclid(odd_prime as i64) as u64; + Some(if x < x2 { (x, x2) } else { (x2, x) }) + } + #[test] - fn small_numbers() { + fn cipolla_small_numbers() { assert_eq!(cipolla(1, 43, None), Some((1, 42))); assert_eq!(cipolla(2, 23, None), Some((5, 18))); assert_eq!(cipolla(17, 83, Some(42)), Some((10, 73))); } #[test] - fn random_numbers() { + fn tonelli_shanks_small_numbers() { + assert_eq!(tonelli_shanks_residues(1, 43).unwrap(), (1, 42)); + assert_eq!(tonelli_shanks_residues(2, 23).unwrap(), (5, 18)); + assert_eq!(tonelli_shanks_residues(17, 83).unwrap(), (10, 73)); + } + + #[test] + fn cipolla_random_numbers() { assert_eq!(cipolla(392203, 852167, None), Some((413252, 438915))); assert_eq!( cipolla(379606557, 425172197, None), @@ -141,8 +236,33 @@ mod tests { ); } + #[test] + fn tonelli_shanks_random_numbers() { + assert_eq!( + tonelli_shanks_residues(392203, 852167), + Some((413252, 438915)) + ); + assert_eq!( + tonelli_shanks_residues(379606557, 425172197), + Some((143417827, 281754370)) + ); + assert_eq!( + tonelli_shanks_residues(585251669, 892950901), + Some((192354555, 700596346)) + ); + assert_eq!( + tonelli_shanks_residues(404690348, 430183399), + Some((57227138, 372956261)) + ); + assert_eq!( + tonelli_shanks_residues(210205747, 625380647), + Some((76810367, 548570280)) + ); + } + #[test] fn no_answer() { assert_eq!(cipolla(650927, 852167, None), None); + assert_eq!(tonelli_shanks(650927, 852167), None); } } diff --git a/src/math/random.rs b/src/math/random.rs index 88e87866b06..de218035484 100644 --- a/src/math/random.rs +++ b/src/math/random.rs @@ -107,7 +107,7 @@ impl PCG32 { } } -impl<'a> Iterator for IterMut<'a> { +impl Iterator for IterMut<'_> { type Item = u32; fn next(&mut self) -> Option { Some(self.pcg.get_u32()) diff --git a/src/math/relu.rs b/src/math/relu.rs new file mode 100644 index 00000000000..dacc2fe8289 --- /dev/null +++ b/src/math/relu.rs @@ -0,0 +1,34 @@ +//Rust implementation of the ReLU (rectified linear unit) activation function. +//The formula for ReLU is quite simple really: (if x>0 -> x, else -> 0) +//More information on the concepts of ReLU can be found here: +//https://en.wikipedia.org/wiki/Rectifier_(neural_networks) + +//The function below takes a reference to a mutable Vector as an argument +//and returns the vector with 'ReLU' applied to all values. +//Of course, these functions can be changed by the developer so that the input vector isn't manipulated. +//This is simply an implemenation of the formula. + +pub fn relu(array: &mut Vec) -> &mut Vec { + //note that these calculations are assuming the Vector values consists of real numbers in radians + for value in &mut *array { + if value <= &mut 0. { + *value = 0.; + } + } + + array +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_relu() { + let mut test: Vec = Vec::from([1.0, 0.5, -1.0, 0.0, 0.3]); + assert_eq!( + relu(&mut test), + &mut Vec::::from([1.0, 0.5, 0.0, 0.0, 0.3]) + ); + } +} diff --git a/src/math/sieve_of_eratosthenes.rs b/src/math/sieve_of_eratosthenes.rs index 079201b26a4..ed331845317 100644 --- a/src/math/sieve_of_eratosthenes.rs +++ b/src/math/sieve_of_eratosthenes.rs @@ -1,53 +1,120 @@ +/// Implements the Sieve of Eratosthenes algorithm to find all prime numbers up to a given limit. +/// +/// # Arguments +/// +/// * `num` - The upper limit up to which to find prime numbers (inclusive). +/// +/// # Returns +/// +/// A vector containing all prime numbers up to the specified limit. pub fn sieve_of_eratosthenes(num: usize) -> Vec { let mut result: Vec = Vec::new(); - if num == 0 { - return result; + if num >= 2 { + let mut sieve: Vec = vec![true; num + 1]; + + // 0 and 1 are not prime numbers + sieve[0] = false; + sieve[1] = false; + + let end: usize = (num as f64).sqrt() as usize; + + // Mark non-prime numbers in the sieve and collect primes up to `end` + update_sieve(&mut sieve, end, num, &mut result); + + // Collect remaining primes beyond `end` + result.extend(extract_remaining_primes(&sieve, end + 1)); } - let mut start: usize = 2; - let end: usize = (num as f64).sqrt() as usize; - let mut sieve: Vec = vec![true; num + 1]; + result +} - while start <= end { +/// Marks non-prime numbers in the sieve and collects prime numbers up to `end`. +/// +/// # Arguments +/// +/// * `sieve` - A mutable slice of booleans representing the sieve. +/// * `end` - The square root of the upper limit, used to optimize the algorithm. +/// * `num` - The upper limit up to which to mark non-prime numbers. +/// * `result` - A mutable vector to store the prime numbers. +fn update_sieve(sieve: &mut [bool], end: usize, num: usize, result: &mut Vec) { + for start in 2..=end { if sieve[start] { - result.push(start); - for i in (start * start..num + 1).step_by(start) { - if sieve[i] { - sieve[i] = false; - } + result.push(start); // Collect prime numbers up to `end` + for i in (start * start..=num).step_by(start) { + sieve[i] = false; } } - start += 1; - } - for (i, item) in sieve.iter().enumerate().take(num + 1).skip(end + 1) { - if *item { - result.push(i) - } } - result +} + +/// Extracts remaining prime numbers from the sieve beyond the given start index. +/// +/// # Arguments +/// +/// * `sieve` - A slice of booleans representing the sieve with non-prime numbers marked as false. +/// * `start` - The index to start checking for primes (inclusive). +/// +/// # Returns +/// +/// A vector containing all remaining prime numbers extracted from the sieve. +fn extract_remaining_primes(sieve: &[bool], start: usize) -> Vec { + sieve[start..] + .iter() + .enumerate() + .filter_map(|(i, &is_prime)| if is_prime { Some(start + i) } else { None }) + .collect() } #[cfg(test)] mod tests { use super::*; - #[test] - fn basic() { - assert_eq!(sieve_of_eratosthenes(0), vec![]); - assert_eq!(sieve_of_eratosthenes(11), vec![2, 3, 5, 7, 11]); - assert_eq!( - sieve_of_eratosthenes(25), - vec![2, 3, 5, 7, 11, 13, 17, 19, 23] - ); - assert_eq!( - sieve_of_eratosthenes(33), - vec![2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31] - ); - assert_eq!( - sieve_of_eratosthenes(100), - vec![ - 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, - 83, 89, 97 - ] - ); + const PRIMES_UP_TO_997: [usize; 168] = [ + 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, + 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, + 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, + 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, + 401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, + 509, 521, 523, 541, 547, 557, 563, 569, 571, 577, 587, 593, 599, 601, 607, 613, 617, 619, + 631, 641, 643, 647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733, 739, 743, + 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823, 827, 829, 839, 853, 857, 859, 863, + 877, 881, 883, 887, 907, 911, 919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997, + ]; + + macro_rules! sieve_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let input: usize = $test_case; + let expected: Vec = PRIMES_UP_TO_997.iter().cloned().filter(|&x| x <= input).collect(); + assert_eq!(sieve_of_eratosthenes(input), expected); + } + )* + } + } + + sieve_tests! { + test_0: 0, + test_1: 1, + test_2: 2, + test_3: 3, + test_4: 4, + test_5: 5, + test_6: 6, + test_7: 7, + test_11: 11, + test_23: 23, + test_24: 24, + test_25: 25, + test_26: 26, + test_27: 27, + test_28: 28, + test_29: 29, + test_33: 33, + test_100: 100, + test_997: 997, + test_998: 998, + test_999: 999, + test_1000: 1000, } } diff --git a/src/math/sigmoid.rs b/src/math/sigmoid.rs new file mode 100644 index 00000000000..bee6c0c6cb7 --- /dev/null +++ b/src/math/sigmoid.rs @@ -0,0 +1,34 @@ +//Rust implementation of the Sigmoid activation function. +//The formula for Sigmoid: 1 / (1 + e^(-x)) +//More information on the concepts of Sigmoid can be found here: +//https://en.wikipedia.org/wiki/Sigmoid_function + +//The function below takes a reference to a mutable Vector as an argument +//and returns the vector with 'Sigmoid' applied to all values. +//Of course, these functions can be changed by the developer so that the input vector isn't manipulated. +//This is simply an implemenation of the formula. + +use std::f32::consts::E; + +pub fn sigmoid(array: &mut Vec) -> &mut Vec { + //note that these calculations are assuming the Vector values consists of real numbers in radians + for value in &mut *array { + *value = 1. / (1. + E.powf(-1. * *value)); + } + + array +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sigmoid() { + let mut test = Vec::from([1.0, 0.5, -1.0, 0.0, 0.3]); + assert_eq!( + sigmoid(&mut test), + &mut Vec::::from([0.7310586, 0.62245935, 0.26894143, 0.5, 0.5744425,]) + ); + } +} diff --git a/src/math/signum.rs b/src/math/signum.rs new file mode 100644 index 00000000000..83b7e7b5538 --- /dev/null +++ b/src/math/signum.rs @@ -0,0 +1,36 @@ +/// Signum function is a mathematical function that extracts +/// the sign of a real number. It is also known as the sign function, +/// and it is an odd piecewise function. +/// If a number is negative, i.e. it is less than zero, then sgn(x) = -1 +/// If a number is zero, then sgn(0) = 0 +/// If a number is positive, i.e. it is greater than zero, then sgn(x) = 1 + +pub fn signum(number: f64) -> i8 { + if number == 0.0 { + return 0; + } else if number > 0.0 { + return 1; + } + + -1 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn positive_integer() { + assert_eq!(signum(15.0), 1); + } + + #[test] + fn negative_integer() { + assert_eq!(signum(-30.0), -1); + } + + #[test] + fn zero() { + assert_eq!(signum(0.0), 0); + } +} diff --git a/src/math/simpson_integration.rs b/src/math/simpson_integration.rs deleted file mode 100644 index a88b00fd9ab..00000000000 --- a/src/math/simpson_integration.rs +++ /dev/null @@ -1,54 +0,0 @@ -// This gives a better approximation than naive approach -// See https://en.wikipedia.org/wiki/Simpson%27s_rule -pub fn simpson_integration f64>( - start: f64, - end: f64, - steps: u64, - function: F, -) -> f64 { - let mut result = function(start) + function(end); - let step = (end - start) / steps as f64; - for i in 1..steps { - let x = start + step * i as f64; - match i % 2 { - 0 => result += function(x) * 2.0, - 1 => result += function(x) * 4.0, - _ => unreachable!(), - } - } - result *= step / 3.0; - result -} - -#[cfg(test)] -mod tests { - - use super::*; - const EPSILON: f64 = 1e-9; - - fn almost_equal(a: f64, b: f64, eps: f64) -> bool { - (a - b).abs() < eps - } - - #[test] - fn parabola_curve_length() { - // Calculate the length of the curve f(x) = x^2 for -5 <= x <= 5 - // We should integrate sqrt(1 + (f'(x))^2) - let function = |x: f64| -> f64 { (1.0 + 4.0 * x * x).sqrt() }; - let result = simpson_integration(-5.0, 5.0, 1_000, function); - let integrated = |x: f64| -> f64 { (x * function(x) / 2.0) + ((2.0 * x).asinh() / 4.0) }; - let expected = integrated(5.0) - integrated(-5.0); - assert!(almost_equal(result, expected, EPSILON)); - } - - #[test] - fn area_under_cosine() { - use std::f64::consts::PI; - // Calculate area under f(x) = cos(x) + 5 for -pi <= x <= pi - // cosine should cancel out and the answer should be 2pi * 5 - let function = |x: f64| -> f64 { x.cos() + 5.0 }; - let result = simpson_integration(-PI, PI, 1_000, function); - let expected = 2.0 * PI * 5.0; - assert!(almost_equal(result, expected, EPSILON)); - } -} diff --git a/src/math/simpsons_integration.rs b/src/math/simpsons_integration.rs new file mode 100644 index 00000000000..57b173a136b --- /dev/null +++ b/src/math/simpsons_integration.rs @@ -0,0 +1,127 @@ +pub fn simpsons_integration(f: F, a: f64, b: f64, n: usize) -> f64 +where + F: Fn(f64) -> f64, +{ + let h = (b - a) / n as f64; + (0..n) + .map(|i| { + let x0 = a + i as f64 * h; + let x1 = x0 + h / 2.0; + let x2 = x0 + h; + (h / 6.0) * (f(x0) + 4.0 * f(x1) + f(x2)) + }) + .sum() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_simpsons_integration() { + let f = |x: f64| x.powi(2); + let a = 0.0; + let b = 1.0; + let n = 100; + let result = simpsons_integration(f, a, b, n); + assert!((result - 1.0 / 3.0).abs() < 1e-6); + } + + #[test] + fn test_error() { + let f = |x: f64| x.powi(2); + let a = 0.0; + let b = 1.0; + let n = 100; + let result = simpsons_integration(f, a, b, n); + let error = (1.0 / 3.0 - result).abs(); + assert!(error < 1e-6); + } + + #[test] + fn test_convergence() { + let f = |x: f64| x.powi(2); + let a = 0.0; + let b = 1.0; + let n = 100; + let result1 = simpsons_integration(f, a, b, n); + let result2 = simpsons_integration(f, a, b, 2 * n); + let result3 = simpsons_integration(f, a, b, 4 * n); + let result4 = simpsons_integration(f, a, b, 8 * n); + assert!((result1 - result2).abs() < 1e-6); + assert!((result2 - result3).abs() < 1e-6); + assert!((result3 - result4).abs() < 1e-6); + } + + #[test] + fn test_negative() { + let f = |x: f64| -x.powi(2); + let a = 0.0; + let b = 1.0; + let n = 100; + let result = simpsons_integration(f, a, b, n); + assert!((result + 1.0 / 3.0).abs() < 1e-6); + } + + #[test] + fn test_non_zero_lower_bound() { + let f = |x: f64| x.powi(2); + let a = 1.0; + let b = 2.0; + let n = 100; + let result = simpsons_integration(f, a, b, n); + assert!((result - 7.0 / 3.0).abs() < 1e-6); + } + + #[test] + fn test_non_zero_upper_bound() { + let f = |x: f64| x.powi(2); + let a = 0.0; + let b = 2.0; + let n = 100; + let result = simpsons_integration(f, a, b, n); + assert!((result - 8.0 / 3.0).abs() < 1e-6); + } + + #[test] + fn test_non_zero_lower_and_upper_bound() { + let f = |x: f64| x.powi(2); + let a = 1.0; + let b = 2.0; + let n = 100; + let result = simpsons_integration(f, a, b, n); + assert!((result - 7.0 / 3.0).abs() < 1e-6); + } + + #[test] + fn test_non_zero_lower_and_upper_bound_negative() { + let f = |x: f64| -x.powi(2); + let a = 1.0; + let b = 2.0; + let n = 100; + let result = simpsons_integration(f, a, b, n); + assert!((result + 7.0 / 3.0).abs() < 1e-6); + } + + #[test] + fn parabola_curve_length() { + // Calculate the length of the curve f(x) = x^2 for -5 <= x <= 5 + // We should integrate sqrt(1 + (f'(x))^2) + let function = |x: f64| -> f64 { (1.0 + 4.0 * x * x).sqrt() }; + let result = simpsons_integration(function, -5.0, 5.0, 1_000); + let integrated = |x: f64| -> f64 { (x * function(x) / 2.0) + ((2.0 * x).asinh() / 4.0) }; + let expected = integrated(5.0) - integrated(-5.0); + assert!((result - expected).abs() < 1e-9); + } + + #[test] + fn area_under_cosine() { + use std::f64::consts::PI; + // Calculate area under f(x) = cos(x) + 5 for -pi <= x <= pi + // cosine should cancel out and the answer should be 2pi * 5 + let function = |x: f64| -> f64 { x.cos() + 5.0 }; + let result = simpsons_integration(function, -PI, PI, 1_000); + let expected = 2.0 * PI * 5.0; + assert!((result - expected).abs() < 1e-9); + } +} diff --git a/src/math/softmax.rs b/src/math/softmax.rs new file mode 100644 index 00000000000..582bf452ef5 --- /dev/null +++ b/src/math/softmax.rs @@ -0,0 +1,56 @@ +//! # Softmax Function +//! +//! The `softmax` function computes the softmax values of a given array of f32 numbers. +//! +//! The softmax operation is often used in machine learning for converting a vector of real numbers into a +//! probability distribution. It exponentiates each element in the input array, and then normalizes the +//! results so that they sum to 1. +//! +//! ## Formula +//! +//! For a given input array `x`, the softmax function computes the output `y` as follows: +//! +//! `y_i = e^(x_i) / sum(e^(x_j) for all j)` +//! +//! ## Softmax Function Implementation +//! +//! This implementation uses the `std::f32::consts::E` constant for the base of the exponential function. and +//! f32 vectors to compute the values. The function creates a new vector and not altering the input vector. +//! +use std::f32::consts::E; + +pub fn softmax(array: Vec) -> Vec { + let mut softmax_array = array; + + for value in &mut softmax_array { + *value = E.powf(*value); + } + + let sum: f32 = softmax_array.iter().sum(); + + for value in &mut softmax_array { + *value /= sum; + } + + softmax_array +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_softmax() { + let test = vec![9.0, 0.5, -3.0, 0.0, 3.0]; + assert_eq!( + softmax(test), + vec![ + 0.9971961, + 0.00020289792, + 6.126987e-6, + 0.00012306382, + 0.0024718025 + ] + ); + } +} diff --git a/src/math/sprague_grundy_theorem.rs b/src/math/sprague_grundy_theorem.rs new file mode 100644 index 00000000000..4006d047ba6 --- /dev/null +++ b/src/math/sprague_grundy_theorem.rs @@ -0,0 +1,70 @@ +/** + * Sprague Grundy Theorem for combinatorial games like Nim + * + * The Sprague Grundy Theorem is a fundamental concept in combinatorial game theory, commonly used to analyze + * games like Nim. It calculates the Grundy number (also known as the nimber) for a position in a game. + * The Grundy number represents the game's position, and it helps determine the winning strategy. + * + * The Grundy number of a terminal state is 0; otherwise, it is recursively defined as the minimum + * excludant (mex) of the Grundy values of possible next states. + * + * For more details on Sprague Grundy Theorem, you can visit:(https://en.wikipedia.org/wiki/Sprague%E2%80%93Grundy_theorem) + * + * Author : [Gyandeep](https://github.com/Gyan172004) + */ + +pub fn calculate_grundy_number( + position: i64, + grundy_numbers: &mut [i64], + possible_moves: &[i64], +) -> i64 { + // Check if we've already calculated the Grundy number for this position. + if grundy_numbers[position as usize] != -1 { + return grundy_numbers[position as usize]; + } + + // Base case: terminal state + if position == 0 { + grundy_numbers[0] = 0; + return 0; + } + + // Calculate Grundy values for possible next states. + let mut next_state_grundy_values: Vec = vec![]; + for move_size in possible_moves.iter() { + if position - move_size >= 0 { + next_state_grundy_values.push(calculate_grundy_number( + position - move_size, + grundy_numbers, + possible_moves, + )); + } + } + + // Sort the Grundy values and find the minimum excludant. + next_state_grundy_values.sort_unstable(); + let mut mex: i64 = 0; + for grundy_value in next_state_grundy_values.iter() { + if *grundy_value != mex { + break; + } + mex += 1; + } + + // Store the calculated Grundy number and return it. + grundy_numbers[position as usize] = mex; + mex +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn calculate_grundy_number_test() { + let mut grundy_numbers: Vec = vec![-1; 7]; + let possible_moves: Vec = vec![1, 4]; + calculate_grundy_number(6, &mut grundy_numbers, &possible_moves); + assert_eq!(grundy_numbers, [0, 1, 0, 1, 2, 0, 1]); + } +} diff --git a/src/math/square_pyramidal_numbers.rs b/src/math/square_pyramidal_numbers.rs new file mode 100644 index 00000000000..2a6659e09c8 --- /dev/null +++ b/src/math/square_pyramidal_numbers.rs @@ -0,0 +1,20 @@ +// https://en.wikipedia.org/wiki/Square_pyramidal_number +// 1² + 2² + ... = ... (total) + +pub fn square_pyramidal_number(n: u64) -> u64 { + n * (n + 1) * (2 * n + 1) / 6 +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test0() { + assert_eq!(0, square_pyramidal_number(0)); + assert_eq!(1, square_pyramidal_number(1)); + assert_eq!(5, square_pyramidal_number(2)); + assert_eq!(14, square_pyramidal_number(3)); + } +} diff --git a/src/math/square_root.rs b/src/math/square_root.rs index 908f8d55987..4f858ad90b5 100644 --- a/src/math/square_root.rs +++ b/src/math/square_root.rs @@ -1,5 +1,5 @@ /// squre_root returns the square root -/// of a f64 number using Newtons method +/// of a f64 number using Newton's method pub fn square_root(num: f64) -> f64 { if num < 0.0_f64 { return f64::NAN; @@ -14,12 +14,43 @@ pub fn square_root(num: f64) -> f64 { root } +// fast_inv_sqrt returns an approximation of the inverse square root +// This algorithm was first used in Quake and has been reimplemented in a few other languages +// This crate implements it more thoroughly: https://docs.rs/quake-inverse-sqrt/latest/quake_inverse_sqrt/ +pub fn fast_inv_sqrt(num: f32) -> f32 { + // If you are confident in your input this can be removed for speed + if num < 0.0f32 { + return f32::NAN; + } + + let i = num.to_bits(); + let i = 0x5f3759df - (i >> 1); + let y = f32::from_bits(i); + + println!("num: {:?}, out: {:?}", num, y * (1.5 - 0.5 * num * y * y)); + // First iteration of Newton's approximation + y * (1.5 - 0.5 * num * y * y) + // The above can be repeated for more precision +} + #[cfg(test)] mod tests { use super::*; #[test] - fn test() { + fn test_fast_inv_sqrt() { + // Negatives don't have square roots: + assert!(fast_inv_sqrt(-1.0f32).is_nan()); + + // Test a few cases, expect less than 1% error: + let test_pairs = [(4.0, 0.5), (16.0, 0.25), (25.0, 0.2)]; + for pair in test_pairs { + assert!((fast_inv_sqrt(pair.0) - pair.1).abs() <= (0.01 * pair.0)); + } + } + + #[test] + fn test_sqare_root() { assert!((square_root(4.0_f64) - 2.0_f64).abs() <= 1e-10_f64); assert!(square_root(-4.0_f64).is_nan()); } diff --git a/src/math/sum_of_digits.rs b/src/math/sum_of_digits.rs new file mode 100644 index 00000000000..1da42ff20d9 --- /dev/null +++ b/src/math/sum_of_digits.rs @@ -0,0 +1,116 @@ +/// Iteratively sums the digits of a signed integer +/// +/// ## Arguments +/// +/// * `num` - The number to sum the digits of +/// +/// ## Examples +/// +/// ``` +/// use the_algorithms_rust::math::sum_digits_iterative; +/// +/// assert_eq!(10, sum_digits_iterative(1234)); +/// assert_eq!(12, sum_digits_iterative(-246)); +/// ``` +pub fn sum_digits_iterative(num: i32) -> u32 { + // convert to unsigned integer + let mut num = num.unsigned_abs(); + // initialize sum + let mut result: u32 = 0; + + // iterate through digits + while num > 0 { + // extract next digit and add to sum + result += num % 10; + num /= 10; // chop off last digit + } + result +} + +/// Recursively sums the digits of a signed integer +/// +/// ## Arguments +/// +/// * `num` - The number to sum the digits of +/// +/// ## Examples +/// +/// ``` +/// use the_algorithms_rust::math::sum_digits_recursive; +/// +/// assert_eq!(10, sum_digits_recursive(1234)); +/// assert_eq!(12, sum_digits_recursive(-246)); +/// ``` +pub fn sum_digits_recursive(num: i32) -> u32 { + // convert to unsigned integer + let num = num.unsigned_abs(); + // base case + if num < 10 { + return num; + } + // recursive case: add last digit to sum of remaining digits + num % 10 + sum_digits_recursive((num / 10) as i32) +} + +#[cfg(test)] +mod tests { + mod iterative { + // import relevant sum_digits function + use super::super::sum_digits_iterative as sum_digits; + + #[test] + fn zero() { + assert_eq!(0, sum_digits(0)); + } + #[test] + fn positive_number() { + assert_eq!(1, sum_digits(1)); + assert_eq!(10, sum_digits(1234)); + assert_eq!(14, sum_digits(42161)); + assert_eq!(6, sum_digits(500010)); + } + #[test] + fn negative_number() { + assert_eq!(1, sum_digits(-1)); + assert_eq!(12, sum_digits(-246)); + assert_eq!(2, sum_digits(-11)); + assert_eq!(14, sum_digits(-42161)); + assert_eq!(6, sum_digits(-500010)); + } + #[test] + fn trailing_zeros() { + assert_eq!(1, sum_digits(1000000000)); + assert_eq!(3, sum_digits(300)); + } + } + + mod recursive { + // import relevant sum_digits function + use super::super::sum_digits_recursive as sum_digits; + + #[test] + fn zero() { + assert_eq!(0, sum_digits(0)); + } + #[test] + fn positive_number() { + assert_eq!(1, sum_digits(1)); + assert_eq!(10, sum_digits(1234)); + assert_eq!(14, sum_digits(42161)); + assert_eq!(6, sum_digits(500010)); + } + #[test] + fn negative_number() { + assert_eq!(1, sum_digits(-1)); + assert_eq!(12, sum_digits(-246)); + assert_eq!(2, sum_digits(-11)); + assert_eq!(14, sum_digits(-42161)); + assert_eq!(6, sum_digits(-500010)); + } + #[test] + fn trailing_zeros() { + assert_eq!(1, sum_digits(1000000000)); + assert_eq!(3, sum_digits(300)); + } + } +} diff --git a/src/math/sum_of_geometric_progression.rs b/src/math/sum_of_geometric_progression.rs new file mode 100644 index 00000000000..30401a3c2d2 --- /dev/null +++ b/src/math/sum_of_geometric_progression.rs @@ -0,0 +1,37 @@ +// Author : cyrixninja +// Find the Sum of Geometric Progression +// Wikipedia: https://en.wikipedia.org/wiki/Geometric_progression + +pub fn sum_of_geometric_progression(first_term: f64, common_ratio: f64, num_of_terms: i32) -> f64 { + if common_ratio == 1.0 { + // Formula for sum if the common ratio is 1 + return (num_of_terms as f64) * first_term; + } + + // Formula for finding the sum of n terms of a Geometric Progression + (first_term / (1.0 - common_ratio)) * (1.0 - common_ratio.powi(num_of_terms)) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_sum_of_geometric_progression { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (first_term, common_ratio, num_of_terms, expected) = $inputs; + assert_eq!(sum_of_geometric_progression(first_term, common_ratio, num_of_terms), expected); + } + )* + } + } + + test_sum_of_geometric_progression! { + regular_input_0: (1.0, 2.0, 10, 1023.0), + regular_input_1: (1.0, 10.0, 5, 11111.0), + regular_input_2: (9.0, 2.5, 5, 579.9375), + common_ratio_one: (10.0, 1.0, 3, 30.0), + } +} diff --git a/src/math/sum_of_harmonic_series.rs b/src/math/sum_of_harmonic_series.rs new file mode 100644 index 00000000000..553ba23e9c6 --- /dev/null +++ b/src/math/sum_of_harmonic_series.rs @@ -0,0 +1,40 @@ +// Author : cyrixninja +// Sum of Harmonic Series : Find the sum of n terms in an harmonic progression. The calculation starts with the +// first_term and loops adding the common difference of Arithmetic Progression by which +// the given Harmonic Progression is linked. +// Wikipedia Reference : https://en.wikipedia.org/wiki/Interquartile_range +// Other References : https://the-algorithms.com/algorithm/sum-of-harmonic-series?lang=python + +pub fn sum_of_harmonic_progression( + first_term: f64, + common_difference: f64, + number_of_terms: i32, +) -> f64 { + let mut arithmetic_progression = vec![1.0 / first_term]; + let mut current_term = 1.0 / first_term; + + for _ in 0..(number_of_terms - 1) { + current_term += common_difference; + arithmetic_progression.push(current_term); + } + + let harmonic_series: Vec = arithmetic_progression + .into_iter() + .map(|step| 1.0 / step) + .collect(); + harmonic_series.iter().sum() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sum_of_harmonic_progression() { + assert_eq!(sum_of_harmonic_progression(1.0 / 2.0, 2.0, 2), 0.75); + assert_eq!( + sum_of_harmonic_progression(1.0 / 5.0, 5.0, 5), + 0.45666666666666667 + ); + } +} diff --git a/src/math/sylvester_sequence.rs b/src/math/sylvester_sequence.rs new file mode 100644 index 00000000000..7ce7e4534d0 --- /dev/null +++ b/src/math/sylvester_sequence.rs @@ -0,0 +1,33 @@ +// Author : cyrixninja +// Sylvester Series : Calculates the nth number in Sylvester's sequence. +// Wikipedia Reference : https://en.wikipedia.org/wiki/Sylvester%27s_sequence +// Other References : https://the-algorithms.com/algorithm/sylvester-sequence?lang=python + +pub fn sylvester(number: i32) -> i128 { + assert!(number > 0, "The input value of [n={number}] has to be > 0"); + + if number == 1 { + 2 + } else { + let num = sylvester(number - 1); + let lower = num - 1; + let upper = num; + lower * upper + 1 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sylvester() { + assert_eq!(sylvester(8), 113423713055421844361000443_i128); + } + + #[test] + #[should_panic(expected = "The input value of [n=-1] has to be > 0")] + fn test_sylvester_negative() { + sylvester(-1); + } +} diff --git a/src/math/tanh.rs b/src/math/tanh.rs new file mode 100644 index 00000000000..e7a9a785366 --- /dev/null +++ b/src/math/tanh.rs @@ -0,0 +1,34 @@ +//Rust implementation of the Tanh (hyperbolic tangent) activation function. +//The formula for Tanh: (e^x - e^(-x))/(e^x + e^(-x)) OR (2/(1+e^(-2x))-1 +//More information on the concepts of Sigmoid can be found here: +//https://en.wikipedia.org/wiki/Hyperbolic_functions + +//The function below takes a reference to a mutable Vector as an argument +//and returns the vector with 'Tanh' applied to all values. +//Of course, these functions can be changed by the developer so that the input vector isn't manipulated. +//This is simply an implemenation of the formula. + +use std::f32::consts::E; + +pub fn tanh(array: &mut Vec) -> &mut Vec { + //note that these calculations are assuming the Vector values consists of real numbers in radians + for value in &mut *array { + *value = (2. / (1. + E.powf(-2. * *value))) - 1.; + } + + array +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tanh() { + let mut test = Vec::from([1.0, 0.5, -1.0, 0.0, 0.3]); + assert_eq!( + tanh(&mut test), + &mut Vec::::from([0.76159406, 0.4621172, -0.7615941, 0.0, 0.29131258,]) + ); + } +} diff --git a/src/math/trapezoidal_integration.rs b/src/math/trapezoidal_integration.rs new file mode 100644 index 00000000000..f9cda7088c5 --- /dev/null +++ b/src/math/trapezoidal_integration.rs @@ -0,0 +1,42 @@ +pub fn trapezoidal_integral(a: f64, b: f64, f: F, precision: u32) -> f64 +where + F: Fn(f64) -> f64, +{ + let delta = (b - a) / precision as f64; + + (0..precision) + .map(|trapezoid| { + let left_side = a + (delta * trapezoid as f64); + let right_side = left_side + delta; + + 0.5 * (f(left_side) + f(right_side)) * delta + }) + .sum() +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_trapezoidal_integral { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (a, b, f, prec, expected, eps) = $inputs; + let actual = trapezoidal_integral(a, b, f, prec); + assert!((actual - expected).abs() < eps); + } + )* + } + } + + test_trapezoidal_integral! { + basic_0: (0.0, 1.0, |x: f64| x.powi(2), 1000, 1.0/3.0, 0.0001), + basic_0_higher_prec: (0.0, 1.0, |x: f64| x.powi(2), 10000, 1.0/3.0, 0.00001), + basic_1: (-1.0, 1.0, |x: f64| x.powi(2), 10000, 2.0/3.0, 0.00001), + basic_1_higher_prec: (-1.0, 1.0, |x: f64| x.powi(2), 100000, 2.0/3.0, 0.000001), + flipped_limits: (1.0, 0.0, |x: f64| x.powi(2), 10000, -1.0/3.0, 0.00001), + empty_range: (0.5, 0.5, |x: f64| x.powi(2), 100, 0.0, 0.0000001), + } +} diff --git a/src/math/trial_division.rs b/src/math/trial_division.rs index f38d990cea6..882e6f72262 100644 --- a/src/math/trial_division.rs +++ b/src/math/trial_division.rs @@ -8,8 +8,15 @@ fn double_to_int(amount: f64) -> i128 { } pub fn trial_division(mut num: i128) -> Vec { + if num < 0 { + return trial_division(-num); + } let mut result: Vec = vec![]; + if num == 0 { + return result; + } + while num % 2 == 0 { result.push(2); num /= 2; @@ -39,7 +46,10 @@ mod tests { #[test] fn basic() { + assert_eq!(trial_division(0), vec![]); + assert_eq!(trial_division(1), vec![]); assert_eq!(trial_division(9), vec!(3, 3)); + assert_eq!(trial_division(-9), vec!(3, 3)); assert_eq!(trial_division(10), vec!(2, 5)); assert_eq!(trial_division(11), vec!(11)); assert_eq!(trial_division(33), vec!(3, 11)); diff --git a/src/math/trig_functions.rs b/src/math/trig_functions.rs new file mode 100644 index 00000000000..e18b5154029 --- /dev/null +++ b/src/math/trig_functions.rs @@ -0,0 +1,263 @@ +/// Function that contains the similarities of the sine and cosine implementations +/// +/// Both of them are calculated using their MacLaurin Series +/// +/// Because there is just a '+1' that differs in their formula, this function has been +/// created for not repeating +fn template>(x: T, tol: f64, kind: i32) -> f64 { + use std::f64::consts::PI; + const PERIOD: f64 = 2.0 * PI; + /* Sometimes, this function is called for a big 'n'(when tol is very small) */ + fn factorial(n: i128) -> i128 { + (1..=n).product() + } + + /* Function to round up to the 'decimal'th decimal of the number 'x' */ + fn round_up_to_decimal(x: f64, decimal: i32) -> f64 { + let multiplier = 10f64.powi(decimal); + (x * multiplier).round() / multiplier + } + + let mut value: f64 = x.into(); //<-- This is the line for which the trait 'Into' is required + + /* Check for invalid arguments */ + if !value.is_finite() || value.is_nan() { + eprintln!("This function does not accept invalid arguments."); + return f64::NAN; + } + + /* + The argument to sine could be bigger than the sine's PERIOD + To prevent overflowing, strip the value off relative to the PERIOD + */ + while value >= PERIOD { + value -= PERIOD; + } + /* For cases when the value is smaller than the -PERIOD (e.g. sin(-3π) <=> sin(-π)) */ + while value <= -PERIOD { + value += PERIOD; + } + + let mut rez = 0f64; + let mut prev_rez = 1f64; + let mut step: i32 = 0; + /* + This while instruction is the MacLaurin Series for sine / cosine + sin(x) = Σ (-1)^n * x^2n+1 / (2n+1)!, for n >= 0 and x a Real number + cos(x) = Σ (-1)^n * x^2n / (2n)!, for n >= 0 and x a Real number + + '+1' in sine's formula is replaced with 'kind', which values are: + -> kind = 0, for cosine + -> kind = 1, for sine + */ + while (prev_rez - rez).abs() > tol { + prev_rez = rez; + rez += (-1f64).powi(step) * value.powi(2 * step + kind) + / factorial((2 * step + kind) as i128) as f64; + step += 1; + } + + /* Round up to the 6th decimal */ + round_up_to_decimal(rez, 6) +} + +/// Returns the value of sin(x), approximated with the given tolerance +/// +/// This function supposes the argument is in radians +/// +/// ### Example +/// +/// sin(1) == sin(1 rad) == sin(π/180) +pub fn sine>(x: T, tol: f64) -> f64 { + template(x, tol, 1) +} + +/// Returns the value of cos, approximated with the given tolerance, for +/// an angle 'x' in radians +pub fn cosine>(x: T, tol: f64) -> f64 { + template(x, tol, 0) +} + +/// Cosine of 'x' in degrees, with the given tolerance +pub fn cosine_no_radian_arg>(x: T, tol: f64) -> f64 { + use std::f64::consts::PI; + let val: f64 = x.into(); + cosine(val * PI / 180., tol) +} + +/// Sine function for non radian angle +/// +/// Interprets the argument in degrees, not in radians +/// +/// ### Example +/// +/// sin(1o) != \[ sin(1 rad) == sin(π/180) \] +pub fn sine_no_radian_arg>(x: T, tol: f64) -> f64 { + use std::f64::consts::PI; + let val: f64 = x.into(); + sine(val * PI / 180f64, tol) +} + +/// Tangent of angle 'x' in radians, calculated with the given tolerance +pub fn tan + Copy>(x: T, tol: f64) -> f64 { + let cos_val = cosine(x, tol); + + /* Cover special cases for division */ + if cos_val != 0f64 { + let sin_val = sine(x, tol); + sin_val / cos_val + } else { + f64::NAN + } +} + +/// Cotangent of angle 'x' in radians, calculated with the given tolerance +pub fn cotan + Copy>(x: T, tol: f64) -> f64 { + let sin_val = sine(x, tol); + + /* Cover special cases for division */ + if sin_val != 0f64 { + let cos_val = cosine(x, tol); + cos_val / sin_val + } else { + f64::NAN + } +} + +/// Tangent of 'x' in degrees, approximated with the given tolerance +pub fn tan_no_radian_arg + Copy>(x: T, tol: f64) -> f64 { + let angle: f64 = x.into(); + + use std::f64::consts::PI; + tan(angle * PI / 180., tol) +} + +/// Cotangent of 'x' in degrees, approximated with the given tolerance +pub fn cotan_no_radian_arg + Copy>(x: T, tol: f64) -> f64 { + let angle: f64 = x.into(); + + use std::f64::consts::PI; + cotan(angle * PI / 180., tol) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::f64::consts::PI; + + enum TrigFuncType { + Sine, + Cosine, + Tan, + Cotan, + } + + const TOL: f64 = 1e-10; + + impl TrigFuncType { + fn verify + Copy>(&self, angle: T, expected_result: f64, is_radian: bool) { + let value = match self { + TrigFuncType::Sine => { + if is_radian { + sine(angle, TOL) + } else { + sine_no_radian_arg(angle, TOL) + } + } + TrigFuncType::Cosine => { + if is_radian { + cosine(angle, TOL) + } else { + cosine_no_radian_arg(angle, TOL) + } + } + TrigFuncType::Tan => { + if is_radian { + tan(angle, TOL) + } else { + tan_no_radian_arg(angle, TOL) + } + } + TrigFuncType::Cotan => { + if is_radian { + cotan(angle, TOL) + } else { + cotan_no_radian_arg(angle, TOL) + } + } + }; + + assert_eq!(format!("{value:.5}"), format!("{:.5}", expected_result)); + } + } + + #[test] + fn test_sine() { + let sine_id = TrigFuncType::Sine; + sine_id.verify(0.0, 0.0, true); + sine_id.verify(-PI, 0.0, true); + sine_id.verify(-PI / 2.0, -1.0, true); + sine_id.verify(0.5, 0.4794255386, true); + /* Same tests, but angle is now in degrees */ + sine_id.verify(0, 0.0, false); + sine_id.verify(-180, 0.0, false); + sine_id.verify(-180 / 2, -1.0, false); + sine_id.verify(0.5, 0.00872654, false); + } + + #[test] + fn test_sine_bad_arg() { + assert!(sine(f64::NEG_INFINITY, 1e-1).is_nan()); + assert!(sine_no_radian_arg(f64::NAN, 1e-1).is_nan()); + } + + #[test] + fn test_cosine_bad_arg() { + assert!(cosine(f64::INFINITY, 1e-1).is_nan()); + assert!(cosine_no_radian_arg(f64::NAN, 1e-1).is_nan()); + } + + #[test] + fn test_cosine() { + let cosine_id = TrigFuncType::Cosine; + cosine_id.verify(0, 1., true); + cosine_id.verify(0, 1., false); + cosine_id.verify(45, 1. / f64::sqrt(2.), false); + cosine_id.verify(PI / 4., 1. / f64::sqrt(2.), true); + cosine_id.verify(360, 1., false); + cosine_id.verify(2. * PI, 1., true); + cosine_id.verify(15. * PI / 2., 0.0, true); + cosine_id.verify(-855, -1. / f64::sqrt(2.), false); + } + + #[test] + fn test_tan_bad_arg() { + assert!(tan(PI / 2., TOL).is_nan()); + assert!(tan(3. * PI / 2., TOL).is_nan()); + } + + #[test] + fn test_tan() { + let tan_id = TrigFuncType::Tan; + tan_id.verify(PI / 4., 1f64, true); + tan_id.verify(45, 1f64, false); + tan_id.verify(PI, 0f64, true); + tan_id.verify(180 + 45, 1f64, false); + tan_id.verify(60 - 2 * 180, 1.7320508075, false); + tan_id.verify(30 + 180 - 180, 0.57735026919, false); + } + + #[test] + fn test_cotan_bad_arg() { + assert!(cotan(tan(PI / 2., TOL), TOL).is_nan()); + assert!(!cotan(0, TOL).is_finite()); + } + + #[test] + fn test_cotan() { + let cotan_id = TrigFuncType::Cotan; + cotan_id.verify(PI / 4., 1f64, true); + cotan_id.verify(90 + 10 * 180, 0f64, false); + cotan_id.verify(30 - 5 * 180, f64::sqrt(3.), false); + } +} diff --git a/src/math/vector_cross_product.rs b/src/math/vector_cross_product.rs new file mode 100644 index 00000000000..582470d2bc7 --- /dev/null +++ b/src/math/vector_cross_product.rs @@ -0,0 +1,98 @@ +/// Cross Product and Magnitude Calculation +/// +/// This program defines functions to calculate the cross product of two 3D vectors +/// and the magnitude of a vector from its direction ratios. The main purpose is +/// to demonstrate the mathematical concepts and provide test cases for the functions. +/// +/// Time Complexity: +/// - Calculating the cross product and magnitude of a vector each takes O(1) time +/// since we are working with fixed-size arrays and performing a fixed number of +/// mathematical operations. + +/// Function to calculate the cross product of two vectors +pub fn cross_product(vec1: [f64; 3], vec2: [f64; 3]) -> [f64; 3] { + let x = vec1[1] * vec2[2] - vec1[2] * vec2[1]; + let y = -(vec1[0] * vec2[2] - vec1[2] * vec2[0]); + let z = vec1[0] * vec2[1] - vec1[1] * vec2[0]; + [x, y, z] +} + +/// Function to calculate the magnitude of a vector +pub fn vector_magnitude(vec: [f64; 3]) -> f64 { + (vec[0].powi(2) + vec[1].powi(2) + vec[2].powi(2)).sqrt() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cross_product_and_magnitude_1() { + // Test case with non-trivial vectors + let vec1 = [1.0, 2.0, 3.0]; + let vec2 = [4.0, 5.0, 6.0]; + + let cross_product = cross_product(vec1, vec2); + let magnitude = vector_magnitude(cross_product); + + // Check the expected results with a tolerance for floating-point comparisons + assert_eq!(cross_product, [-3.0, 6.0, -3.0]); + assert!((magnitude - 7.34847).abs() < 1e-5); + } + + #[test] + fn test_cross_product_and_magnitude_2() { + // Test case with orthogonal vectors + let vec1 = [1.0, 0.0, 0.0]; + let vec2 = [0.0, 1.0, 0.0]; + + let cross_product = cross_product(vec1, vec2); + let magnitude = vector_magnitude(cross_product); + + // Check the expected results + assert_eq!(cross_product, [0.0, 0.0, 1.0]); + assert_eq!(magnitude, 1.0); + } + + #[test] + fn test_cross_product_and_magnitude_3() { + // Test case with vectors along the axes + let vec1 = [2.0, 0.0, 0.0]; + let vec2 = [0.0, 3.0, 0.0]; + + let cross_product = cross_product(vec1, vec2); + let magnitude = vector_magnitude(cross_product); + + // Check the expected results + assert_eq!(cross_product, [0.0, 0.0, 6.0]); + assert_eq!(magnitude, 6.0); + } + + #[test] + fn test_cross_product_and_magnitude_4() { + // Test case with parallel vectors + let vec1 = [1.0, 2.0, 3.0]; + let vec2 = [2.0, 4.0, 6.0]; + + let cross_product = cross_product(vec1, vec2); + let magnitude = vector_magnitude(cross_product); + + // Check the expected results + assert_eq!(cross_product, [0.0, 0.0, 0.0]); + assert_eq!(magnitude, 0.0); + } + + #[test] + fn test_cross_product_and_magnitude_5() { + // Test case with zero vectors + let vec1 = [0.0, 0.0, 0.0]; + let vec2 = [0.0, 0.0, 0.0]; + + let cross_product = cross_product(vec1, vec2); + let magnitude = vector_magnitude(cross_product); + + // Check the expected results + assert_eq!(cross_product, [0.0, 0.0, 0.0]); + assert_eq!(magnitude, 0.0); + } +} diff --git a/src/math/zellers_congruence_algorithm.rs b/src/math/zellers_congruence_algorithm.rs index b6bdc2f7f98..43bf49e732f 100644 --- a/src/math/zellers_congruence_algorithm.rs +++ b/src/math/zellers_congruence_algorithm.rs @@ -2,12 +2,11 @@ pub fn zellers_congruence_algorithm(date: i32, month: i32, year: i32, as_string: bool) -> String { let q = date; - let mut m = month; - let mut y = year; - if month < 3 { - m = month + 12; - y = year - 1; - } + let (m, y) = if month < 3 { + (month + 12, year - 1) + } else { + (month, year) + }; let day: i32 = (q + (26 * (m + 1) / 10) + (y % 100) + ((y % 100) / 4) + ((y / 100) / 4) + (5 * (y / 100))) % 7; diff --git a/src/navigation/bearing.rs b/src/navigation/bearing.rs new file mode 100644 index 00000000000..6efec578b26 --- /dev/null +++ b/src/navigation/bearing.rs @@ -0,0 +1,40 @@ +use std::f64::consts::PI; + +pub fn bearing(lat1: f64, lng1: f64, lat2: f64, lng2: f64) -> f64 { + let lat1 = lat1 * PI / 180.0; + let lng1 = lng1 * PI / 180.0; + + let lat2 = lat2 * PI / 180.0; + let lng2 = lng2 * PI / 180.0; + + let delta_longitude = lng2 - lng1; + + let y = delta_longitude.sin() * lat2.cos(); + let x = lat1.cos() * lat2.sin() - lat1.sin() * lat2.cos() * delta_longitude.cos(); + + let mut brng = y.atan2(x); + brng = brng.to_degrees(); + + (brng + 360.0) % 360.0 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn testing() { + assert_eq!( + format!( + "{:.0}º", + bearing( + -27.2020447088982, + -49.631891179172555, + -3.106362, + -60.025826, + ) + ), + "336º" + ); + } +} diff --git a/src/navigation/haversine.rs b/src/navigation/haversine.rs new file mode 100644 index 00000000000..27e61eb535c --- /dev/null +++ b/src/navigation/haversine.rs @@ -0,0 +1,35 @@ +use std::f64::consts::PI; + +const EARTH_RADIUS: f64 = 6371000.00; + +pub fn haversine(lat1: f64, lng1: f64, lat2: f64, lng2: f64) -> f64 { + let delta_dist_lat = (lat2 - lat1) * PI / 180.0; + let delta_dist_lng = (lng2 - lng1) * PI / 180.0; + + let cos1 = lat1 * PI / 180.0; + let cos2 = lat2 * PI / 180.0; + + let delta_lat = (delta_dist_lat / 2.0).sin().powf(2.0); + let delta_lng = (delta_dist_lng / 2.0).sin().powf(2.0); + + let a = delta_lat + delta_lng * cos1.cos() * cos2.cos(); + let result = 2.0 * a.asin().sqrt(); + + result * EARTH_RADIUS +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn testing() { + assert_eq!( + format!( + "{:.2}km", + haversine(52.375603, 4.903206, 52.366059, 4.926692) / 1000.0 + ), + "1.92km" + ); + } +} diff --git a/src/navigation/mod.rs b/src/navigation/mod.rs new file mode 100644 index 00000000000..e62be90acbc --- /dev/null +++ b/src/navigation/mod.rs @@ -0,0 +1,5 @@ +mod bearing; +mod haversine; + +pub use self::bearing::bearing; +pub use self::haversine::haversine; diff --git a/src/number_theory/compute_totient.rs b/src/number_theory/compute_totient.rs new file mode 100644 index 00000000000..88af0649fcd --- /dev/null +++ b/src/number_theory/compute_totient.rs @@ -0,0 +1,60 @@ +// Totient function for +// all numbers smaller than +// or equal to n. + +// Computes and prints +// totient of all numbers +// smaller than or equal to n + +use std::vec; + +pub fn compute_totient(n: i32) -> vec::Vec { + let mut phi: Vec = Vec::new(); + + // initialize phi[i] = i + for i in 0..=n { + phi.push(i); + } + + // Compute other Phi values + for p in 2..=n { + // If phi[p] is not computed already, + // then number p is prime + if phi[(p) as usize] == p { + // Phi of a prime number p is + // always equal to p-1. + phi[(p) as usize] = p - 1; + + // Update phi values of all + // multiples of p + for i in ((2 * p)..=n).step_by(p as usize) { + phi[(i) as usize] = (phi[i as usize] / p) * (p - 1); + } + } + } + + phi[1..].to_vec() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_1() { + assert_eq!( + compute_totient(12), + vec![1, 1, 2, 2, 4, 2, 6, 4, 6, 4, 10, 4] + ); + } + + #[test] + fn test_2() { + assert_eq!(compute_totient(7), vec![1, 1, 2, 2, 4, 2, 6]); + } + + #[test] + fn test_3() { + assert_eq!(compute_totient(4), vec![1, 1, 2, 2]); + } +} diff --git a/src/number_theory/euler_totient.rs b/src/number_theory/euler_totient.rs new file mode 100644 index 00000000000..69c0694a335 --- /dev/null +++ b/src/number_theory/euler_totient.rs @@ -0,0 +1,74 @@ +pub fn euler_totient(n: u64) -> u64 { + let mut result = n; + let mut num = n; + let mut p = 2; + + // Find all prime factors and apply formula + while p * p <= num { + // Check if p is a divisor of n + if num % p == 0 { + // If yes, then it is a prime factor + // Apply the formula: result = result * (1 - 1/p) + while num % p == 0 { + num /= p; + } + result -= result / p; + } + p += 1; + } + + // If num > 1, then it is a prime factor + if num > 1 { + result -= result / num; + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + macro_rules! test_euler_totient { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(euler_totient(input), expected) + } + )* + }; + } + + test_euler_totient! { + prime_2: (2, 1), + prime_3: (3, 2), + prime_5: (5, 4), + prime_7: (7, 6), + prime_11: (11, 10), + prime_13: (13, 12), + prime_17: (17, 16), + prime_19: (19, 18), + + composite_6: (6, 2), // 2 * 3 + composite_10: (10, 4), // 2 * 5 + composite_15: (15, 8), // 3 * 5 + composite_12: (12, 4), // 2^2 * 3 + composite_18: (18, 6), // 2 * 3^2 + composite_20: (20, 8), // 2^2 * 5 + composite_30: (30, 8), // 2 * 3 * 5 + + prime_power_2_to_2: (4, 2), + prime_power_2_to_3: (8, 4), + prime_power_3_to_2: (9, 6), + prime_power_2_to_4: (16, 8), + prime_power_5_to_2: (25, 20), + prime_power_3_to_3: (27, 18), + prime_power_2_to_5: (32, 16), + + // Large numbers + large_50: (50, 20), // 2 * 5^2 + large_100: (100, 40), // 2^2 * 5^2 + large_1000: (1000, 400), // 2^3 * 5^3 + } +} diff --git a/src/number_theory/kth_factor.rs b/src/number_theory/kth_factor.rs new file mode 100644 index 00000000000..abfece86bb9 --- /dev/null +++ b/src/number_theory/kth_factor.rs @@ -0,0 +1,41 @@ +// Kth Factor of N +// The idea is to check for each number in the range [N, 1], and print the Kth number that divides N completely. + +pub fn kth_factor(n: i32, k: i32) -> i32 { + let mut factors: Vec = Vec::new(); + let k = (k as usize) - 1; + for i in 1..=n { + if n % i == 0 { + factors.push(i); + } + if let Some(number) = factors.get(k) { + return *number; + } + } + -1 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_1() { + assert_eq!(kth_factor(12, 3), 3); + } + + #[test] + fn test_2() { + assert_eq!(kth_factor(7, 2), 7); + } + + #[test] + fn test_3() { + assert_eq!(kth_factor(4, 4), -1); + } + + #[test] + fn test_4() { + assert_eq!(kth_factor(950, 5), 19); + } +} diff --git a/src/number_theory/mod.rs b/src/number_theory/mod.rs new file mode 100644 index 00000000000..0500ad775d1 --- /dev/null +++ b/src/number_theory/mod.rs @@ -0,0 +1,7 @@ +mod compute_totient; +mod euler_totient; +mod kth_factor; + +pub use self::compute_totient::compute_totient; +pub use self::euler_totient::euler_totient; +pub use self::kth_factor::kth_factor; diff --git a/src/searching/binary_search.rs b/src/searching/binary_search.rs index 2c822ed59ba..4c64c58217c 100644 --- a/src/searching/binary_search.rs +++ b/src/searching/binary_search.rs @@ -1,106 +1,153 @@ +//! This module provides an implementation of a binary search algorithm that +//! works for both ascending and descending ordered arrays. The binary search +//! function returns the index of the target element if it is found, or `None` +//! if the target is not present in the array. + use std::cmp::Ordering; +/// Performs a binary search for a specified item within a sorted array. +/// +/// This function can handle both ascending and descending ordered arrays. It +/// takes a reference to the item to search for and a slice of the array. If +/// the item is found, it returns the index of the item within the array. If +/// the item is not found, it returns `None`. +/// +/// # Parameters +/// +/// - `item`: A reference to the item to search for. +/// - `arr`: A slice of the sorted array in which to search. +/// +/// # Returns +/// +/// An `Option` which is: +/// - `Some(index)` if the item is found at the given index. +/// - `None` if the item is not found in the array. pub fn binary_search(item: &T, arr: &[T]) -> Option { - let mut is_asc = true; - if arr.len() > 1 { - is_asc = arr[0] < arr[(arr.len() - 1)]; - } + let is_asc = is_asc_arr(arr); + let mut left = 0; let mut right = arr.len(); while left < right { - let mid = left + (right - left) / 2; - - if is_asc { - match item.cmp(&arr[mid]) { - Ordering::Less => right = mid, - Ordering::Equal => return Some(mid), - Ordering::Greater => left = mid + 1, - } - } else { - match item.cmp(&arr[mid]) { - Ordering::Less => left = mid + 1, - Ordering::Equal => return Some(mid), - Ordering::Greater => right = mid, - } + if match_compare(item, arr, &mut left, &mut right, is_asc) { + return Some(left); } } + None } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn empty() { - let index = binary_search(&"a", &vec![]); - assert_eq!(index, None); - } - - #[test] - fn one_item() { - let index = binary_search(&"a", &vec!["a"]); - assert_eq!(index, Some(0)); - } - - #[test] - fn search_strings_asc() { - let index = binary_search(&"a", &vec!["a", "b", "c", "d", "google", "zoo"]); - assert_eq!(index, Some(0)); - - let index = binary_search(&"google", &vec!["a", "b", "c", "d", "google", "zoo"]); - assert_eq!(index, Some(4)); - } - - #[test] - fn search_strings_desc() { - let index = binary_search(&"a", &vec!["zoo", "google", "d", "c", "b", "a"]); - assert_eq!(index, Some(5)); - - let index = binary_search(&"zoo", &vec!["zoo", "google", "d", "c", "b", "a"]); - assert_eq!(index, Some(0)); - - let index = binary_search(&"google", &vec!["zoo", "google", "d", "c", "b", "a"]); - assert_eq!(index, Some(1)); - } - - #[test] - fn search_ints_asc() { - let index = binary_search(&4, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(3)); - - let index = binary_search(&3, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(2)); - - let index = binary_search(&2, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(1)); - - let index = binary_search(&1, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(0)); +/// Compares the item with the middle element of the current search range and +/// updates the search bounds accordingly. This function handles both ascending +/// and descending ordered arrays. It calculates the middle index of the +/// current search range and compares the item with the element at +/// this index. It then updates the search bounds (`left` and `right`) based on +/// the result of this comparison. If the item is found, it updates `left` to +/// the index of the found item and returns `true`. +/// +/// # Parameters +/// +/// - `item`: A reference to the item to search for. +/// - `arr`: A slice of the array in which to search. +/// - `left`: A mutable reference to the left bound of the search range. +/// - `right`: A mutable reference to the right bound of the search range. +/// - `is_asc`: A boolean indicating whether the array is sorted in ascending order. +/// +/// # Returns +/// +/// A `bool` indicating whether the item was found. +fn match_compare( + item: &T, + arr: &[T], + left: &mut usize, + right: &mut usize, + is_asc: bool, +) -> bool { + let mid = *left + (*right - *left) / 2; + let cmp_result = item.cmp(&arr[mid]); + + match (is_asc, cmp_result) { + (true, Ordering::Less) | (false, Ordering::Greater) => { + *right = mid; + } + (true, Ordering::Greater) | (false, Ordering::Less) => { + *left = mid + 1; + } + (_, Ordering::Equal) => { + *left = mid; + return true; + } } - #[test] - fn search_ints_desc() { - let index = binary_search(&4, &vec![4, 3, 2, 1]); - assert_eq!(index, Some(0)); + false +} - let index = binary_search(&3, &vec![4, 3, 2, 1]); - assert_eq!(index, Some(1)); +/// Determines if the given array is sorted in ascending order. +/// +/// This helper function checks if the first element of the array is less than the +/// last element, indicating an ascending order. It returns `false` if the array +/// has fewer than two elements. +/// +/// # Parameters +/// +/// - `arr`: A slice of the array to check. +/// +/// # Returns +/// +/// A `bool` indicating whether the array is sorted in ascending order. +fn is_asc_arr(arr: &[T]) -> bool { + arr.len() > 1 && arr[0] < arr[arr.len() - 1] +} - let index = binary_search(&2, &vec![4, 3, 2, 1]); - assert_eq!(index, Some(2)); +#[cfg(test)] +mod tests { + use super::*; - let index = binary_search(&1, &vec![4, 3, 2, 1]); - assert_eq!(index, Some(3)); + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (item, arr, expected) = $test_case; + assert_eq!(binary_search(&item, arr), expected); + } + )* + }; } - #[test] - fn not_found() { - let index = binary_search(&5, &vec![1, 2, 3, 4]); - assert_eq!(index, None); - - let index = binary_search(&5, &vec![4, 3, 2, 1]); - assert_eq!(index, None); + test_cases! { + empty: ("a", &[] as &[&str], None), + one_item_found: ("a", &["a"], Some(0)), + one_item_not_found: ("b", &["a"], None), + search_strings_asc_start: ("a", &["a", "b", "c", "d", "google", "zoo"], Some(0)), + search_strings_asc_middle: ("google", &["a", "b", "c", "d", "google", "zoo"], Some(4)), + search_strings_asc_last: ("zoo", &["a", "b", "c", "d", "google", "zoo"], Some(5)), + search_strings_asc_not_found: ("x", &["a", "b", "c", "d", "google", "zoo"], None), + search_strings_desc_start: ("zoo", &["zoo", "google", "d", "c", "b", "a"], Some(0)), + search_strings_desc_middle: ("google", &["zoo", "google", "d", "c", "b", "a"], Some(1)), + search_strings_desc_last: ("a", &["zoo", "google", "d", "c", "b", "a"], Some(5)), + search_strings_desc_not_found: ("x", &["zoo", "google", "d", "c", "b", "a"], None), + search_ints_asc_start: (1, &[1, 2, 3, 4], Some(0)), + search_ints_asc_middle: (3, &[1, 2, 3, 4], Some(2)), + search_ints_asc_end: (4, &[1, 2, 3, 4], Some(3)), + search_ints_asc_not_found: (5, &[1, 2, 3, 4], None), + search_ints_desc_start: (4, &[4, 3, 2, 1], Some(0)), + search_ints_desc_middle: (3, &[4, 3, 2, 1], Some(1)), + search_ints_desc_end: (1, &[4, 3, 2, 1], Some(3)), + search_ints_desc_not_found: (5, &[4, 3, 2, 1], None), + with_gaps_0: (0, &[1, 3, 8, 11], None), + with_gaps_1: (1, &[1, 3, 8, 11], Some(0)), + with_gaps_2: (2, &[1, 3, 8, 11], None), + with_gaps_3: (3, &[1, 3, 8, 11], Some(1)), + with_gaps_4: (4, &[1, 3, 8, 10], None), + with_gaps_5: (5, &[1, 3, 8, 10], None), + with_gaps_6: (6, &[1, 3, 8, 10], None), + with_gaps_7: (7, &[1, 3, 8, 11], None), + with_gaps_8: (8, &[1, 3, 8, 11], Some(2)), + with_gaps_9: (9, &[1, 3, 8, 11], None), + with_gaps_10: (10, &[1, 3, 8, 11], None), + with_gaps_11: (11, &[1, 3, 8, 11], Some(3)), + with_gaps_12: (12, &[1, 3, 8, 11], None), + with_gaps_13: (13, &[1, 3, 8, 11], None), } } diff --git a/src/searching/binary_search_recursive.rs b/src/searching/binary_search_recursive.rs index 14740e4800d..e83fa2f48d5 100644 --- a/src/searching/binary_search_recursive.rs +++ b/src/searching/binary_search_recursive.rs @@ -1,31 +1,42 @@ use std::cmp::Ordering; -pub fn binary_search_rec( - list_of_items: &[T], - target: &T, - left: &usize, - right: &usize, -) -> Option { +/// Recursively performs a binary search for a specified item within a sorted array. +/// +/// This function can handle both ascending and descending ordered arrays. It +/// takes a reference to the item to search for and a slice of the array. If +/// the item is found, it returns the index of the item within the array. If +/// the item is not found, it returns `None`. +/// +/// # Parameters +/// +/// - `item`: A reference to the item to search for. +/// - `arr`: A slice of the sorted array in which to search. +/// - `left`: The left bound of the current search range. +/// - `right`: The right bound of the current search range. +/// - `is_asc`: A boolean indicating whether the array is sorted in ascending order. +/// +/// # Returns +/// +/// An `Option` which is: +/// - `Some(index)` if the item is found at the given index. +/// - `None` if the item is not found in the array. +pub fn binary_search_rec(item: &T, arr: &[T], left: usize, right: usize) -> Option { if left >= right { return None; } - let is_asc = list_of_items[0] < list_of_items[list_of_items.len() - 1]; + let is_asc = arr.len() > 1 && arr[0] < arr[arr.len() - 1]; + let mid = left + (right - left) / 2; + let cmp_result = item.cmp(&arr[mid]); - let middle: usize = left + (right - left) / 2; - - if is_asc { - match target.cmp(&list_of_items[middle]) { - Ordering::Less => binary_search_rec(list_of_items, target, left, &middle), - Ordering::Greater => binary_search_rec(list_of_items, target, &(middle + 1), right), - Ordering::Equal => Some(middle), + match (is_asc, cmp_result) { + (true, Ordering::Less) | (false, Ordering::Greater) => { + binary_search_rec(item, arr, left, mid) } - } else { - match target.cmp(&list_of_items[middle]) { - Ordering::Less => binary_search_rec(list_of_items, target, &(middle + 1), right), - Ordering::Greater => binary_search_rec(list_of_items, target, left, &middle), - Ordering::Equal => Some(middle), + (true, Ordering::Greater) | (false, Ordering::Less) => { + binary_search_rec(item, arr, mid + 1, right) } + (_, Ordering::Equal) => Some(mid), } } @@ -33,124 +44,51 @@ pub fn binary_search_rec( mod tests { use super::*; - const LEFT: usize = 0; - - #[test] - fn fail_empty_list() { - let list_of_items = vec![]; - assert_eq!( - binary_search_rec(&list_of_items, &1, &LEFT, &list_of_items.len()), - None - ); - } - - #[test] - fn success_one_item() { - let list_of_items = vec![30]; - assert_eq!( - binary_search_rec(&list_of_items, &30, &LEFT, &list_of_items.len()), - Some(0) - ); - } - - #[test] - fn success_search_strings_asc() { - let say_hello_list = vec!["hi", "olá", "salut"]; - let right = say_hello_list.len(); - assert_eq!( - binary_search_rec(&say_hello_list, &"hi", &LEFT, &right), - Some(0) - ); - assert_eq!( - binary_search_rec(&say_hello_list, &"salut", &LEFT, &right), - Some(2) - ); - } - - #[test] - fn success_search_strings_desc() { - let say_hello_list = vec!["salut", "olá", "hi"]; - let right = say_hello_list.len(); - assert_eq!( - binary_search_rec(&say_hello_list, &"hi", &LEFT, &right), - Some(2) - ); - assert_eq!( - binary_search_rec(&say_hello_list, &"salut", &LEFT, &right), - Some(0) - ); - } - - #[test] - fn fail_search_strings_asc() { - let say_hello_list = vec!["hi", "olá", "salut"]; - for target in &["adiós", "你好"] { - assert_eq!( - binary_search_rec(&say_hello_list, target, &LEFT, &say_hello_list.len()), - None - ); - } - } - - #[test] - fn fail_search_strings_desc() { - let say_hello_list = vec!["salut", "olá", "hi"]; - for target in &["adiós", "你好"] { - assert_eq!( - binary_search_rec(&say_hello_list, target, &LEFT, &say_hello_list.len()), - None - ); - } - } - - #[test] - fn success_search_integers_asc() { - let integers = vec![0, 10, 20, 30, 40, 50, 60, 70, 80, 90]; - for (index, target) in integers.iter().enumerate() { - assert_eq!( - binary_search_rec(&integers, target, &LEFT, &integers.len()), - Some(index) - ) - } - } - - #[test] - fn success_search_integers_desc() { - let integers = vec![90, 80, 70, 60, 50, 40, 30, 20, 10, 0]; - for (index, target) in integers.iter().enumerate() { - assert_eq!( - binary_search_rec(&integers, target, &LEFT, &integers.len()), - Some(index) - ) - } - } - - #[test] - fn fail_search_integers() { - let integers = vec![0, 10, 20, 30, 40, 50, 60, 70, 80, 90]; - for target in &[100, 444, 336] { - assert_eq!( - binary_search_rec(&integers, target, &LEFT, &integers.len()), - None - ); - } - } - - #[test] - fn success_search_string_in_middle_of_unsorted_list() { - let unsorted_strings = vec!["salut", "olá", "hi"]; - assert_eq!( - binary_search_rec(&unsorted_strings, &"olá", &LEFT, &unsorted_strings.len()), - Some(1) - ); + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (item, arr, expected) = $test_case; + assert_eq!(binary_search_rec(&item, arr, 0, arr.len()), expected); + } + )* + }; } - #[test] - fn success_search_integer_in_middle_of_unsorted_list() { - let unsorted_integers = vec![90, 80, 70]; - assert_eq!( - binary_search_rec(&unsorted_integers, &80, &LEFT, &unsorted_integers.len()), - Some(1) - ); + test_cases! { + empty: ("a", &[] as &[&str], None), + one_item_found: ("a", &["a"], Some(0)), + one_item_not_found: ("b", &["a"], None), + search_strings_asc_start: ("a", &["a", "b", "c", "d", "google", "zoo"], Some(0)), + search_strings_asc_middle: ("google", &["a", "b", "c", "d", "google", "zoo"], Some(4)), + search_strings_asc_last: ("zoo", &["a", "b", "c", "d", "google", "zoo"], Some(5)), + search_strings_asc_not_found: ("x", &["a", "b", "c", "d", "google", "zoo"], None), + search_strings_desc_start: ("zoo", &["zoo", "google", "d", "c", "b", "a"], Some(0)), + search_strings_desc_middle: ("google", &["zoo", "google", "d", "c", "b", "a"], Some(1)), + search_strings_desc_last: ("a", &["zoo", "google", "d", "c", "b", "a"], Some(5)), + search_strings_desc_not_found: ("x", &["zoo", "google", "d", "c", "b", "a"], None), + search_ints_asc_start: (1, &[1, 2, 3, 4], Some(0)), + search_ints_asc_middle: (3, &[1, 2, 3, 4], Some(2)), + search_ints_asc_end: (4, &[1, 2, 3, 4], Some(3)), + search_ints_asc_not_found: (5, &[1, 2, 3, 4], None), + search_ints_desc_start: (4, &[4, 3, 2, 1], Some(0)), + search_ints_desc_middle: (3, &[4, 3, 2, 1], Some(1)), + search_ints_desc_end: (1, &[4, 3, 2, 1], Some(3)), + search_ints_desc_not_found: (5, &[4, 3, 2, 1], None), + with_gaps_0: (0, &[1, 3, 8, 11], None), + with_gaps_1: (1, &[1, 3, 8, 11], Some(0)), + with_gaps_2: (2, &[1, 3, 8, 11], None), + with_gaps_3: (3, &[1, 3, 8, 11], Some(1)), + with_gaps_4: (4, &[1, 3, 8, 10], None), + with_gaps_5: (5, &[1, 3, 8, 10], None), + with_gaps_6: (6, &[1, 3, 8, 10], None), + with_gaps_7: (7, &[1, 3, 8, 11], None), + with_gaps_8: (8, &[1, 3, 8, 11], Some(2)), + with_gaps_9: (9, &[1, 3, 8, 11], None), + with_gaps_10: (10, &[1, 3, 8, 11], None), + with_gaps_11: (11, &[1, 3, 8, 11], Some(3)), + with_gaps_12: (12, &[1, 3, 8, 11], None), + with_gaps_13: (13, &[1, 3, 8, 11], None), } } diff --git a/src/searching/exponential_search.rs b/src/searching/exponential_search.rs index 7cf78981859..be700956149 100644 --- a/src/searching/exponential_search.rs +++ b/src/searching/exponential_search.rs @@ -33,40 +33,40 @@ mod tests { #[test] fn empty() { - let index = exponential_search(&"a", &vec![]); + let index = exponential_search(&"a", &[]); assert_eq!(index, None); } #[test] fn one_item() { - let index = exponential_search(&"a", &vec!["a"]); + let index = exponential_search(&"a", &["a"]); assert_eq!(index, Some(0)); } #[test] fn search_strings() { - let index = exponential_search(&"a", &vec!["a", "b", "c", "d", "google", "zoo"]); + let index = exponential_search(&"a", &["a", "b", "c", "d", "google", "zoo"]); assert_eq!(index, Some(0)); } #[test] fn search_ints() { - let index = exponential_search(&4, &vec![1, 2, 3, 4]); + let index = exponential_search(&4, &[1, 2, 3, 4]); assert_eq!(index, Some(3)); - let index = exponential_search(&3, &vec![1, 2, 3, 4]); + let index = exponential_search(&3, &[1, 2, 3, 4]); assert_eq!(index, Some(2)); - let index = exponential_search(&2, &vec![1, 2, 3, 4]); + let index = exponential_search(&2, &[1, 2, 3, 4]); assert_eq!(index, Some(1)); - let index = exponential_search(&1, &vec![1, 2, 3, 4]); + let index = exponential_search(&1, &[1, 2, 3, 4]); assert_eq!(index, Some(0)); } #[test] fn not_found() { - let index = exponential_search(&5, &vec![1, 2, 3, 4]); + let index = exponential_search(&5, &[1, 2, 3, 4]); assert_eq!(index, None); } } diff --git a/src/searching/fibonacci_search.rs b/src/searching/fibonacci_search.rs index bd90eaec4d0..dc33fbba884 100644 --- a/src/searching/fibonacci_search.rs +++ b/src/searching/fibonacci_search.rs @@ -45,40 +45,40 @@ mod tests { #[test] fn empty() { - let index = fibonacci_search(&"a", &vec![]); + let index = fibonacci_search(&"a", &[]); assert_eq!(index, None); } #[test] fn one_item() { - let index = fibonacci_search(&"a", &vec!["a"]); + let index = fibonacci_search(&"a", &["a"]); assert_eq!(index, Some(0)); } #[test] fn search_strings() { - let index = fibonacci_search(&"a", &vec!["a", "b", "c", "d", "google", "zoo"]); + let index = fibonacci_search(&"a", &["a", "b", "c", "d", "google", "zoo"]); assert_eq!(index, Some(0)); } #[test] fn search_ints() { - let index = fibonacci_search(&4, &vec![1, 2, 3, 4]); + let index = fibonacci_search(&4, &[1, 2, 3, 4]); assert_eq!(index, Some(3)); - let index = fibonacci_search(&3, &vec![1, 2, 3, 4]); + let index = fibonacci_search(&3, &[1, 2, 3, 4]); assert_eq!(index, Some(2)); - let index = fibonacci_search(&2, &vec![1, 2, 3, 4]); + let index = fibonacci_search(&2, &[1, 2, 3, 4]); assert_eq!(index, Some(1)); - let index = fibonacci_search(&1, &vec![1, 2, 3, 4]); + let index = fibonacci_search(&1, &[1, 2, 3, 4]); assert_eq!(index, Some(0)); } #[test] fn not_found() { - let index = fibonacci_search(&5, &vec![1, 2, 3, 4]); + let index = fibonacci_search(&5, &[1, 2, 3, 4]); assert_eq!(index, None); } } diff --git a/src/searching/jump_search.rs b/src/searching/jump_search.rs index 296bd421be5..64d49331a30 100644 --- a/src/searching/jump_search.rs +++ b/src/searching/jump_search.rs @@ -17,9 +17,6 @@ pub fn jump_search(item: &T, arr: &[T]) -> Option { } while &arr[prev] < item { prev += 1; - if prev == min(step, len) { - return None; - } } if &arr[prev] == item { return Some(prev); @@ -33,40 +30,36 @@ mod tests { #[test] fn empty() { - let index = jump_search(&"a", &vec![]); - assert_eq!(index, None); + assert!(jump_search(&"a", &[]).is_none()); } #[test] fn one_item() { - let index = jump_search(&"a", &vec!["a"]); - assert_eq!(index, Some(0)); + assert_eq!(jump_search(&"a", &["a"]).unwrap(), 0); } #[test] fn search_strings() { - let index = jump_search(&"a", &vec!["a", "b", "c", "d", "google", "zoo"]); - assert_eq!(index, Some(0)); + assert_eq!( + jump_search(&"a", &["a", "b", "c", "d", "google", "zoo"]).unwrap(), + 0 + ); } #[test] fn search_ints() { - let index = jump_search(&4, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(3)); - - let index = jump_search(&3, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(2)); - - let index = jump_search(&2, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(1)); - - let index = jump_search(&1, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(0)); + let arr = [1, 2, 3, 4]; + assert_eq!(jump_search(&4, &arr).unwrap(), 3); + assert_eq!(jump_search(&3, &arr).unwrap(), 2); + assert_eq!(jump_search(&2, &arr).unwrap(), 1); + assert_eq!(jump_search(&1, &arr).unwrap(), 0); } #[test] fn not_found() { - let index = jump_search(&5, &vec![1, 2, 3, 4]); - assert_eq!(index, None); + let arr = [1, 2, 3, 4]; + + assert!(jump_search(&5, &arr).is_none()); + assert!(jump_search(&0, &arr).is_none()); } } diff --git a/src/searching/kth_smallest.rs b/src/searching/kth_smallest.rs index 5e208177397..39c77fd7412 100644 --- a/src/searching/kth_smallest.rs +++ b/src/searching/kth_smallest.rs @@ -24,7 +24,7 @@ where return input[lo]; } - let pivot = partition(input, lo as isize, hi as isize) as usize; + let pivot = partition(input, lo, hi); let i = pivot - lo + 1; match k.cmp(&i) { diff --git a/src/searching/kth_smallest_heap.rs b/src/searching/kth_smallest_heap.rs index 2c7612d516b..fe2a3c15a5f 100644 --- a/src/searching/kth_smallest_heap.rs +++ b/src/searching/kth_smallest_heap.rs @@ -1,4 +1,4 @@ -use crate::data_structures::MaxHeap; +use crate::data_structures::Heap; use std::cmp::{Ord, Ordering}; /// Returns k-th smallest element of an array. @@ -12,7 +12,7 @@ use std::cmp::{Ord, Ordering}; /// operation's complexity is O(log(k)). pub fn kth_smallest_heap(input: &[T], k: usize) -> Option where - T: Default + Ord + Copy, + T: Ord + Copy, { if input.len() < k { return None; @@ -28,7 +28,7 @@ where // than it // otherwise, E_large cannot be the kth smallest, and should // be removed from the heap and E_new should be added - let mut heap = MaxHeap::new(); + let mut heap = Heap::new_max(); // first k elements goes to the heap as the baseline for &val in input.iter().take(k) { @@ -37,7 +37,7 @@ where for &val in input.iter().skip(k) { // compare new value to the current kth smallest value - let cur_big = heap.next().unwrap(); // heap.next() can't be None + let cur_big = heap.pop().unwrap(); // heap.pop() can't be None match val.cmp(&cur_big) { Ordering::Greater => { heap.add(cur_big); @@ -48,7 +48,7 @@ where } } - heap.next() + heap.pop() } #[cfg(test)] diff --git a/src/searching/linear_search.rs b/src/searching/linear_search.rs index d3a0be48042..d38b224d0a6 100644 --- a/src/searching/linear_search.rs +++ b/src/searching/linear_search.rs @@ -1,6 +1,15 @@ -use std::cmp::PartialEq; - -pub fn linear_search(item: &T, arr: &[T]) -> Option { +/// Performs a linear search on the given array, returning the index of the first occurrence of the item. +/// +/// # Arguments +/// +/// * `item` - A reference to the item to search for in the array. +/// * `arr` - A slice of items to search within. +/// +/// # Returns +/// +/// * `Some(usize)` - The index of the first occurrence of the item, if found. +/// * `None` - If the item is not found in the array. +pub fn linear_search(item: &T, arr: &[T]) -> Option { for (i, data) in arr.iter().enumerate() { if item == data { return Some(i); @@ -14,36 +23,54 @@ pub fn linear_search(item: &T, arr: &[T]) -> Option { mod tests { use super::*; - #[test] - fn search_strings() { - let index = linear_search(&"a", &vec!["a", "b", "c", "d", "google", "zoo"]); - assert_eq!(index, Some(0)); - } - - #[test] - fn search_ints() { - let index = linear_search(&4, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(3)); - - let index = linear_search(&3, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(2)); - - let index = linear_search(&2, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(1)); - - let index = linear_search(&1, &vec![1, 2, 3, 4]); - assert_eq!(index, Some(0)); - } - - #[test] - fn not_found() { - let index = linear_search(&5, &vec![1, 2, 3, 4]); - assert_eq!(index, None); + macro_rules! test_cases { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (item, arr, expected) = $tc; + if let Some(expected_index) = expected { + assert_eq!(arr[expected_index], item); + } + assert_eq!(linear_search(&item, arr), expected); + } + )* + } } - #[test] - fn empty() { - let index = linear_search(&1, &vec![]); - assert_eq!(index, None); + test_cases! { + empty: ("a", &[] as &[&str], None), + one_item_found: ("a", &["a"], Some(0)), + one_item_not_found: ("b", &["a"], None), + search_strings_asc_start: ("a", &["a", "b", "c", "d", "google", "zoo"], Some(0)), + search_strings_asc_middle: ("google", &["a", "b", "c", "d", "google", "zoo"], Some(4)), + search_strings_asc_last: ("zoo", &["a", "b", "c", "d", "google", "zoo"], Some(5)), + search_strings_asc_not_found: ("x", &["a", "b", "c", "d", "google", "zoo"], None), + search_strings_desc_start: ("zoo", &["zoo", "google", "d", "c", "b", "a"], Some(0)), + search_strings_desc_middle: ("google", &["zoo", "google", "d", "c", "b", "a"], Some(1)), + search_strings_desc_last: ("a", &["zoo", "google", "d", "c", "b", "a"], Some(5)), + search_strings_desc_not_found: ("x", &["zoo", "google", "d", "c", "b", "a"], None), + search_ints_asc_start: (1, &[1, 2, 3, 4], Some(0)), + search_ints_asc_middle: (3, &[1, 2, 3, 4], Some(2)), + search_ints_asc_end: (4, &[1, 2, 3, 4], Some(3)), + search_ints_asc_not_found: (5, &[1, 2, 3, 4], None), + search_ints_desc_start: (4, &[4, 3, 2, 1], Some(0)), + search_ints_desc_middle: (3, &[4, 3, 2, 1], Some(1)), + search_ints_desc_end: (1, &[4, 3, 2, 1], Some(3)), + search_ints_desc_not_found: (5, &[4, 3, 2, 1], None), + with_gaps_0: (0, &[1, 3, 8, 11], None), + with_gaps_1: (1, &[1, 3, 8, 11], Some(0)), + with_gaps_2: (2, &[1, 3, 8, 11], None), + with_gaps_3: (3, &[1, 3, 8, 11], Some(1)), + with_gaps_4: (4, &[1, 3, 8, 10], None), + with_gaps_5: (5, &[1, 3, 8, 10], None), + with_gaps_6: (6, &[1, 3, 8, 10], None), + with_gaps_7: (7, &[1, 3, 8, 11], None), + with_gaps_8: (8, &[1, 3, 8, 11], Some(2)), + with_gaps_9: (9, &[1, 3, 8, 11], None), + with_gaps_10: (10, &[1, 3, 8, 11], None), + with_gaps_11: (11, &[1, 3, 8, 11], Some(3)), + with_gaps_12: (12, &[1, 3, 8, 11], None), + with_gaps_13: (13, &[1, 3, 8, 11], None), } } diff --git a/src/searching/mod.rs b/src/searching/mod.rs index 146276da793..94f65988195 100644 --- a/src/searching/mod.rs +++ b/src/searching/mod.rs @@ -7,7 +7,9 @@ mod jump_search; mod kth_smallest; mod kth_smallest_heap; mod linear_search; +mod moore_voting; mod quick_select; +mod saddleback_search; mod ternary_search; mod ternary_search_min_max; mod ternary_search_min_max_recursive; @@ -22,7 +24,9 @@ pub use self::jump_search::jump_search; pub use self::kth_smallest::kth_smallest; pub use self::kth_smallest_heap::kth_smallest_heap; pub use self::linear_search::linear_search; +pub use self::moore_voting::moore_voting; pub use self::quick_select::quick_select; +pub use self::saddleback_search::saddleback_search; pub use self::ternary_search::ternary_search; pub use self::ternary_search_min_max::ternary_search_max; pub use self::ternary_search_min_max::ternary_search_min; diff --git a/src/searching/moore_voting.rs b/src/searching/moore_voting.rs new file mode 100644 index 00000000000..8acf0dd8b36 --- /dev/null +++ b/src/searching/moore_voting.rs @@ -0,0 +1,80 @@ +/* + + Moore's voting algorithm finds out the strictly majority-occurring element + without using extra space + and O(n) + O(n) time complexity + + It is built on the intuition that a strictly major element will always have a net occurrence as 1. + Say, array given: 9 1 8 1 1 + Here, the algorithm will work as: + + (for finding element present >(n/2) times) + (assumed: all elements are >0) + + Initialisation: ele=0, cnt=0 + Loop beings. + + loop 1: arr[0]=9 + ele = 9 + cnt=1 (since cnt = 0, cnt increments to 1 and ele = 9) + + loop 2: arr[1]=1 + ele = 9 + cnt= 0 (since in this turn of the loop, the array[i] != ele, cnt decrements by 1) + + loop 3: arr[2]=8 + ele = 8 + cnt=1 (since cnt = 0, cnt increments to 1 and ele = 8) + + loop 4: arr[3]=1 + ele = 8 + cnt= 0 (since in this turn of the loop, the array[i] != ele, cnt decrements by 1) + + loop 5: arr[4]=1 + ele = 9 + cnt=1 (since cnt = 0, cnt increments to 1 and ele = 1) + + Now, this ele should be the majority element if there's any + To check, a quick O(n) loop is run to check if the count of ele is >(n/2), n being the length of the array + + -1 is returned when no such element is found. + +*/ + +pub fn moore_voting(arr: &[i32]) -> i32 { + let n = arr.len(); + let mut cnt = 0; // initializing cnt + let mut ele = 0; // initializing ele + + for &item in arr.iter() { + if cnt == 0 { + cnt = 1; + ele = item; + } else if item == ele { + cnt += 1; + } else { + cnt -= 1; + } + } + + let cnt_check = arr.iter().filter(|&&x| x == ele).count(); + + if cnt_check > (n / 2) { + ele + } else { + -1 + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_moore_voting() { + let arr1: Vec = vec![9, 1, 8, 1, 1]; + assert!(moore_voting(&arr1) == 1); + let arr2: Vec = vec![1, 2, 3, 4]; + assert!(moore_voting(&arr2) == -1); + } +} diff --git a/src/searching/quick_select.rs b/src/searching/quick_select.rs index 5b971e47642..592c4ceb2f4 100644 --- a/src/searching/quick_select.rs +++ b/src/searching/quick_select.rs @@ -4,13 +4,13 @@ fn partition(list: &mut [i32], left: usize, right: usize, pivot_index: usize) -> let pivot_value = list[pivot_index]; list.swap(pivot_index, right); // Move pivot to end let mut store_index = left; - for i in left..(right + 1) { + for i in left..right { if list[i] < pivot_value { list.swap(store_index, i); store_index += 1; } - list.swap(right, store_index); // Move pivot to its final place } + list.swap(right, store_index); // Move pivot to its final place store_index } @@ -19,7 +19,7 @@ pub fn quick_select(list: &mut [i32], left: usize, right: usize, index: usize) - // If the list contains only one element, return list[left]; } // return that element - let mut pivot_index = ((left + right) / 2) + 1; // select a pivotIndex between left and right + let mut pivot_index = left + (right - left) / 2; // select a pivotIndex between left and right pivot_index = partition(list, left, right, pivot_index); // The pivot is in its final sorted position match index { @@ -37,7 +37,7 @@ mod tests { let mut arr1 = [2, 3, 4, 5]; assert_eq!(quick_select(&mut arr1, 0, 3, 1), 3); let mut arr2 = [2, 5, 9, 12, 16]; - assert_eq!(quick_select(&mut arr2, 1, 3, 2), 12); + assert_eq!(quick_select(&mut arr2, 1, 3, 2), 9); let mut arr2 = [0, 3, 8]; assert_eq!(quick_select(&mut arr2, 0, 0, 0), 0); } diff --git a/src/searching/saddleback_search.rs b/src/searching/saddleback_search.rs new file mode 100644 index 00000000000..1a722a29b1c --- /dev/null +++ b/src/searching/saddleback_search.rs @@ -0,0 +1,84 @@ +// Saddleback search is a technique used to find an element in a sorted 2D matrix in O(m + n) time, +// where m is the number of rows, and n is the number of columns. It works by starting from the +// top-right corner of the matrix and moving left or down based on the comparison of the current +// element with the target element. +use std::cmp::Ordering; + +pub fn saddleback_search(matrix: &[Vec], element: i32) -> (usize, usize) { + // Initialize left and right indices + let mut left_index = 0; + let mut right_index = matrix[0].len() - 1; + + // Start searching + while left_index < matrix.len() { + match element.cmp(&matrix[left_index][right_index]) { + // If the current element matches the target element, return its position (indices are 1-based) + Ordering::Equal => return (left_index + 1, right_index + 1), + Ordering::Greater => { + // If the target element is greater, move to the next row (downwards) + left_index += 1; + } + Ordering::Less => { + // If the target element is smaller, move to the previous column (leftwards) + if right_index == 0 { + break; // If we reach the left-most column, exit the loop + } + right_index -= 1; + } + } + } + + // If the element is not found, return (0, 0) + (0, 0) +} + +#[cfg(test)] +mod tests { + use super::*; + + // Test when the element is not present in the matrix + #[test] + fn test_element_not_found() { + let matrix = vec![vec![1, 10, 100], vec![2, 20, 200], vec![3, 30, 300]]; + assert_eq!(saddleback_search(&matrix, 123), (0, 0)); + } + + // Test when the element is at the top-left corner of the matrix + #[test] + fn test_element_at_top_left() { + let matrix = vec![vec![1, 10, 100], vec![2, 20, 200], vec![3, 30, 300]]; + assert_eq!(saddleback_search(&matrix, 1), (1, 1)); + } + + // Test when the element is at the bottom-right corner of the matrix + #[test] + fn test_element_at_bottom_right() { + let matrix = vec![vec![1, 10, 100], vec![2, 20, 200], vec![3, 30, 300]]; + assert_eq!(saddleback_search(&matrix, 300), (3, 3)); + } + + // Test when the element is at the top-right corner of the matrix + #[test] + fn test_element_at_top_right() { + let matrix = vec![vec![1, 10, 100], vec![2, 20, 200], vec![3, 30, 300]]; + assert_eq!(saddleback_search(&matrix, 100), (1, 3)); + } + + // Test when the element is at the bottom-left corner of the matrix + #[test] + fn test_element_at_bottom_left() { + let matrix = vec![vec![1, 10, 100], vec![2, 20, 200], vec![3, 30, 300]]; + assert_eq!(saddleback_search(&matrix, 3), (3, 1)); + } + + // Additional test case: Element in the middle of the matrix + #[test] + fn test_element_in_middle() { + let matrix = vec![ + vec![1, 10, 100, 1000], + vec![2, 20, 200, 2000], + vec![3, 30, 300, 3000], + ]; + assert_eq!(saddleback_search(&matrix, 200), (2, 3)); + } +} diff --git a/src/searching/ternary_search.rs b/src/searching/ternary_search.rs index 345f979d479..cb9b5bee477 100644 --- a/src/searching/ternary_search.rs +++ b/src/searching/ternary_search.rs @@ -1,91 +1,195 @@ +//! This module provides an implementation of a ternary search algorithm that +//! works for both ascending and descending ordered arrays. The ternary search +//! function returns the index of the target element if it is found, or `None` +//! if the target is not present in the array. + use std::cmp::Ordering; -pub fn ternary_search( - target: &T, - list: &[T], - mut start: usize, - mut end: usize, -) -> Option { - if list.is_empty() { +/// Performs a ternary search for a specified item within a sorted array. +/// +/// This function can handle both ascending and descending ordered arrays. It +/// takes a reference to the item to search for and a slice of the array. If +/// the item is found, it returns the index of the item within the array. If +/// the item is not found, it returns `None`. +/// +/// # Parameters +/// +/// - `item`: A reference to the item to search for. +/// - `arr`: A slice of the sorted array in which to search. +/// +/// # Returns +/// +/// An `Option` which is: +/// - `Some(index)` if the item is found at the given index. +/// - `None` if the item is not found in the array. +pub fn ternary_search(item: &T, arr: &[T]) -> Option { + if arr.is_empty() { return None; } - while start <= end { - let mid1: usize = start + (end - start) / 3; - let mid2: usize = end - (end - start) / 3; + let is_asc = is_asc_arr(arr); + let mut left = 0; + let mut right = arr.len() - 1; - match target.cmp(&list[mid1]) { - Ordering::Less => end = mid1 - 1, - Ordering::Equal => return Some(mid1), - Ordering::Greater => match target.cmp(&list[mid2]) { - Ordering::Greater => start = mid2 + 1, - Ordering::Equal => return Some(mid2), - Ordering::Less => { - start = mid1 + 1; - end = mid2 - 1; - } - }, + while left <= right { + if match_compare(item, arr, &mut left, &mut right, is_asc) { + return Some(left); } } None } -#[cfg(test)] -mod tests { - use super::*; +/// Compares the item with two middle elements of the current search range and +/// updates the search bounds accordingly. This function handles both ascending +/// and descending ordered arrays. It calculates two middle indices of the +/// current search range and compares the item with the elements at these +/// indices. It then updates the search bounds (`left` and `right`) based on +/// the result of these comparisons. If the item is found, it returns `true`. +/// +/// # Parameters +/// +/// - `item`: A reference to the item to search for. +/// - `arr`: A slice of the array in which to search. +/// - `left`: A mutable reference to the left bound of the search range. +/// - `right`: A mutable reference to the right bound of the search range. +/// - `is_asc`: A boolean indicating whether the array is sorted in ascending order. +/// +/// # Returns +/// +/// A `bool` indicating: +/// - `true` if the item was found in the array. +/// - `false` if the item was not found in the array. +fn match_compare( + item: &T, + arr: &[T], + left: &mut usize, + right: &mut usize, + is_asc: bool, +) -> bool { + let first_mid = *left + (*right - *left) / 3; + let second_mid = *right - (*right - *left) / 3; - #[test] - fn returns_none_if_empty_list() { - let index = ternary_search(&"a", &vec![], 1, 10); - assert_eq!(index, None); + // Handling the edge case where the search narrows down to a single element + if first_mid == second_mid && first_mid == *left { + return match &arr[*left] { + x if x == item => true, + _ => { + *left += 1; + false + } + }; } - #[test] - fn returns_none_if_range_is_invalid() { - let index = ternary_search(&1, &vec![1, 2, 3], 2, 1); - assert_eq!(index, None); - } + let cmp_first_mid = item.cmp(&arr[first_mid]); + let cmp_second_mid = item.cmp(&arr[second_mid]); - #[test] - fn returns_index_if_list_has_one_item() { - let index = ternary_search(&1, &vec![1], 0, 1); - assert_eq!(index, Some(0)); - } - - #[test] - fn returns_first_index() { - let index = ternary_search(&1, &vec![1, 2, 3], 0, 2); - assert_eq!(index, Some(0)); + match (is_asc, cmp_first_mid, cmp_second_mid) { + // If the item matches either midpoint, it returns the index + (_, Ordering::Equal, _) => { + *left = first_mid; + return true; + } + (_, _, Ordering::Equal) => { + *left = second_mid; + return true; + } + // If the item is smaller than the element at first_mid (in ascending order) + // or greater than it (in descending order), it narrows the search to the first third. + (true, Ordering::Less, _) | (false, Ordering::Greater, _) => { + *right = first_mid.saturating_sub(1) + } + // If the item is greater than the element at second_mid (in ascending order) + // or smaller than it (in descending order), it narrows the search to the last third. + (true, _, Ordering::Greater) | (false, _, Ordering::Less) => *left = second_mid + 1, + // Otherwise, it searches the middle third. + (_, _, _) => { + *left = first_mid + 1; + *right = second_mid - 1; + } } - #[test] - fn returns_first_index_if_end_out_of_bounds() { - let index = ternary_search(&1, &vec![1, 2, 3], 0, 3); - assert_eq!(index, Some(0)); - } + false +} - #[test] - fn returns_last_index() { - let index = ternary_search(&3, &vec![1, 2, 3], 0, 2); - assert_eq!(index, Some(2)); - } +/// Determines if the given array is sorted in ascending order. +/// +/// This helper function checks if the first element of the array is less than the +/// last element, indicating an ascending order. It returns `false` if the array +/// has fewer than two elements. +/// +/// # Parameters +/// +/// - `arr`: A slice of the array to check. +/// +/// # Returns +/// +/// A `bool` indicating whether the array is sorted in ascending order. +fn is_asc_arr(arr: &[T]) -> bool { + arr.len() > 1 && arr[0] < arr[arr.len() - 1] +} - #[test] - fn returns_last_index_if_end_out_of_bounds() { - let index = ternary_search(&3, &vec![1, 2, 3], 0, 3); - assert_eq!(index, Some(2)); - } +#[cfg(test)] +mod tests { + use super::*; - #[test] - fn returns_middle_index() { - let index = ternary_search(&2, &vec![1, 2, 3], 0, 2); - assert_eq!(index, Some(1)); + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (item, arr, expected) = $test_case; + if let Some(expected_index) = expected { + assert_eq!(arr[expected_index], item); + } + assert_eq!(ternary_search(&item, arr), expected); + } + )* + }; } - #[test] - fn returns_middle_index_if_end_out_of_bounds() { - let index = ternary_search(&2, &vec![1, 2, 3], 0, 3); - assert_eq!(index, Some(1)); + test_cases! { + empty: ("a", &[] as &[&str], None), + one_item_found: ("a", &["a"], Some(0)), + one_item_not_found: ("b", &["a"], None), + search_two_elements_found_at_start: (1, &[1, 2], Some(0)), + search_two_elements_found_at_end: (2, &[1, 2], Some(1)), + search_two_elements_not_found_start: (0, &[1, 2], None), + search_two_elements_not_found_end: (3, &[1, 2], None), + search_three_elements_found_start: (1, &[1, 2, 3], Some(0)), + search_three_elements_found_middle: (2, &[1, 2, 3], Some(1)), + search_three_elements_found_end: (3, &[1, 2, 3], Some(2)), + search_three_elements_not_found_start: (0, &[1, 2, 3], None), + search_three_elements_not_found_end: (4, &[1, 2, 3], None), + search_strings_asc_start: ("a", &["a", "b", "c", "d", "google", "zoo"], Some(0)), + search_strings_asc_middle: ("google", &["a", "b", "c", "d", "google", "zoo"], Some(4)), + search_strings_asc_last: ("zoo", &["a", "b", "c", "d", "google", "zoo"], Some(5)), + search_strings_asc_not_found: ("x", &["a", "b", "c", "d", "google", "zoo"], None), + search_strings_desc_start: ("zoo", &["zoo", "google", "d", "c", "b", "a"], Some(0)), + search_strings_desc_middle: ("google", &["zoo", "google", "d", "c", "b", "a"], Some(1)), + search_strings_desc_last: ("a", &["zoo", "google", "d", "c", "b", "a"], Some(5)), + search_strings_desc_not_found: ("x", &["zoo", "google", "d", "c", "b", "a"], None), + search_ints_asc_start: (1, &[1, 2, 3, 4], Some(0)), + search_ints_asc_middle: (3, &[1, 2, 3, 4], Some(2)), + search_ints_asc_end: (4, &[1, 2, 3, 4], Some(3)), + search_ints_asc_not_found: (5, &[1, 2, 3, 4], None), + search_ints_desc_start: (4, &[4, 3, 2, 1], Some(0)), + search_ints_desc_middle: (3, &[4, 3, 2, 1], Some(1)), + search_ints_desc_end: (1, &[4, 3, 2, 1], Some(3)), + search_ints_desc_not_found: (5, &[4, 3, 2, 1], None), + with_gaps_0: (0, &[1, 3, 8, 11], None), + with_gaps_1: (1, &[1, 3, 8, 11], Some(0)), + with_gaps_2: (2, &[1, 3, 8, 11], None), + with_gaps_3: (3, &[1, 3, 8, 11], Some(1)), + with_gaps_4: (4, &[1, 3, 8, 10], None), + with_gaps_5: (5, &[1, 3, 8, 10], None), + with_gaps_6: (6, &[1, 3, 8, 10], None), + with_gaps_7: (7, &[1, 3, 8, 11], None), + with_gaps_8: (8, &[1, 3, 8, 11], Some(2)), + with_gaps_9: (9, &[1, 3, 8, 11], None), + with_gaps_10: (10, &[1, 3, 8, 11], None), + with_gaps_11: (11, &[1, 3, 8, 11], Some(3)), + with_gaps_12: (12, &[1, 3, 8, 11], None), + with_gaps_13: (13, &[1, 3, 8, 11], None), } } diff --git a/src/searching/ternary_search_min_max_recursive.rs b/src/searching/ternary_search_min_max_recursive.rs index 1e5941441e5..88d3a0a7b1b 100644 --- a/src/searching/ternary_search_min_max_recursive.rs +++ b/src/searching/ternary_search_min_max_recursive.rs @@ -16,9 +16,8 @@ pub fn ternary_search_max_rec( return ternary_search_max_rec(f, mid1, end, absolute_precision); } else if r1 > r2 { return ternary_search_max_rec(f, start, mid2, absolute_precision); - } else { - return ternary_search_max_rec(f, mid1, mid2, absolute_precision); } + return ternary_search_max_rec(f, mid1, mid2, absolute_precision); } f(start) } @@ -41,9 +40,8 @@ pub fn ternary_search_min_rec( return ternary_search_min_rec(f, start, mid2, absolute_precision); } else if r1 > r2 { return ternary_search_min_rec(f, mid1, end, absolute_precision); - } else { - return ternary_search_min_rec(f, mid1, mid2, absolute_precision); } + return ternary_search_min_rec(f, mid1, mid2, absolute_precision); } f(start) } diff --git a/src/searching/ternary_search_recursive.rs b/src/searching/ternary_search_recursive.rs index e033d67f5dd..045df86e3eb 100644 --- a/src/searching/ternary_search_recursive.rs +++ b/src/searching/ternary_search_recursive.rs @@ -34,55 +34,55 @@ mod tests { #[test] fn returns_none_if_empty_list() { - let index = ternary_search_rec(&"a", &vec![], 1, 10); + let index = ternary_search_rec(&"a", &[], 1, 10); assert_eq!(index, None); } #[test] fn returns_none_if_range_is_invalid() { - let index = ternary_search_rec(&1, &vec![1, 2, 3], 2, 1); + let index = ternary_search_rec(&1, &[1, 2, 3], 2, 1); assert_eq!(index, None); } #[test] fn returns_index_if_list_has_one_item() { - let index = ternary_search_rec(&1, &vec![1], 0, 1); + let index = ternary_search_rec(&1, &[1], 0, 1); assert_eq!(index, Some(0)); } #[test] fn returns_first_index() { - let index = ternary_search_rec(&1, &vec![1, 2, 3], 0, 2); + let index = ternary_search_rec(&1, &[1, 2, 3], 0, 2); assert_eq!(index, Some(0)); } #[test] fn returns_first_index_if_end_out_of_bounds() { - let index = ternary_search_rec(&1, &vec![1, 2, 3], 0, 3); + let index = ternary_search_rec(&1, &[1, 2, 3], 0, 3); assert_eq!(index, Some(0)); } #[test] fn returns_last_index() { - let index = ternary_search_rec(&3, &vec![1, 2, 3], 0, 2); + let index = ternary_search_rec(&3, &[1, 2, 3], 0, 2); assert_eq!(index, Some(2)); } #[test] fn returns_last_index_if_end_out_of_bounds() { - let index = ternary_search_rec(&3, &vec![1, 2, 3], 0, 3); + let index = ternary_search_rec(&3, &[1, 2, 3], 0, 3); assert_eq!(index, Some(2)); } #[test] fn returns_middle_index() { - let index = ternary_search_rec(&2, &vec![1, 2, 3], 0, 2); + let index = ternary_search_rec(&2, &[1, 2, 3], 0, 2); assert_eq!(index, Some(1)); } #[test] fn returns_middle_index_if_end_out_of_bounds() { - let index = ternary_search_rec(&2, &vec![1, 2, 3], 0, 3); + let index = ternary_search_rec(&2, &[1, 2, 3], 0, 3); assert_eq!(index, Some(1)); } } diff --git a/src/sorting/README.md b/src/sorting/README.md index ed7feb42cac..4b0e248db35 100644 --- a/src/sorting/README.md +++ b/src/sorting/README.md @@ -180,8 +180,30 @@ __Properties__ From [Wikipedia][tim-wiki]: Timsort is a hybrid stable sorting algorithm, derived from merge sort and insertion sort, designed to perform well on many kinds of real-world data. It was implemented by Tim Peters in 2002 for use in the Python programming language. The algorithm finds subsequences of the data that are already ordered (runs) and uses them to sort the remainder more efficiently. This is done by merging runs until certain criteria are fulfilled. Timsort has been Python's standard sorting algorithm since version 2.3. It is also used to sort arrays of non-primitive type in Java SE 7, on the Android platform, in GNU Octave, on V8, Swift, and Rust. __Properties__ -* Worst-case performance O(n log n) -* Best-case performance O(n) +* Worst-case performance O(max element size(ms)) +* Best-case performance O(max element size(ms)) + +### [Sleep](./sleep_sort.rs) +![alt text][sleep-image] + +From [Wikipedia][bucket-sort-wiki]: This is an idea that was originally posted on the message board 4chan, replacing the bucket in bucket sort with time instead of memory space. +It is actually possible to sort by "maximum of all elements x unit time to sleep". The only case where this would be useful would be in examples. + +### [Patience](./patience_sort.rs) +[patience-video] + + +From [Wikipedia][patience-sort-wiki]: The algorithm's name derives from a simplified variant of the patience card game. The game begins with a shuffled deck of cards. The cards are dealt one by one into a sequence of piles on the table, according to the following rules. + +1. Initially, there are no piles. The first card dealt forms a new pile consisting of the single card. +2. Each subsequent card is placed on the leftmost existing pile whose top card has a value greater than or equal to the new card's value, or to the right of all of the existing piles, thus forming a new pile. +3. When there are no more cards remaining to deal, the game ends. + +This card game is turned into a two-phase sorting algorithm, as follows. Given an array of n elements from some totally ordered domain, consider this array as a collection of cards and simulate the patience sorting game. When the game is over, recover the sorted sequence by repeatedly picking off the minimum visible card; in other words, perform a k-way merge of the p piles, each of which is internally sorted. + +__Properties__ +* Worst case performance O(n log n) +* Best case performance O(n) [bogo-wiki]: https://en.wikipedia.org/wiki/Bogosort [bogo-image]: https://upload.wikimedia.org/wikipedia/commons/7/7b/Bogo_sort_animation.gif @@ -235,3 +257,9 @@ __Properties__ [comb-sort]: https://upload.wikimedia.org/wikipedia/commons/4/46/Comb_sort_demo.gif [comb-sort-wiki]: https://en.wikipedia.org/wiki/Comb_sort + +[sleep-sort]: +[sleep-sort-wiki]: https://ja.m.wikipedia.org/wiki/バケットソート#.E3.82.B9.E3.83.AA.E3.83.BC.E3.83.97.E3.82.BD.E3.83.BC.E3.83.88 + +[patience-sort-wiki]: https://en.wikipedia.org/wiki/Patience_sorting +[patience-video]: https://user-images.githubusercontent.com/67539676/212542208-d3f7a824-60d8-467c-8097-841945514ae9.mp4 diff --git a/src/sorting/bead_sort.rs b/src/sorting/bead_sort.rs new file mode 100644 index 00000000000..c8c4017942e --- /dev/null +++ b/src/sorting/bead_sort.rs @@ -0,0 +1,59 @@ +//Bead sort only works for sequences of non-negative integers. +//https://en.wikipedia.org/wiki/Bead_sort +pub fn bead_sort(a: &mut [usize]) { + // Find the maximum element + let mut max = a[0]; + (1..a.len()).for_each(|i| { + if a[i] > max { + max = a[i]; + } + }); + + // allocating memory + let mut beads = vec![vec![0; max]; a.len()]; + + // mark the beads + for i in 0..a.len() { + for j in (0..a[i]).rev() { + beads[i][j] = 1; + } + } + + // move down the beads + for j in 0..max { + let mut sum = 0; + (0..a.len()).for_each(|i| { + sum += beads[i][j]; + beads[i][j] = 0; + }); + + for k in ((a.len() - sum)..a.len()).rev() { + a[k] = j + 1; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + + #[test] + fn descending() { + //descending + let mut ve1: [usize; 5] = [5, 4, 3, 2, 1]; + let cloned = ve1; + bead_sort(&mut ve1); + assert!(is_sorted(&ve1) && have_same_elements(&ve1, &cloned)); + } + + #[test] + fn mix_values() { + //pre-sorted + let mut ve2: [usize; 5] = [7, 9, 6, 2, 3]; + let cloned = ve2; + bead_sort(&mut ve2); + assert!(is_sorted(&ve2) && have_same_elements(&ve2, &cloned)); + } +} diff --git a/src/sorting/binary_insertion_sort.rs b/src/sorting/binary_insertion_sort.rs new file mode 100644 index 00000000000..3ecb47456e8 --- /dev/null +++ b/src/sorting/binary_insertion_sort.rs @@ -0,0 +1,51 @@ +fn _binary_search(arr: &[T], target: &T) -> usize { + let mut low = 0; + let mut high = arr.len(); + + while low < high { + let mid = low + (high - low) / 2; + + if arr[mid] < *target { + low = mid + 1; + } else { + high = mid; + } + } + + low +} + +pub fn binary_insertion_sort(arr: &mut [T]) { + let len = arr.len(); + + for i in 1..len { + let key = arr[i].clone(); + let index = _binary_search(&arr[..i], &key); + + arr[index..=i].rotate_right(1); + arr[index] = key; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_binary_insertion_sort() { + let mut arr1 = vec![64, 25, 12, 22, 11]; + let mut arr2 = vec![5, 4, 3, 2, 1]; + let mut arr3 = vec![1, 2, 3, 4, 5]; + let mut arr4: Vec = vec![]; // Explicitly specify the type for arr4 + + binary_insertion_sort(&mut arr1); + binary_insertion_sort(&mut arr2); + binary_insertion_sort(&mut arr3); + binary_insertion_sort(&mut arr4); + + assert_eq!(arr1, vec![11, 12, 22, 25, 64]); + assert_eq!(arr2, vec![1, 2, 3, 4, 5]); + assert_eq!(arr3, vec![1, 2, 3, 4, 5]); + assert_eq!(arr4, Vec::::new()); + } +} diff --git a/src/sorting/bingo_sort.rs b/src/sorting/bingo_sort.rs new file mode 100644 index 00000000000..0c113fe86d5 --- /dev/null +++ b/src/sorting/bingo_sort.rs @@ -0,0 +1,105 @@ +use std::cmp::{max, min}; + +// Function for finding the maximum and minimum element of the Array +fn max_min(vec: &[i32], bingo: &mut i32, next_bingo: &mut i32) { + for &element in vec.iter().skip(1) { + *bingo = min(*bingo, element); + *next_bingo = max(*next_bingo, element); + } +} + +pub fn bingo_sort(vec: &mut [i32]) { + if vec.is_empty() { + return; + } + + let mut bingo = vec[0]; + let mut next_bingo = vec[0]; + + max_min(vec, &mut bingo, &mut next_bingo); + + let largest_element = next_bingo; + let mut next_element_pos = 0; + + for (bingo, _next_bingo) in (bingo..=largest_element).zip(bingo..=largest_element) { + let start_pos = next_element_pos; + + for i in start_pos..vec.len() { + if vec[i] == bingo { + vec.swap(i, next_element_pos); + next_element_pos += 1; + } + } + } +} + +#[allow(dead_code)] +fn print_array(arr: &[i32]) { + print!("Sorted Array: "); + for &element in arr { + print!("{element} "); + } + println!(); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_bingo_sort() { + let mut arr = vec![5, 4, 8, 5, 4, 8, 5, 4, 4, 4]; + bingo_sort(&mut arr); + assert_eq!(arr, vec![4, 4, 4, 4, 4, 5, 5, 5, 8, 8]); + + let mut arr2 = vec![10, 9, 8, 7, 6, 5, 4, 3, 2, 1]; + bingo_sort(&mut arr2); + assert_eq!(arr2, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + let mut arr3 = vec![0, 1, 0, 1, 0, 1]; + bingo_sort(&mut arr3); + assert_eq!(arr3, vec![0, 0, 0, 1, 1, 1]); + } + + #[test] + fn test_empty_array() { + let mut arr = Vec::new(); + bingo_sort(&mut arr); + assert_eq!(arr, Vec::new()); + } + + #[test] + fn test_single_element_array() { + let mut arr = vec![42]; + bingo_sort(&mut arr); + assert_eq!(arr, vec![42]); + } + + #[test] + fn test_negative_numbers() { + let mut arr = vec![-5, -4, -3, -2, -1]; + bingo_sort(&mut arr); + assert_eq!(arr, vec![-5, -4, -3, -2, -1]); + } + + #[test] + fn test_already_sorted() { + let mut arr = vec![1, 2, 3, 4, 5]; + bingo_sort(&mut arr); + assert_eq!(arr, vec![1, 2, 3, 4, 5]); + } + + #[test] + fn test_reverse_sorted() { + let mut arr = vec![5, 4, 3, 2, 1]; + bingo_sort(&mut arr); + assert_eq!(arr, vec![1, 2, 3, 4, 5]); + } + + #[test] + fn test_duplicates() { + let mut arr = vec![1, 2, 3, 4, 5, 1, 2, 3, 4, 5]; + bingo_sort(&mut arr); + assert_eq!(arr, vec![1, 1, 2, 2, 3, 3, 4, 4, 5, 5]); + } +} diff --git a/src/sorting/bitonic_sort.rs b/src/sorting/bitonic_sort.rs new file mode 100644 index 00000000000..8b1dca0c6c0 --- /dev/null +++ b/src/sorting/bitonic_sort.rs @@ -0,0 +1,52 @@ +fn _comp_and_swap(array: &mut [T], left: usize, right: usize, ascending: bool) { + if (ascending && array[left] > array[right]) || (!ascending && array[left] < array[right]) { + array.swap(left, right); + } +} + +fn _bitonic_merge(array: &mut [T], low: usize, length: usize, ascending: bool) { + if length > 1 { + let middle = length / 2; + for i in low..(low + middle) { + _comp_and_swap(array, i, i + middle, ascending); + } + _bitonic_merge(array, low, middle, ascending); + _bitonic_merge(array, low + middle, middle, ascending); + } +} + +pub fn bitonic_sort(array: &mut [T], low: usize, length: usize, ascending: bool) { + if length > 1 { + let middle = length / 2; + bitonic_sort(array, low, middle, true); + bitonic_sort(array, low + middle, middle, false); + _bitonic_merge(array, low, length, ascending); + } +} + +//Note that this program works only when size of input is a power of 2. +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_descending_sorted; + use crate::sorting::is_sorted; + + #[test] + fn descending() { + //descending + let mut ve1 = vec![6, 5, 4, 3]; + let cloned = ve1.clone(); + bitonic_sort(&mut ve1, 0, 4, true); + assert!(is_sorted(&ve1) && have_same_elements(&ve1, &cloned)); + } + + #[test] + fn ascending() { + //pre-sorted + let mut ve2 = vec![1, 2, 3, 4]; + let cloned = ve2.clone(); + bitonic_sort(&mut ve2, 0, 4, false); + assert!(is_descending_sorted(&ve2) && have_same_elements(&ve2, &cloned)); + } +} diff --git a/src/sorting/bubble_sort.rs b/src/sorting/bubble_sort.rs index 6c881c92144..0df7cec07a1 100644 --- a/src/sorting/bubble_sort.rs +++ b/src/sorting/bubble_sort.rs @@ -18,28 +18,32 @@ pub fn bubble_sort(arr: &mut [T]) { #[cfg(test)] mod tests { - use super::super::is_sorted; use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn descending() { //descending let mut ve1 = vec![6, 5, 4, 3, 2, 1]; + let cloned = ve1.clone(); bubble_sort(&mut ve1); - assert!(is_sorted(&ve1)); + assert!(is_sorted(&ve1) && have_same_elements(&ve1, &cloned)); } #[test] fn ascending() { //pre-sorted let mut ve2 = vec![1, 2, 3, 4, 5, 6]; + let cloned = ve2.clone(); bubble_sort(&mut ve2); - assert!(is_sorted(&ve2)); + assert!(is_sorted(&ve2) && have_same_elements(&ve2, &cloned)); } #[test] fn empty() { let mut ve3: Vec = vec![]; + let cloned = ve3.clone(); bubble_sort(&mut ve3); - assert!(is_sorted(&ve3)); + assert!(is_sorted(&ve3) && have_same_elements(&ve3, &cloned)); } } diff --git a/src/sorting/bucket_sort.rs b/src/sorting/bucket_sort.rs index fae0a0d0174..05fb30272ad 100644 --- a/src/sorting/bucket_sort.rs +++ b/src/sorting/bucket_sort.rs @@ -33,48 +33,55 @@ pub fn bucket_sort(arr: &[usize]) -> Vec { #[cfg(test)] mod tests { - use super::super::is_sorted; use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn empty() { let arr: [usize; 0] = []; + let cloned = arr; let res = bucket_sort(&arr); - assert!(is_sorted(&res)); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); } #[test] fn one_element() { let arr: [usize; 1] = [4]; + let cloned = arr; let res = bucket_sort(&arr); - assert!(is_sorted(&res)); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); } #[test] fn already_sorted() { - let arr: [usize; 3] = [10, 9, 105]; + let arr: [usize; 3] = [10, 19, 105]; + let cloned = arr; let res = bucket_sort(&arr); - assert!(is_sorted(&res)); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); } #[test] fn basic() { let arr: [usize; 4] = [35, 53, 1, 0]; + let cloned = arr; let res = bucket_sort(&arr); - assert!(is_sorted(&res)); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); } #[test] fn odd_number_of_elements() { - let arr: Vec = vec![1, 21, 5, 11, 58]; + let arr: [usize; 5] = [1, 21, 5, 11, 58]; + let cloned = arr; let res = bucket_sort(&arr); - assert!(is_sorted(&res)); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); } #[test] fn repeated_elements() { - let arr: Vec = vec![542, 542, 542, 542]; + let arr: [usize; 4] = [542, 542, 542, 542]; + let cloned = arr; let res = bucket_sort(&arr); - assert!(is_sorted(&res)); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); } } diff --git a/src/sorting/cocktail_shaker_sort.rs b/src/sorting/cocktail_shaker_sort.rs index dc3af99f617..dd65fc0fa99 100644 --- a/src/sorting/cocktail_shaker_sort.rs +++ b/src/sorting/cocktail_shaker_sort.rs @@ -37,32 +37,38 @@ pub fn cocktail_shaker_sort(arr: &mut [T]) { #[cfg(test)] mod tests { use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn basic() { let mut arr = vec![5, 2, 1, 3, 4, 6]; + let cloned = arr.clone(); cocktail_shaker_sort(&mut arr); - assert_eq!(arr, vec![1, 2, 3, 4, 5, 6]); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); } #[test] fn empty() { let mut arr = Vec::::new(); + let cloned = arr.clone(); cocktail_shaker_sort(&mut arr); - assert_eq!(arr, vec![]); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); } #[test] fn one_element() { let mut arr = vec![1]; + let cloned = arr.clone(); cocktail_shaker_sort(&mut arr); - assert_eq!(arr, vec![1]); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); } #[test] fn pre_sorted() { let mut arr = vec![1, 2, 3, 4, 5, 6]; + let cloned = arr.clone(); cocktail_shaker_sort(&mut arr); - assert_eq!(arr, vec![1, 2, 3, 4, 5, 6]); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); } } diff --git a/src/sorting/comb_sort.rs b/src/sorting/comb_sort.rs index a6cd89fc827..d84522ce2ee 100644 --- a/src/sorting/comb_sort.rs +++ b/src/sorting/comb_sort.rs @@ -22,24 +22,33 @@ pub fn comb_sort(arr: &mut [T]) { #[cfg(test)] mod tests { use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn descending() { //descending let mut ve1 = vec![6, 5, 4, 3, 2, 1]; + let cloned = ve1.clone(); comb_sort(&mut ve1); - for i in 0..ve1.len() - 1 { - assert!(ve1[i] <= ve1[i + 1]); - } + assert!(is_sorted(&ve1) && have_same_elements(&ve1, &cloned)); } #[test] fn ascending() { //pre-sorted let mut ve2 = vec![1, 2, 3, 4, 5, 6]; + let cloned = ve2.clone(); comb_sort(&mut ve2); - for i in 0..ve2.len() - 1 { - assert!(ve2[i] <= ve2[i + 1]); - } + assert!(is_sorted(&ve2) && have_same_elements(&ve2, &cloned)); + } + + #[test] + fn duplicates() { + //pre-sorted + let mut ve3 = vec![2, 2, 2, 2, 2, 1]; + let cloned = ve3.clone(); + comb_sort(&mut ve3); + assert!(is_sorted(&ve3) && have_same_elements(&ve3, &cloned)); } } diff --git a/src/sorting/counting_sort.rs b/src/sorting/counting_sort.rs index 6488ed5a5ae..e1c22373f7c 100644 --- a/src/sorting/counting_sort.rs +++ b/src/sorting/counting_sort.rs @@ -51,38 +51,43 @@ pub fn generic_counting_sort + From + AddAssign + Copy>( #[cfg(test)] mod test { - use super::super::is_sorted; use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn counting_sort_descending() { let mut ve1 = vec![6, 5, 4, 3, 2, 1]; + let cloned = ve1.clone(); counting_sort(&mut ve1, 6); - assert!(is_sorted(&ve1)); + assert!(is_sorted(&ve1) && have_same_elements(&ve1, &cloned)); } #[test] fn counting_sort_pre_sorted() { let mut ve2 = vec![1, 2, 3, 4, 5, 6]; + let cloned = ve2.clone(); counting_sort(&mut ve2, 6); - assert!(is_sorted(&ve2)); + assert!(is_sorted(&ve2) && have_same_elements(&ve2, &cloned)); } #[test] fn generic_counting_sort() { let mut ve1: Vec = vec![100, 30, 60, 10, 20, 120, 1]; + let cloned = ve1.clone(); super::generic_counting_sort(&mut ve1, 120); - assert!(is_sorted(&ve1)); + assert!(is_sorted(&ve1) && have_same_elements(&ve1, &cloned)); } #[test] fn presorted_u64_counting_sort() { let mut ve2: Vec = vec![1, 2, 3, 4, 5, 6]; + let cloned = ve2.clone(); super::generic_counting_sort(&mut ve2, 6); - assert!(is_sorted(&ve2)); + assert!(is_sorted(&ve2) && have_same_elements(&ve2, &cloned)); } } diff --git a/src/sorting/cycle_sort.rs b/src/sorting/cycle_sort.rs index 99f4b260e0f..44c0947613c 100644 --- a/src/sorting/cycle_sort.rs +++ b/src/sorting/cycle_sort.rs @@ -33,21 +33,28 @@ pub fn cycle_sort(arr: &mut [i32]) { #[cfg(test)] mod tests { - use super::super::is_sorted; + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + #[test] fn it_works() { let mut arr1 = [6, 5, 4, 3, 2, 1]; + let cloned = arr1; cycle_sort(&mut arr1); - assert!(is_sorted(&arr1)); + assert!(is_sorted(&arr1) && have_same_elements(&arr1, &cloned)); arr1 = [12, 343, 21, 90, 3, 21]; + let cloned = arr1; cycle_sort(&mut arr1); - assert!(is_sorted(&arr1)); + assert!(is_sorted(&arr1) && have_same_elements(&arr1, &cloned)); let mut arr2 = [1]; + let cloned = arr2; cycle_sort(&mut arr2); - assert!(is_sorted(&arr2)); + assert!(is_sorted(&arr2) && have_same_elements(&arr2, &cloned)); let mut arr3 = [213, 542, 90, -23412, -32, 324, -34, 3324, 54]; + let cloned = arr3; cycle_sort(&mut arr3); - assert!(is_sorted(&arr3)); + assert!(is_sorted(&arr3) && have_same_elements(&arr3, &cloned)); } } diff --git a/src/sorting/dutch_national_flag_sort.rs b/src/sorting/dutch_national_flag_sort.rs new file mode 100644 index 00000000000..7d24d6d0321 --- /dev/null +++ b/src/sorting/dutch_national_flag_sort.rs @@ -0,0 +1,67 @@ +/* +A Rust implementation of the Dutch National Flag sorting algorithm. + +Reference implementation: https://github.com/TheAlgorithms/Python/blob/master/sorts/dutch_national_flag_sort.py +More info: https://en.wikipedia.org/wiki/Dutch_national_flag_problem +*/ + +#[derive(PartialOrd, PartialEq, Eq)] +pub enum Colors { + Red, // \ + White, // | Define the three colors of the Dutch Flag: 🇳🇱 + Blue, // / +} +use Colors::{Blue, Red, White}; + +// Algorithm implementation +pub fn dutch_national_flag_sort(mut sequence: Vec) -> Vec { + // We take ownership of `sequence` because the original `sequence` will be modified and then returned + let length = sequence.len(); + if length <= 1 { + return sequence; // Arrays of length 0 or 1 are automatically sorted + } + let mut low = 0; + let mut mid = 0; + let mut high = length - 1; + while mid <= high { + match sequence[mid] { + Red => { + sequence.swap(low, mid); + low += 1; + mid += 1; + } + White => { + mid += 1; + } + Blue => { + sequence.swap(mid, high); + high -= 1; + } + } + } + sequence +} + +#[cfg(test)] +mod tests { + use super::super::is_sorted; + use super::*; + + #[test] + fn random_array() { + let arr = vec![ + Red, Blue, White, White, Blue, Blue, Red, Red, White, Blue, White, Red, White, Blue, + ]; + let arr = dutch_national_flag_sort(arr); + assert!(is_sorted(&arr)) + } + + #[test] + fn sorted_array() { + let arr = vec![ + Red, Red, Red, Red, Red, White, White, White, White, White, Blue, Blue, Blue, Blue, + ]; + let arr = dutch_national_flag_sort(arr); + assert!(is_sorted(&arr)) + } +} diff --git a/src/sorting/exchange_sort.rs b/src/sorting/exchange_sort.rs index 762df304faa..bbd0348b9a5 100644 --- a/src/sorting/exchange_sort.rs +++ b/src/sorting/exchange_sort.rs @@ -13,21 +13,26 @@ pub fn exchange_sort(arr: &mut [i32]) { #[cfg(test)] mod tests { - use super::super::is_sorted; use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn it_works() { let mut arr1 = [6, 5, 4, 3, 2, 1]; + let cloned = arr1; exchange_sort(&mut arr1); - assert!(is_sorted(&arr1)); + assert!(is_sorted(&arr1) && have_same_elements(&arr1, &cloned)); arr1 = [12, 343, 21, 90, 3, 21]; + let cloned = arr1; exchange_sort(&mut arr1); - assert!(is_sorted(&arr1)); + assert!(is_sorted(&arr1) && have_same_elements(&arr1, &cloned)); let mut arr2 = [1]; + let cloned = arr2; exchange_sort(&mut arr2); - assert!(is_sorted(&arr2)); + assert!(is_sorted(&arr2) && have_same_elements(&arr2, &cloned)); let mut arr3 = [213, 542, 90, -23412, -32, 324, -34, 3324, 54]; + let cloned = arr3; exchange_sort(&mut arr3); - assert!(is_sorted(&arr3)); + assert!(is_sorted(&arr3) && have_same_elements(&arr3, &cloned)); } } diff --git a/src/sorting/gnome_sort.rs b/src/sorting/gnome_sort.rs index bf73e635dce..cc43ad7963b 100644 --- a/src/sorting/gnome_sort.rs +++ b/src/sorting/gnome_sort.rs @@ -27,34 +27,41 @@ where #[cfg(test)] mod tests { use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn basic() { - let res = gnome_sort(&vec![6, 5, -8, 3, 2, 3]); - assert_eq!(res, vec![-8, 2, 3, 3, 5, 6]); + let original = [6, 5, -8, 3, 2, 3]; + let res = gnome_sort(&original); + assert!(is_sorted(&res) && have_same_elements(&res, &original)); } #[test] fn already_sorted() { - let res = gnome_sort(&vec!["a", "b", "c"]); - assert_eq!(res, vec!["a", "b", "c"]); + let original = gnome_sort(&["a", "b", "c"]); + let res = gnome_sort(&original); + assert!(is_sorted(&res) && have_same_elements(&res, &original)); } #[test] fn odd_number_of_elements() { - let res = gnome_sort(&vec!["d", "a", "c", "e", "b"]); - assert_eq!(res, vec!["a", "b", "c", "d", "e"]); + let original = gnome_sort(&["d", "a", "c", "e", "b"]); + let res = gnome_sort(&original); + assert!(is_sorted(&res) && have_same_elements(&res, &original)); } #[test] fn one_element() { - let res = gnome_sort(&vec![3]); - assert_eq!(res, vec![3]); + let original = gnome_sort(&[3]); + let res = gnome_sort(&original); + assert!(is_sorted(&res) && have_same_elements(&res, &original)); } #[test] fn empty() { - let res = gnome_sort(&Vec::::new()); - assert_eq!(res, vec![]); + let original = gnome_sort(&Vec::::new()); + let res = gnome_sort(&original); + assert!(is_sorted(&res) && have_same_elements(&res, &original)); } } diff --git a/src/sorting/heap_sort.rs b/src/sorting/heap_sort.rs index d860dad0465..7b37a7c5149 100644 --- a/src/sorting/heap_sort.rs +++ b/src/sorting/heap_sort.rs @@ -1,139 +1,114 @@ -/// Sort a mutable slice using heap sort. -/// -/// Heap sort is an in-place O(n log n) sorting algorithm. It is based on a -/// max heap, a binary tree data structure whose main feature is that -/// parent nodes are always greater or equal to their child nodes. -/// -/// # Max Heap Implementation +//! This module provides functions for heap sort algorithm. + +use std::cmp::Ordering; + +/// Builds a heap from the provided array. /// -/// A max heap can be efficiently implemented with an array. -/// For example, the binary tree: -/// ```text -/// 1 -/// 2 3 -/// 4 5 6 7 -/// ``` +/// This function builds either a max heap or a min heap based on the `is_max_heap` parameter. /// -/// ... is represented by the following array: -/// ```text -/// 1 23 4567 -/// ``` +/// # Arguments /// -/// Given the index `i` of a node, parent and child indices can be calculated -/// as follows: -/// ```text -/// parent(i) = (i-1) / 2 -/// left_child(i) = 2*i + 1 -/// right_child(i) = 2*i + 2 -/// ``` +/// * `arr` - A mutable reference to the array to be sorted. +/// * `is_max_heap` - A boolean indicating whether to build a max heap (`true`) or a min heap (`false`). +fn build_heap(arr: &mut [T], is_max_heap: bool) { + let mut i = (arr.len() - 1) / 2; + while i > 0 { + heapify(arr, i, is_max_heap); + i -= 1; + } + heapify(arr, 0, is_max_heap); +} -/// # Algorithm +/// Fixes a heap violation starting at the given index. /// -/// Heap sort has two steps: -/// 1. Convert the input array to a max heap. -/// 2. Partition the array into heap part and sorted part. Initially the -/// heap consists of the whole array and the sorted part is empty: -/// ```text -/// arr: [ heap |] -/// ``` +/// This function adjusts the heap rooted at index `i` to fix the heap property violation. +/// It assumes that the subtrees rooted at left and right children of `i` are already heaps. /// -/// Repeatedly swap the root (i.e. the largest) element of the heap with -/// the last element of the heap and increase the sorted part by one: -/// ```text -/// arr: [ root ... last | sorted ] -/// --> [ last ... | root sorted ] -/// ``` +/// # Arguments /// -/// After each swap, fix the heap to make it a valid max heap again. -/// Once the heap is empty, `arr` is completely sorted. -pub fn heap_sort(arr: &mut [T]) { - if arr.len() <= 1 { - return; // already sorted - } +/// * `arr` - A mutable reference to the array representing the heap. +/// * `i` - The index to start fixing the heap violation. +/// * `is_max_heap` - A boolean indicating whether to maintain a max heap or a min heap. +fn heapify(arr: &mut [T], i: usize, is_max_heap: bool) { + let comparator: fn(&T, &T) -> Ordering = if !is_max_heap { + |a, b| b.cmp(a) + } else { + |a, b| a.cmp(b) + }; - heapify(arr); + let mut idx = i; + let l = 2 * i + 1; + let r = 2 * i + 2; - for end in (1..arr.len()).rev() { - arr.swap(0, end); - move_down(&mut arr[..end], 0); + if l < arr.len() && comparator(&arr[l], &arr[idx]) == Ordering::Greater { + idx = l; } -} -/// Convert `arr` into a max heap. -fn heapify(arr: &mut [T]) { - let last_parent = (arr.len() - 2) / 2; - for i in (0..=last_parent).rev() { - move_down(arr, i); + if r < arr.len() && comparator(&arr[r], &arr[idx]) == Ordering::Greater { + idx = r; + } + + if idx != i { + arr.swap(i, idx); + heapify(arr, idx, is_max_heap); } } -/// Move the element at `root` down until `arr` is a max heap again. +/// Sorts the given array using heap sort algorithm. /// -/// This assumes that the subtrees under `root` are valid max heaps already. -fn move_down(arr: &mut [T], mut root: usize) { - let last = arr.len() - 1; - loop { - let left = 2 * root + 1; - if left > last { - break; - } - let right = left + 1; - let max = if right <= last && arr[right] > arr[left] { - right - } else { - left - }; +/// This function sorts the array either in ascending or descending order based on the `ascending` parameter. +/// +/// # Arguments +/// +/// * `arr` - A mutable reference to the array to be sorted. +/// * `ascending` - A boolean indicating whether to sort in ascending order (`true`) or descending order (`false`). +pub fn heap_sort(arr: &mut [T], ascending: bool) { + if arr.len() <= 1 { + return; + } - if arr[max] > arr[root] { - arr.swap(root, max); - } - root = max; + // Build heap based on the order + build_heap(arr, ascending); + + let mut end = arr.len() - 1; + while end > 0 { + arr.swap(0, end); + heapify(&mut arr[..end], 0, ascending); + end -= 1; } } #[cfg(test)] mod tests { - use super::*; - - #[test] - fn empty() { - let mut arr: Vec = Vec::new(); - heap_sort(&mut arr); - assert_eq!(&arr, &[]); - } - - #[test] - fn single_element() { - let mut arr = vec![1]; - heap_sort(&mut arr); - assert_eq!(&arr, &[1]); - } + use crate::sorting::{have_same_elements, heap_sort, is_descending_sorted, is_sorted}; - #[test] - fn sorted_array() { - let mut arr = vec![1, 2, 3, 4]; - heap_sort(&mut arr); - assert_eq!(&arr, &[1, 2, 3, 4]); - } + macro_rules! test_heap_sort { + ($($name:ident: $input:expr,)*) => { + $( + #[test] + fn $name() { + let input_array = $input; + let mut arr_asc = input_array.clone(); + heap_sort(&mut arr_asc, true); + assert!(is_sorted(&arr_asc) && have_same_elements(&arr_asc, &input_array)); - #[test] - fn unsorted_array() { - let mut arr = vec![3, 4, 2, 1]; - heap_sort(&mut arr); - assert_eq!(&arr, &[1, 2, 3, 4]); - } - - #[test] - fn odd_number_of_elements() { - let mut arr = vec![3, 4, 2, 1, 7]; - heap_sort(&mut arr); - assert_eq!(&arr, &[1, 2, 3, 4, 7]); + let mut arr_dsc = input_array.clone(); + heap_sort(&mut arr_dsc, false); + assert!(is_descending_sorted(&arr_dsc) && have_same_elements(&arr_dsc, &input_array)); + } + )* + } } - #[test] - fn repeated_elements() { - let mut arr = vec![542, 542, 542, 542]; - heap_sort(&mut arr); - assert_eq!(&arr, &vec![542, 542, 542, 542]); + test_heap_sort! { + empty_array: Vec::::new(), + single_element_array: vec![5], + sorted: vec![1, 2, 3, 4, 5], + sorted_desc: vec![5, 4, 3, 2, 1, 0], + basic_0: vec![9, 8, 7, 6, 5], + basic_1: vec![8, 3, 1, 5, 7], + basic_2: vec![4, 5, 7, 1, 2, 3, 2, 8, 5, 4, 9, 9, 100, 1, 2, 3, 6, 4, 3], + duplicated_elements: vec![5, 5, 5, 5, 5], + strings: vec!["aa", "a", "ba", "ab"], } } diff --git a/src/sorting/insertion_sort.rs b/src/sorting/insertion_sort.rs index b5550c75110..ab33241b045 100644 --- a/src/sorting/insertion_sort.rs +++ b/src/sorting/insertion_sort.rs @@ -1,79 +1,72 @@ -use std::cmp; - /// Sorts a mutable slice using in-place insertion sort algorithm. /// /// Time complexity is `O(n^2)`, where `n` is the number of elements. /// Space complexity is `O(1)` as it sorts elements in-place. -pub fn insertion_sort(arr: &mut [T]) -where - T: cmp::PartialOrd + Copy, -{ +pub fn insertion_sort(arr: &mut [T]) { for i in 1..arr.len() { + let mut j = i; let cur = arr[i]; - let mut j = i - 1; - while arr[j] > cur { - arr[j + 1] = arr[j]; - if j == 0 { - break; - } + while j > 0 && cur < arr[j - 1] { + arr[j] = arr[j - 1]; j -= 1; } - // we exit the loop from that break statement - if j == 0 && arr[0] > cur { - arr[0] = cur; - } else { - // `arr[j] > cur` is not satsified, exit from condition judgement - arr[j + 1] = cur; - } + arr[j] = cur; } } #[cfg(test)] mod tests { - use super::super::is_sorted; use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn empty() { let mut arr: [u8; 0] = []; + let cloned = arr; insertion_sort(&mut arr); - assert!(is_sorted(&arr)); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); } #[test] fn one_element() { let mut arr: [char; 1] = ['a']; + let cloned = arr; insertion_sort(&mut arr); - assert!(is_sorted(&arr)); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); } #[test] fn already_sorted() { let mut arr: [&str; 3] = ["a", "b", "c"]; + let cloned = arr; insertion_sort(&mut arr); - assert!(is_sorted(&arr)); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); } #[test] fn basic() { let mut arr: [&str; 4] = ["d", "a", "c", "b"]; + let cloned = arr; insertion_sort(&mut arr); - assert!(is_sorted(&arr)); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); } #[test] fn odd_number_of_elements() { let mut arr: Vec<&str> = vec!["d", "a", "c", "e", "b"]; + let cloned = arr.clone(); insertion_sort(&mut arr); - assert!(is_sorted(&arr)); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); } #[test] fn repeated_elements() { let mut arr: Vec = vec![542, 542, 542, 542]; + let cloned = arr.clone(); insertion_sort(&mut arr); - assert!(is_sorted(&arr)); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); } } diff --git a/src/sorting/intro_sort.rs b/src/sorting/intro_sort.rs new file mode 100755 index 00000000000..d6b31c17a02 --- /dev/null +++ b/src/sorting/intro_sort.rs @@ -0,0 +1,107 @@ +// Intro Sort (Also known as Introspective Sort) +// Introspective Sort is hybrid sort (Quick Sort + Heap Sort + Insertion Sort) +// https://en.wikipedia.org/wiki/Introsort +fn insertion_sort(arr: &mut [T]) { + for i in 1..arr.len() { + let mut j = i; + while j > 0 && arr[j] < arr[j - 1] { + arr.swap(j, j - 1); + j -= 1; + } + } +} + +fn heapify(arr: &mut [T], n: usize, i: usize) { + let mut largest = i; + let left = 2 * i + 1; + let right = 2 * i + 2; + + if left < n && arr[left] > arr[largest] { + largest = left; + } + + if right < n && arr[right] > arr[largest] { + largest = right; + } + + if largest != i { + arr.swap(i, largest); + heapify(arr, n, largest); + } +} + +fn heap_sort(arr: &mut [T]) { + let n = arr.len(); + + // Build a max-heap + for i in (0..n / 2).rev() { + heapify(arr, n, i); + } + + // Extract elements from the heap one by one + for i in (0..n).rev() { + arr.swap(0, i); + heapify(arr, i, 0); + } +} + +pub fn intro_sort(arr: &mut [T]) { + let len = arr.len(); + let max_depth = (2.0 * len as f64).log2() as usize + 1; + + fn intro_sort_recursive(arr: &mut [T], max_depth: usize) { + let len = arr.len(); + + if len <= 16 { + insertion_sort(arr); + } else if max_depth == 0 { + heap_sort(arr); + } else { + let pivot = partition(arr); + intro_sort_recursive(&mut arr[..pivot], max_depth - 1); + intro_sort_recursive(&mut arr[pivot + 1..], max_depth - 1); + } + } + + fn partition(arr: &mut [T]) -> usize { + let len = arr.len(); + let pivot_index = len / 2; + arr.swap(pivot_index, len - 1); + + let mut i = 0; + for j in 0..len - 1 { + if arr[j] <= arr[len - 1] { + arr.swap(i, j); + i += 1; + } + } + + arr.swap(i, len - 1); + i + } + + intro_sort_recursive(arr, max_depth); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_intro_sort() { + // Test with integers + let mut arr1 = vec![67, 34, 29, 15, 21, 9, 99]; + intro_sort(&mut arr1); + assert_eq!(arr1, vec![9, 15, 21, 29, 34, 67, 99]); + + // Test with strings + let mut arr2 = vec!["sydney", "london", "tokyo", "beijing", "mumbai"]; + intro_sort(&mut arr2); + assert_eq!(arr2, vec!["beijing", "london", "mumbai", "sydney", "tokyo"]); + + // Test with an empty array + let mut arr3: Vec = vec![]; + intro_sort(&mut arr3); + assert_eq!(arr3, vec![]); + } +} diff --git a/src/sorting/merge_sort.rs b/src/sorting/merge_sort.rs index 317c2e40b76..4c184c86110 100644 --- a/src/sorting/merge_sort.rs +++ b/src/sorting/merge_sort.rs @@ -19,61 +19,151 @@ fn merge(arr: &mut [T], mid: usize) { } } -pub fn merge_sort(arr: &mut [T]) { +pub fn top_down_merge_sort(arr: &mut [T]) { if arr.len() > 1 { let mid = arr.len() / 2; // Sort the left half recursively. - merge_sort(&mut arr[..mid]); + top_down_merge_sort(&mut arr[..mid]); // Sort the right half recursively. - merge_sort(&mut arr[mid..]); + top_down_merge_sort(&mut arr[mid..]); // Combine the two halves. merge(arr, mid); } } +pub fn bottom_up_merge_sort(a: &mut [T]) { + if a.len() > 1 { + let len: usize = a.len(); + let mut sub_array_size: usize = 1; + while sub_array_size < len { + let mut start_index: usize = 0; + // still have more than one sub-arrays to merge + while len - start_index > sub_array_size { + let end_idx: usize = if start_index + 2 * sub_array_size > len { + len + } else { + start_index + 2 * sub_array_size + }; + // merge a[start_index..start_index+sub_array_size] and a[start_index+sub_array_size..end_idx] + // NOTE: mid is a relative index number starting from `start_index` + merge(&mut a[start_index..end_idx], sub_array_size); + // update `start_index` to merge the next sub-arrays + start_index = end_idx; + } + sub_array_size *= 2; + } + } +} + #[cfg(test)] mod tests { - use super::*; + #[cfg(test)] + mod top_down_merge_sort { + use super::super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; - #[test] - fn basic() { - let mut res = vec![10, 8, 4, 3, 1, 9, 2, 7, 5, 6]; - merge_sort(&mut res); - assert_eq!(res, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); - } + #[test] + fn basic() { + let mut res = vec![10, 8, 4, 3, 1, 9, 2, 7, 5, 6]; + let cloned = res.clone(); + top_down_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } - #[test] - fn basic_string() { - let mut res = vec!["a", "bb", "d", "cc"]; - merge_sort(&mut res); - assert_eq!(res, vec!["a", "bb", "cc", "d"]); - } + #[test] + fn basic_string() { + let mut res = vec!["a", "bb", "d", "cc"]; + let cloned = res.clone(); + top_down_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } - #[test] - fn empty() { - let mut res = Vec::::new(); - merge_sort(&mut res); - assert_eq!(res, vec![]); - } + #[test] + fn empty() { + let mut res = Vec::::new(); + let cloned = res.clone(); + top_down_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } - #[test] - fn one_element() { - let mut res = vec![1]; - merge_sort(&mut res); - assert_eq!(res, vec![1]); - } + #[test] + fn one_element() { + let mut res = vec![1]; + let cloned = res.clone(); + top_down_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn pre_sorted() { + let mut res = vec![1, 2, 3, 4]; + let cloned = res.clone(); + top_down_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } - #[test] - fn pre_sorted() { - let mut res = vec![1, 2, 3, 4]; - merge_sort(&mut res); - assert_eq!(res, vec![1, 2, 3, 4]); + #[test] + fn reverse_sorted() { + let mut res = vec![4, 3, 2, 1]; + let cloned = res.clone(); + top_down_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } } - #[test] - fn reverse_sorted() { - let mut res = vec![4, 3, 2, 1]; - merge_sort(&mut res); - assert_eq!(res, vec![1, 2, 3, 4]); + #[cfg(test)] + mod bottom_up_merge_sort { + use super::super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + + #[test] + fn basic() { + let mut res = vec![10, 8, 4, 3, 1, 9, 2, 7, 5, 6]; + let cloned = res.clone(); + bottom_up_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn basic_string() { + let mut res = vec!["a", "bb", "d", "cc"]; + let cloned = res.clone(); + bottom_up_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn empty() { + let mut res = Vec::::new(); + let cloned = res.clone(); + bottom_up_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn one_element() { + let mut res = vec![1]; + let cloned = res.clone(); + bottom_up_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn pre_sorted() { + let mut res = vec![1, 2, 3, 4]; + let cloned = res.clone(); + bottom_up_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn reverse_sorted() { + let mut res = vec![4, 3, 2, 1]; + let cloned = res.clone(); + bottom_up_merge_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } } } diff --git a/src/sorting/mod.rs b/src/sorting/mod.rs index f78e6c96c2e..79be2b0b9e6 100644 --- a/src/sorting/mod.rs +++ b/src/sorting/mod.rs @@ -1,3 +1,7 @@ +mod bead_sort; +mod binary_insertion_sort; +mod bingo_sort; +mod bitonic_sort; mod bogo_sort; mod bubble_sort; mod bucket_sort; @@ -5,21 +9,35 @@ mod cocktail_shaker_sort; mod comb_sort; mod counting_sort; mod cycle_sort; +mod dutch_national_flag_sort; mod exchange_sort; mod gnome_sort; mod heap_sort; mod insertion_sort; +mod intro_sort; mod merge_sort; mod odd_even_sort; mod pancake_sort; +mod patience_sort; mod pigeonhole_sort; mod quick_sort; +mod quick_sort_3_ways; mod radix_sort; mod selection_sort; mod shell_sort; +mod sleep_sort; +#[cfg(test)] +mod sort_utils; mod stooge_sort; mod tim_sort; +mod tree_sort; +mod wave_sort; +mod wiggle_sort; +pub use self::bead_sort::bead_sort; +pub use self::binary_insertion_sort::binary_insertion_sort; +pub use self::bingo_sort::bingo_sort; +pub use self::bitonic_sort::bitonic_sort; pub use self::bogo_sort::bogo_sort; pub use self::bubble_sort::bubble_sort; pub use self::bucket_sort::bucket_sort; @@ -28,42 +46,69 @@ pub use self::comb_sort::comb_sort; pub use self::counting_sort::counting_sort; pub use self::counting_sort::generic_counting_sort; pub use self::cycle_sort::cycle_sort; +pub use self::dutch_national_flag_sort::dutch_national_flag_sort; pub use self::exchange_sort::exchange_sort; pub use self::gnome_sort::gnome_sort; pub use self::heap_sort::heap_sort; pub use self::insertion_sort::insertion_sort; -pub use self::merge_sort::merge_sort; +pub use self::intro_sort::intro_sort; +pub use self::merge_sort::bottom_up_merge_sort; +pub use self::merge_sort::top_down_merge_sort; pub use self::odd_even_sort::odd_even_sort; pub use self::pancake_sort::pancake_sort; +pub use self::patience_sort::patience_sort; pub use self::pigeonhole_sort::pigeonhole_sort; pub use self::quick_sort::{partition, quick_sort}; +pub use self::quick_sort_3_ways::quick_sort_3_ways; pub use self::radix_sort::radix_sort; pub use self::selection_sort::selection_sort; pub use self::shell_sort::shell_sort; +pub use self::sleep_sort::sleep_sort; pub use self::stooge_sort::stooge_sort; pub use self::tim_sort::tim_sort; +pub use self::tree_sort::tree_sort; +pub use self::wave_sort::wave_sort; +pub use self::wiggle_sort::wiggle_sort; +#[cfg(test)] use std::cmp; -pub fn is_sorted(arr: &[T]) -> bool +#[cfg(test)] +pub fn have_same_elements(a: &[T], b: &[T]) -> bool where - T: cmp::PartialOrd, + // T: cmp::PartialOrd, + // If HashSet is used + T: cmp::PartialOrd + cmp::Eq + std::hash::Hash, { - if arr.is_empty() { - return true; - } - - let mut prev = &arr[0]; + use std::collections::HashSet; - for item in arr.iter().skip(1) { - if prev > item { - return false; - } + if a.len() == b.len() { + // This is O(n^2) but performs better on smaller data sizes + //b.iter().all(|item| a.contains(item)) - prev = item; + // This is O(n), performs well on larger data sizes + let set_a: HashSet<&T> = a.iter().collect(); + let set_b: HashSet<&T> = b.iter().collect(); + set_a == set_b + } else { + false } +} - true +#[cfg(test)] +pub fn is_sorted(arr: &[T]) -> bool +where + T: cmp::PartialOrd, +{ + arr.windows(2).all(|w| w[0] <= w[1]) +} + +#[cfg(test)] +pub fn is_descending_sorted(arr: &[T]) -> bool +where + T: cmp::PartialOrd, +{ + arr.windows(2).all(|w| w[0] >= w[1]) } #[cfg(test)] @@ -77,7 +122,7 @@ mod tests { assert!(is_sorted(&[1, 2, 3])); assert!(is_sorted(&[0, 1, 1])); - assert_eq!(is_sorted(&[1, 0]), false); - assert_eq!(is_sorted(&[2, 3, 1, -1, 5]), false); + assert!(!is_sorted(&[1, 0])); + assert!(!is_sorted(&[2, 3, 1, -1, 5])); } } diff --git a/src/sorting/odd_even_sort.rs b/src/sorting/odd_even_sort.rs index 8a7db51c31b..c22a1c4daa1 100644 --- a/src/sorting/odd_even_sort.rs +++ b/src/sorting/odd_even_sort.rs @@ -27,32 +27,38 @@ pub fn odd_even_sort(arr: &mut [T]) { #[cfg(test)] mod tests { use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn basic() { let mut arr = vec![3, 5, 1, 2, 4, 6]; + let cloned = arr.clone(); odd_even_sort(&mut arr); - assert_eq!(arr, vec![1, 2, 3, 4, 5, 6]); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); } #[test] fn empty() { let mut arr = Vec::::new(); + let cloned = arr.clone(); odd_even_sort(&mut arr); - assert_eq!(arr, vec![]); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); } #[test] fn one_element() { let mut arr = vec![3]; + let cloned = arr.clone(); odd_even_sort(&mut arr); - assert_eq!(arr, vec![3]); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); } #[test] fn pre_sorted() { let mut arr = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + let cloned = arr.clone(); odd_even_sort(&mut arr); - assert_eq!(arr, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + assert!(is_sorted(&arr) && have_same_elements(&arr, &cloned)); } } diff --git a/src/sorting/pancake_sort.rs b/src/sorting/pancake_sort.rs index 223e127b890..c37b646ca1a 100644 --- a/src/sorting/pancake_sort.rs +++ b/src/sorting/pancake_sort.rs @@ -17,8 +17,8 @@ where .map(|(idx, _)| idx) .unwrap(); if max_index != i { - arr[0..max_index + 1].reverse(); - arr[0..i + 1].reverse(); + arr[0..=max_index].reverse(); + arr[0..=i].reverse(); } } arr.to_vec() @@ -30,31 +30,31 @@ mod tests { #[test] fn basic() { - let res = pancake_sort(&mut vec![6, 5, -8, 3, 2, 3]); + let res = pancake_sort(&mut [6, 5, -8, 3, 2, 3]); assert_eq!(res, vec![-8, 2, 3, 3, 5, 6]); } #[test] fn already_sorted() { - let res = pancake_sort(&mut vec!["a", "b", "c"]); + let res = pancake_sort(&mut ["a", "b", "c"]); assert_eq!(res, vec!["a", "b", "c"]); } #[test] fn odd_number_of_elements() { - let res = pancake_sort(&mut vec!["d", "a", "c", "e", "b"]); + let res = pancake_sort(&mut ["d", "a", "c", "e", "b"]); assert_eq!(res, vec!["a", "b", "c", "d", "e"]); } #[test] fn one_element() { - let res = pancake_sort(&mut vec![3]); + let res = pancake_sort(&mut [3]); assert_eq!(res, vec![3]); } #[test] fn empty() { - let res = pancake_sort(&mut Vec::::new()); + let res = pancake_sort(&mut [] as &mut [u8]); assert_eq!(res, vec![]); } } diff --git a/src/sorting/patience_sort.rs b/src/sorting/patience_sort.rs new file mode 100644 index 00000000000..662d8ceefcb --- /dev/null +++ b/src/sorting/patience_sort.rs @@ -0,0 +1,86 @@ +use std::vec; + +pub fn patience_sort(arr: &mut [T]) { + if arr.is_empty() { + return; + } + + // collect piles from arr + let mut piles: Vec> = Vec::new(); + for &card in arr.iter() { + let mut left = 0usize; + let mut right = piles.len(); + + while left < right { + let mid = left + (right - left) / 2; + if piles[mid][piles[mid].len() - 1] >= card { + right = mid; + } else { + left = mid + 1; + } + } + + if left == piles.len() { + piles.push(vec![card]); + } else { + piles[left].push(card); + } + } + + // merge the piles + let mut idx = 0usize; + while let Some((min_id, pile)) = piles + .iter() + .enumerate() + .min_by_key(|(_, pile)| *pile.last().unwrap()) + { + arr[idx] = *pile.last().unwrap(); + idx += 1; + piles[min_id].pop(); + + if piles[min_id].is_empty() { + _ = piles.remove(min_id); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + + #[test] + fn basic() { + let mut array = vec![ + -2, 7, 15, -14, 0, 15, 0, 10_033, 7, -7, -4, -13, 5, 8, -14, 12, + ]; + let cloned = array.clone(); + patience_sort(&mut array); + assert!(is_sorted(&array) && have_same_elements(&array, &cloned)); + } + + #[test] + fn empty() { + let mut array = Vec::::new(); + let cloned = array.clone(); + patience_sort(&mut array); + assert!(is_sorted(&array) && have_same_elements(&array, &cloned)); + } + + #[test] + fn one_element() { + let mut array = vec![3]; + let cloned = array.clone(); + patience_sort(&mut array); + assert!(is_sorted(&array) && have_same_elements(&array, &cloned)); + } + + #[test] + fn pre_sorted() { + let mut array = vec![-123_456, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + let cloned = array.clone(); + patience_sort(&mut array); + assert!(is_sorted(&array) && have_same_elements(&array, &cloned)); + } +} diff --git a/src/sorting/quick_sort.rs b/src/sorting/quick_sort.rs index ff0c087f5ed..102edc6337d 100644 --- a/src/sorting/quick_sort.rs +++ b/src/sorting/quick_sort.rs @@ -1,36 +1,137 @@ -use std::cmp::PartialOrd; - -pub fn partition(arr: &mut [T], lo: isize, hi: isize) -> isize { - let pivot = hi as usize; - let mut i = lo - 1; - let mut j = hi; +pub fn partition(arr: &mut [T], lo: usize, hi: usize) -> usize { + let pivot = hi; + let mut i = lo; + let mut j = hi - 1; loop { - i += 1; - while arr[i as usize] < arr[pivot] { + while arr[i] < arr[pivot] { i += 1; } - j -= 1; - while j >= 0 && arr[j as usize] > arr[pivot] { + while j > 0 && arr[j] > arr[pivot] { j -= 1; } - if i >= j { + if j == 0 || i >= j { break; + } else if arr[i] == arr[j] { + i += 1; + j -= 1; } else { - arr.swap(i as usize, j as usize); + arr.swap(i, j); } } - arr.swap(i as usize, pivot as usize); + arr.swap(i, pivot); i } -fn _quick_sort(arr: &mut [T], lo: isize, hi: isize) { - if lo < hi { - let p = partition(arr, lo, hi); - _quick_sort(arr, lo, p - 1); - _quick_sort(arr, p + 1, hi); + +fn _quick_sort(arr: &mut [T], mut lo: usize, mut hi: usize) { + while lo < hi { + let pivot = partition(arr, lo, hi); + + if pivot - lo < hi - pivot { + if pivot > 0 { + _quick_sort(arr, lo, pivot - 1); + } + lo = pivot + 1; + } else { + _quick_sort(arr, pivot + 1, hi); + hi = pivot - 1; + } } } + pub fn quick_sort(arr: &mut [T]) { let len = arr.len(); - _quick_sort(arr, 0, (len - 1) as isize); + if len > 1 { + _quick_sort(arr, 0, len - 1); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + use crate::sorting::sort_utils; + + #[test] + fn basic() { + let mut res = vec![10, 8, 4, 3, 1, 9, 2, 7, 5, 6]; + let cloned = res.clone(); + quick_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn basic_string() { + let mut res = vec!["a", "bb", "d", "cc"]; + let cloned = res.clone(); + quick_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn empty() { + let mut res = Vec::::new(); + let cloned = res.clone(); + quick_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn one_element() { + let mut res = vec![1]; + let cloned = res.clone(); + quick_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn pre_sorted() { + let mut res = vec![1, 2, 3, 4]; + let cloned = res.clone(); + quick_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn reverse_sorted() { + let mut res = vec![4, 3, 2, 1]; + let cloned = res.clone(); + quick_sort(&mut res); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn large_elements() { + let mut res = sort_utils::generate_random_vec(300000, 0, 1000000); + let cloned = res.clone(); + sort_utils::log_timed("large elements test", || { + quick_sort(&mut res); + }); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn nearly_ordered_elements() { + let mut res = sort_utils::generate_nearly_ordered_vec(3000, 10); + let cloned = res.clone(); + + sort_utils::log_timed("nearly ordered elements test", || { + quick_sort(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn repeated_elements() { + let mut res = sort_utils::generate_repeated_elements_vec(1000000, 3); + let cloned = res.clone(); + + sort_utils::log_timed("repeated elements test", || { + quick_sort(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } } diff --git a/src/sorting/quick_sort_3_ways.rs b/src/sorting/quick_sort_3_ways.rs new file mode 100644 index 00000000000..cf333114170 --- /dev/null +++ b/src/sorting/quick_sort_3_ways.rs @@ -0,0 +1,156 @@ +use std::cmp::{Ord, Ordering}; + +use rand::Rng; + +fn _quick_sort_3_ways(arr: &mut [T], lo: usize, hi: usize) { + if lo >= hi { + return; + } + + let mut rng = rand::rng(); + arr.swap(lo, rng.random_range(lo..=hi)); + + let mut lt = lo; // arr[lo+1, lt] < v + let mut gt = hi + 1; // arr[gt, r] > v + let mut i = lo + 1; // arr[lt + 1, i) == v + + while i < gt { + match arr[i].cmp(&arr[lo]) { + Ordering::Less => { + arr.swap(i, lt + 1); + i += 1; + lt += 1; + } + Ordering::Greater => { + arr.swap(i, gt - 1); + gt -= 1; + } + Ordering::Equal => { + i += 1; + } + } + } + + arr.swap(lo, lt); + + if lt > 1 { + _quick_sort_3_ways(arr, lo, lt - 1); + } + + _quick_sort_3_ways(arr, gt, hi); +} + +pub fn quick_sort_3_ways(arr: &mut [T]) { + let len = arr.len(); + if len > 1 { + _quick_sort_3_ways(arr, 0, len - 1); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; + use crate::sorting::sort_utils; + + #[test] + fn basic() { + let mut res = vec![10, 8, 4, 3, 1, 9, 2, 7, 5, 6]; + let cloned = res.clone(); + sort_utils::log_timed("basic", || { + quick_sort_3_ways(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn basic_string() { + let mut res = vec!["a", "bb", "d", "cc"]; + let cloned = res.clone(); + sort_utils::log_timed("basic string", || { + quick_sort_3_ways(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn empty() { + let mut res = Vec::::new(); + let cloned = res.clone(); + sort_utils::log_timed("empty", || { + quick_sort_3_ways(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn one_element() { + let mut res = sort_utils::generate_random_vec(1, 0, 1); + let cloned = res.clone(); + sort_utils::log_timed("one element", || { + quick_sort_3_ways(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn pre_sorted() { + let mut res = sort_utils::generate_nearly_ordered_vec(300000, 0); + let cloned = res.clone(); + sort_utils::log_timed("pre sorted", || { + quick_sort_3_ways(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn reverse_sorted() { + let mut res = sort_utils::generate_reverse_ordered_vec(300000); + let cloned = res.clone(); + sort_utils::log_timed("reverse sorted", || { + quick_sort_3_ways(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn large_elements() { + let mut res = sort_utils::generate_random_vec(300000, 0, 1000000); + let cloned = res.clone(); + sort_utils::log_timed("large elements test", || { + quick_sort_3_ways(&mut res); + }); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn nearly_ordered_elements() { + let mut res = sort_utils::generate_nearly_ordered_vec(300000, 10); + let cloned = res.clone(); + + sort_utils::log_timed("nearly ordered elements test", || { + quick_sort_3_ways(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } + + #[test] + fn repeated_elements() { + let mut res = sort_utils::generate_repeated_elements_vec(1000000, 3); + let cloned = res.clone(); + + sort_utils::log_timed("repeated elements test", || { + quick_sort_3_ways(&mut res); + }); + + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); + } +} diff --git a/src/sorting/radix_sort.rs b/src/sorting/radix_sort.rs index 19c081a591b..8f14efb2058 100644 --- a/src/sorting/radix_sort.rs +++ b/src/sorting/radix_sort.rs @@ -36,27 +36,31 @@ pub fn radix_sort(arr: &mut [u64]) { #[cfg(test)] mod tests { - use super::super::is_sorted; use super::radix_sort; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn empty() { let mut a: [u64; 0] = []; + let cloned = a; radix_sort(&mut a); - assert!(is_sorted(&a)); + assert!(is_sorted(&a) && have_same_elements(&a, &cloned)); } #[test] fn descending() { let mut v = vec![201, 127, 64, 37, 24, 4, 1]; + let cloned = v.clone(); radix_sort(&mut v); - assert!(is_sorted(&v)); + assert!(is_sorted(&v) && have_same_elements(&v, &cloned)); } #[test] fn ascending() { let mut v = vec![1, 4, 24, 37, 64, 127, 201]; + let cloned = v.clone(); radix_sort(&mut v); - assert!(is_sorted(&v)); + assert!(is_sorted(&v) && have_same_elements(&v, &cloned)); } } diff --git a/src/sorting/selection_sort.rs b/src/sorting/selection_sort.rs index 59fa6e0bf48..eebaebed3aa 100644 --- a/src/sorting/selection_sort.rs +++ b/src/sorting/selection_sort.rs @@ -14,32 +14,38 @@ pub fn selection_sort(arr: &mut [T]) { #[cfg(test)] mod tests { use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn basic() { let mut res = vec!["d", "a", "c", "b"]; + let cloned = res.clone(); selection_sort(&mut res); - assert_eq!(res, vec!["a", "b", "c", "d"]); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); } #[test] fn empty() { let mut res = Vec::::new(); + let cloned = res.clone(); selection_sort(&mut res); - assert_eq!(res, vec![]); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); } #[test] fn one_element() { let mut res = vec!["a"]; + let cloned = res.clone(); selection_sort(&mut res); - assert_eq!(res, vec!["a"]); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); } #[test] fn pre_sorted() { let mut res = vec!["a", "b", "c"]; + let cloned = res.clone(); selection_sort(&mut res); - assert_eq!(res, vec!["a", "b", "c"]); + assert!(is_sorted(&res) && have_same_elements(&res, &cloned)); } } diff --git a/src/sorting/shell_sort.rs b/src/sorting/shell_sort.rs index 14c64b647c0..40a61b223b5 100644 --- a/src/sorting/shell_sort.rs +++ b/src/sorting/shell_sort.rs @@ -25,38 +25,38 @@ pub fn shell_sort(values: &mut [T]) { #[cfg(test)] mod test { use super::shell_sort; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn basic() { let mut vec = vec![3, 5, 6, 3, 1, 4]; + let cloned = vec.clone(); shell_sort(&mut vec); - for i in 0..vec.len() - 1 { - assert!(vec[i] <= vec[i + 1]); - } + assert!(is_sorted(&vec) && have_same_elements(&vec, &cloned)); } #[test] fn empty() { let mut vec: Vec = vec![]; + let cloned = vec.clone(); shell_sort(&mut vec); - assert_eq!(vec, vec![]); + assert!(is_sorted(&vec) && have_same_elements(&vec, &cloned)); } #[test] fn reverse() { let mut vec = vec![6, 5, 4, 3, 2, 1]; + let cloned = vec.clone(); shell_sort(&mut vec); - for i in 0..vec.len() - 1 { - assert!(vec[i] <= vec[i + 1]); - } + assert!(is_sorted(&vec) && have_same_elements(&vec, &cloned)); } #[test] fn already_sorted() { let mut vec = vec![1, 2, 3, 4, 5, 6]; + let cloned = vec.clone(); shell_sort(&mut vec); - for i in 0..vec.len() - 1 { - assert!(vec[i] <= vec[i + 1]); - } + assert!(is_sorted(&vec) && have_same_elements(&vec, &cloned)); } } diff --git a/src/sorting/sleep_sort.rs b/src/sorting/sleep_sort.rs new file mode 100644 index 00000000000..184eedb2539 --- /dev/null +++ b/src/sorting/sleep_sort.rs @@ -0,0 +1,70 @@ +use std::sync::mpsc; +use std::thread; +use std::time::Duration; + +pub fn sleep_sort(vec: &[usize]) -> Vec { + let len = vec.len(); + let (tx, rx) = mpsc::channel(); + + for &x in vec.iter() { + let tx: mpsc::Sender = tx.clone(); + thread::spawn(move || { + thread::sleep(Duration::from_millis((20 * x) as u64)); + tx.send(x).expect("panic"); + }); + } + let mut sorted_list: Vec = Vec::new(); + + for _ in 0..len { + sorted_list.push(rx.recv().unwrap()) + } + + sorted_list +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn empty() { + let res = sleep_sort(&[]); + assert_eq!(res, &[]); + } + + #[test] + fn single_element() { + let res = sleep_sort(&[1]); + assert_eq!(res, &[1]); + } + + #[test] + fn sorted_array() { + let res = sleep_sort(&[1, 2, 3, 4]); + assert_eq!(res, &[1, 2, 3, 4]); + } + + #[test] + fn unsorted_array() { + let res = sleep_sort(&[3, 4, 2, 1]); + assert_eq!(res, &[1, 2, 3, 4]); + } + + #[test] + fn odd_number_of_elements() { + let res = sleep_sort(&[3, 1, 7]); + assert_eq!(res, &[1, 3, 7]); + } + + #[test] + fn repeated_elements() { + let res = sleep_sort(&[1, 1, 1, 1]); + assert_eq!(res, &[1, 1, 1, 1]); + } + + #[test] + fn random_elements() { + let res = sleep_sort(&[5, 3, 7, 10, 1, 0, 8]); + assert_eq!(res, &[0, 1, 3, 5, 7, 8, 10]); + } +} diff --git a/src/sorting/sort_utils.rs b/src/sorting/sort_utils.rs new file mode 100644 index 00000000000..140d10a7f33 --- /dev/null +++ b/src/sorting/sort_utils.rs @@ -0,0 +1,63 @@ +use rand::Rng; +use std::time::Instant; + +#[cfg(test)] +pub fn generate_random_vec(n: u32, range_l: i32, range_r: i32) -> Vec { + let mut arr = Vec::::with_capacity(n as usize); + let mut rng = rand::rng(); + let mut count = n; + + while count > 0 { + arr.push(rng.random_range(range_l..=range_r)); + count -= 1; + } + + arr +} + +#[cfg(test)] +pub fn generate_nearly_ordered_vec(n: u32, swap_times: u32) -> Vec { + let mut arr: Vec = (0..n as i32).collect(); + let mut rng = rand::rng(); + + let mut count = swap_times; + + while count > 0 { + arr.swap( + rng.random_range(0..n as usize), + rng.random_range(0..n as usize), + ); + count -= 1; + } + + arr +} + +#[cfg(test)] +pub fn generate_ordered_vec(n: u32) -> Vec { + generate_nearly_ordered_vec(n, 0) +} + +#[cfg(test)] +pub fn generate_reverse_ordered_vec(n: u32) -> Vec { + let mut arr = generate_ordered_vec(n); + arr.reverse(); + arr +} + +#[cfg(test)] +pub fn generate_repeated_elements_vec(n: u32, unique_elements: u8) -> Vec { + let mut rng = rand::rng(); + let v = rng.random_range(0..n as i32); + generate_random_vec(n, v, v + unique_elements as i32) +} + +#[cfg(test)] +pub fn log_timed(test_name: &str, f: F) +where + F: FnOnce(), +{ + let before = Instant::now(); + f(); + println!("Elapsed time of {:?} is {:?}", test_name, before.elapsed()); +} diff --git a/src/sorting/stooge_sort.rs b/src/sorting/stooge_sort.rs index 6397f967125..bae8b5bfdac 100644 --- a/src/sorting/stooge_sort.rs +++ b/src/sorting/stooge_sort.rs @@ -26,38 +26,38 @@ pub fn stooge_sort(arr: &mut [T]) { #[cfg(test)] mod test { use super::*; + use crate::sorting::have_same_elements; + use crate::sorting::is_sorted; #[test] fn basic() { let mut vec = vec![3, 5, 6, 3, 1, 4]; + let cloned = vec.clone(); stooge_sort(&mut vec); - for i in 0..vec.len() - 1 { - assert!(vec[i] <= vec[i + 1]); - } + assert!(is_sorted(&vec) && have_same_elements(&vec, &cloned)); } #[test] fn empty() { let mut vec: Vec = vec![]; + let cloned = vec.clone(); stooge_sort(&mut vec); - assert_eq!(vec, vec![]); + assert!(is_sorted(&vec) && have_same_elements(&vec, &cloned)); } #[test] fn reverse() { let mut vec = vec![6, 5, 4, 3, 2, 1]; + let cloned = vec.clone(); stooge_sort(&mut vec); - for i in 0..vec.len() - 1 { - assert!(vec[i] <= vec[i + 1]); - } + assert!(is_sorted(&vec) && have_same_elements(&vec, &cloned)); } #[test] fn already_sorted() { let mut vec = vec![1, 2, 3, 4, 5, 6]; + let cloned = vec.clone(); stooge_sort(&mut vec); - for i in 0..vec.len() - 1 { - assert!(vec[i] <= vec[i + 1]); - } + assert!(is_sorted(&vec) && have_same_elements(&vec, &cloned)); } } diff --git a/src/sorting/tim_sort.rs b/src/sorting/tim_sort.rs index b81cefc83c9..04398a6aba9 100644 --- a/src/sorting/tim_sort.rs +++ b/src/sorting/tim_sort.rs @@ -1,80 +1,98 @@ +//! Implements Tim sort algorithm. +//! +//! Tim sort is a hybrid sorting algorithm derived from merge sort and insertion sort. +//! It is designed to perform well on many kinds of real-world data. + +use crate::sorting::insertion_sort; use std::cmp; static MIN_MERGE: usize = 32; -fn min_run_length(mut n: usize) -> usize { - let mut r = 0; - while n >= MIN_MERGE { - r |= n & 1; - n >>= 1; +/// Calculates the minimum run length for Tim sort based on the length of the array. +/// +/// The minimum run length is determined using a heuristic that ensures good performance. +/// +/// # Arguments +/// +/// * `array_length` - The length of the array. +/// +/// # Returns +/// +/// The minimum run length. +fn compute_min_run_length(array_length: usize) -> usize { + let mut remaining_length = array_length; + let mut result = 0; + + while remaining_length >= MIN_MERGE { + result |= remaining_length & 1; + remaining_length >>= 1; } - n + r -} -fn insertion_sort(arr: &mut Vec, left: usize, right: usize) -> &Vec { - for i in (left + 1)..(right + 1) { - let temp = arr[i]; - let mut j = (i - 1) as i32; - - while j >= (left as i32) && arr[j as usize] > temp { - arr[(j + 1) as usize] = arr[j as usize]; - j -= 1; - } - arr[(j + 1) as usize] = temp; - } - arr + remaining_length + result } -fn merge(arr: &mut Vec, l: usize, m: usize, r: usize) -> &Vec { - let len1 = m - l + 1; - let len2 = r - m; - let mut left = vec![0; len1 as usize]; - let mut right = vec![0; len2 as usize]; - - left[..len1].clone_from_slice(&arr[l..(len1 + l)]); - - for x in 0..len2 { - right[x] = arr[m + 1 + x]; - } - +/// Merges two sorted subarrays into a single sorted subarray. +/// +/// This function merges two sorted subarrays of the provided slice into a single sorted subarray. +/// +/// # Arguments +/// +/// * `arr` - The slice containing the subarrays to be merged. +/// * `left` - The starting index of the first subarray. +/// * `mid` - The ending index of the first subarray. +/// * `right` - The ending index of the second subarray. +fn merge(arr: &mut [T], left: usize, mid: usize, right: usize) { + let left_slice = arr[left..=mid].to_vec(); + let right_slice = arr[mid + 1..=right].to_vec(); let mut i = 0; let mut j = 0; - let mut k = l; + let mut k = left; - while i < len1 && j < len2 { - if left[i] <= right[j] { - arr[k] = left[i]; + while i < left_slice.len() && j < right_slice.len() { + if left_slice[i] <= right_slice[j] { + arr[k] = left_slice[i]; i += 1; } else { - arr[k] = right[j]; + arr[k] = right_slice[j]; j += 1; } k += 1; } - while i < len1 { - arr[k] = left[i]; + // Copy any remaining elements from the left subarray + while i < left_slice.len() { + arr[k] = left_slice[i]; k += 1; i += 1; } - while j < len2 { - arr[k] = right[j]; + // Copy any remaining elements from the right subarray + while j < right_slice.len() { + arr[k] = right_slice[j]; k += 1; j += 1; } - arr } -pub fn tim_sort(arr: &mut Vec, n: usize) { - let min_run = min_run_length(MIN_MERGE) as usize; - +/// Sorts a slice using Tim sort algorithm. +/// +/// This function sorts the provided slice in-place using the Tim sort algorithm. +/// +/// # Arguments +/// +/// * `arr` - The slice to be sorted. +pub fn tim_sort(arr: &mut [T]) { + let n = arr.len(); + let min_run = compute_min_run_length(MIN_MERGE); + + // Perform insertion sort on small subarrays let mut i = 0; while i < n { - insertion_sort(arr, i, cmp::min(i + MIN_MERGE - 1, n - 1)); + insertion_sort(&mut arr[i..cmp::min(i + MIN_MERGE, n)]); i += min_run; } + // Merge sorted subarrays let mut size = min_run; while size < n { let mut left = 0; @@ -94,38 +112,56 @@ pub fn tim_sort(arr: &mut Vec, n: usize) { #[cfg(test)] mod tests { use super::*; + use crate::sorting::{have_same_elements, is_sorted}; #[test] - fn basic() { - let mut array = vec![-2, 7, 15, -14, 0, 15, 0, 7, -7, -4, -13, 5, 8, -14, 12]; - let arr_len = array.len(); - tim_sort(&mut array, arr_len); - for i in 0..array.len() - 1 { - assert!(array[i] <= array[i + 1]); + fn min_run_length_returns_correct_value() { + assert_eq!(compute_min_run_length(0), 0); + assert_eq!(compute_min_run_length(10), 10); + assert_eq!(compute_min_run_length(33), 17); + assert_eq!(compute_min_run_length(64), 16); + } + + macro_rules! test_merge { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (input_arr, l, m, r, expected) = $inputs; + let mut arr = input_arr.clone(); + merge(&mut arr, l, m, r); + assert_eq!(arr, expected); + } + )* } } - #[test] - fn empty() { - let mut array = Vec::::new(); - let arr_len = array.len(); - tim_sort(&mut array, arr_len); - assert_eq!(array, vec![]); + test_merge! { + left_and_right_subarrays_into_array: (vec![0, 2, 4, 1, 3, 5], 0, 2, 5, vec![0, 1, 2, 3, 4, 5]), + with_empty_left_subarray: (vec![1, 2, 3], 0, 0, 2, vec![1, 2, 3]), + with_empty_right_subarray: (vec![1, 2, 3], 0, 2, 2, vec![1, 2, 3]), + with_empty_left_and_right_subarrays: (vec![1, 2, 3], 1, 0, 0, vec![1, 2, 3]), } - #[test] - fn one_element() { - let mut array = vec![3]; - let arr_len = array.len(); - tim_sort(&mut array, arr_len); - assert_eq!(array, vec![3]); + macro_rules! test_tim_sort { + ($($name:ident: $input:expr,)*) => { + $( + #[test] + fn $name() { + let mut array = $input; + let cloned = array.clone(); + tim_sort(&mut array); + assert!(is_sorted(&array) && have_same_elements(&array, &cloned)); + } + )* + } } - #[test] - fn pre_sorted() { - let mut array = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; - let arr_len = array.len(); - tim_sort(&mut array, arr_len); - assert_eq!(array, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); + test_tim_sort! { + sorts_basic_array_correctly: vec![-2, 7, 15, -14, 0, 15, 0, 7, -7, -4, -13, 5, 8, -14, 12], + sorts_long_array_correctly: vec![-2, 7, 15, -14, 0, 15, 0, 7, -7, -4, -13, 5, 8, -14, 12, 5, 3, 9, 22, 1, 1, 2, 3, 9, 6, 5, 4, 5, 6, 7, 8, 9, 1], + handles_empty_array: Vec::::new(), + handles_single_element_array: vec![3], + handles_pre_sorted_array: vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9], } } diff --git a/src/sorting/tree_sort.rs b/src/sorting/tree_sort.rs new file mode 100644 index 00000000000..a04067a5835 --- /dev/null +++ b/src/sorting/tree_sort.rs @@ -0,0 +1,114 @@ +// Author : cyrixninja +// Tree Sort Algorithm +// https://en.wikipedia.org/wiki/Tree_sort +// Wikipedia :A tree sort is a sort algorithm that builds a binary search tree from the elements to be sorted, and then traverses the tree (in-order) so that the elements come out in sorted order. +// Its typical use is sorting elements online: after each insertion, the set of elements seen so far is available in sorted order. + +struct TreeNode { + value: T, + left: Option>>, + right: Option>>, +} + +impl TreeNode { + fn new(value: T) -> Self { + TreeNode { + value, + left: None, + right: None, + } + } +} + +struct BinarySearchTree { + root: Option>>, +} + +impl BinarySearchTree { + fn new() -> Self { + BinarySearchTree { root: None } + } + + fn insert(&mut self, value: T) { + self.root = Some(Self::insert_recursive(self.root.take(), value)); + } + + fn insert_recursive(root: Option>>, value: T) -> Box> { + match root { + None => Box::new(TreeNode::new(value)), + Some(mut node) => { + if value <= node.value { + node.left = Some(Self::insert_recursive(node.left.take(), value)); + } else { + node.right = Some(Self::insert_recursive(node.right.take(), value)); + } + node + } + } + } + + fn in_order_traversal(&self, result: &mut Vec) { + Self::in_order_recursive(&self.root, result); + } + + fn in_order_recursive(root: &Option>>, result: &mut Vec) { + if let Some(node) = root { + Self::in_order_recursive(&node.left, result); + result.push(node.value.clone()); + Self::in_order_recursive(&node.right, result); + } + } +} + +pub fn tree_sort(arr: &mut Vec) { + let mut tree = BinarySearchTree::new(); + + for elem in arr.iter().cloned() { + tree.insert(elem); + } + + let mut result = Vec::new(); + tree.in_order_traversal(&mut result); + + *arr = result; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_array() { + let mut arr: Vec = vec![]; + tree_sort(&mut arr); + assert_eq!(arr, vec![]); + } + + #[test] + fn test_single_element() { + let mut arr = vec![8]; + tree_sort(&mut arr); + assert_eq!(arr, vec![8]); + } + + #[test] + fn test_already_sorted() { + let mut arr = vec![1, 2, 3, 4, 5]; + tree_sort(&mut arr); + assert_eq!(arr, vec![1, 2, 3, 4, 5]); + } + + #[test] + fn test_reverse_sorted() { + let mut arr = vec![5, 4, 3, 2, 1]; + tree_sort(&mut arr); + assert_eq!(arr, vec![1, 2, 3, 4, 5]); + } + + #[test] + fn test_random() { + let mut arr = vec![9, 6, 10, 11, 2, 19]; + tree_sort(&mut arr); + assert_eq!(arr, vec![2, 6, 9, 10, 11, 19]); + } +} diff --git a/src/sorting/wave_sort.rs b/src/sorting/wave_sort.rs new file mode 100644 index 00000000000..06e2e3dec97 --- /dev/null +++ b/src/sorting/wave_sort.rs @@ -0,0 +1,70 @@ +/// Wave Sort Algorithm +/// +/// Wave Sort is a sorting algorithm that works in O(n log n) time assuming +/// the sort function used works in O(n log n) time. +/// It arranges elements in an array into a sequence where every alternate +/// element is either greater or smaller than its adjacent elements. +/// +/// Reference: +/// [Wave Sort Algorithm - GeeksforGeeks](https://www.geeksforgeeks.org/sort-array-wave-form-2/) +/// +/// # Examples +/// +/// use the_algorithms_rust::sorting::wave_sort; +/// let array = vec![10, 90, 49, 2, 1, 5, 23]; +/// let result = wave_sort(array); +/// // Result: [2, 1, 10, 5, 49, 23, 90] +/// +pub fn wave_sort(arr: &mut [T]) { + let n = arr.len(); + arr.sort(); + + for i in (0..n - 1).step_by(2) { + arr.swap(i, i + 1); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_case_1() { + let mut array = vec![10, 90, 49, 2, 1, 5, 23]; + wave_sort(&mut array); + let expected = vec![2, 1, 10, 5, 49, 23, 90]; + assert_eq!(&array, &expected); + } + + #[test] + fn test_case_2() { + let mut array = vec![1, 3, 4, 2, 7, 8]; + wave_sort(&mut array); + let expected = vec![2, 1, 4, 3, 8, 7]; + assert_eq!(&array, &expected); + } + + #[test] + fn test_case_3() { + let mut array = vec![3, 3, 3, 3]; + wave_sort(&mut array); + let expected = vec![3, 3, 3, 3]; + assert_eq!(&array, &expected); + } + + #[test] + fn test_case_4() { + let mut array = vec![9, 4, 6, 8, 14, 3]; + wave_sort(&mut array); + let expected = vec![4, 3, 8, 6, 14, 9]; + assert_eq!(&array, &expected); + } + + #[test] + fn test_case_5() { + let mut array = vec![5, 10, 15, 20, 25]; + wave_sort(&mut array); + let expected = vec![10, 5, 20, 15, 25]; + assert_eq!(&array, &expected); + } +} diff --git a/src/sorting/wiggle_sort.rs b/src/sorting/wiggle_sort.rs new file mode 100644 index 00000000000..7f1bf1bf921 --- /dev/null +++ b/src/sorting/wiggle_sort.rs @@ -0,0 +1,80 @@ +//Wiggle Sort. +//Given an unsorted array nums, reorder it such +//that nums[0] < nums[1] > nums[2] < nums[3].... +//For example: +//if input numbers = [3, 5, 2, 1, 6, 4] +//one possible Wiggle Sorted answer is [3, 5, 1, 6, 2, 4]. + +pub fn wiggle_sort(nums: &mut Vec) -> &mut Vec { + //Rust implementation of wiggle. + // Example: + // >>> wiggle_sort([0, 5, 3, 2, 2]) + // [0, 5, 2, 3, 2] + // >>> wiggle_sort([]) + // [] + // >>> wiggle_sort([-2, -5, -45]) + // [-45, -2, -5] + + let len = nums.len(); + for i in 1..len { + let num_x = nums[i - 1]; + let num_y = nums[i]; + if (i % 2 == 1) == (num_x > num_y) { + nums[i - 1] = num_y; + nums[i] = num_x; + } + } + nums +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sorting::have_same_elements; + + fn is_wiggle_sorted(nums: &[i32]) -> bool { + if nums.is_empty() { + return true; + } + let mut previous = nums[0]; + let mut result = true; + nums.iter().enumerate().skip(1).for_each(|(i, &item)| { + if i != 0 { + result = + result && ((i % 2 == 1 && previous < item) || (i % 2 == 0 && previous > item)); + } + + previous = item; + }); + result + } + + #[test] + fn wingle_elements() { + let arr = vec![3, 5, 2, 1, 6, 4]; + let mut cloned = arr.clone(); + let res = wiggle_sort(&mut cloned); + assert!(is_wiggle_sorted(res)); + assert!(have_same_elements(res, &arr)); + } + + #[test] + fn odd_number_of_elements() { + let arr = vec![4, 1, 3, 5, 2]; + let mut cloned = arr.clone(); + let res = wiggle_sort(&mut cloned); + assert!(is_wiggle_sorted(res)); + assert!(have_same_elements(res, &arr)); + } + + #[test] + fn repeated_elements() { + let arr = vec![5, 5, 5, 5]; + let mut cloned = arr.clone(); + let res = wiggle_sort(&mut cloned); + + // Negative test, can't be wiggle sorted + assert!(!is_wiggle_sorted(res)); + assert!(have_same_elements(res, &arr)); + } +} diff --git a/src/string/README.md b/src/string/README.md index 169c167cd29..85addca949f 100644 --- a/src/string/README.md +++ b/src/string/README.md @@ -47,4 +47,9 @@ to find an exact match of a pattern string in a text. ### [Hamming Distance](./hamming_distance.rs) From [Wikipedia][hamming-distance-wiki]: In information theory, the Hamming distance between two strings of equal length is the number of positions at which the corresponding symbols are different. In other words, it measures the minimum number of substitutions required to change one string into the other, or the minimum number of errors that could have transformed one string into the other. In a more general context, the Hamming distance is one of several string metrics for measuring the edit distance between two sequences. It is named after the American mathematician Richard Hamming. +[run-length-encoding-wiki]: https://en.wikipedia.org/wiki/Run-length_encoding + +### [Run Length Encoding](./run_length_encoding.rs) +From [Wikipedia][run-length-encoding-wiki]: a form of lossless data compression in which runs of data (sequences in which the same data value occurs in many consecutive data elements) are stored as a single data value and count, rather than as the original run. + [hamming-distance-wiki]: https://en.wikipedia.org/wiki/Hamming_distance diff --git a/src/string/aho_corasick.rs b/src/string/aho_corasick.rs index e1d5759c491..02c6f7cdccc 100644 --- a/src/string/aho_corasick.rs +++ b/src/string/aho_corasick.rs @@ -51,9 +51,8 @@ impl AhoCorasick { child.lengths.extend(node.borrow().lengths.clone()); child.suffix = Rc::downgrade(node); break; - } else { - suffix = suffix.unwrap().borrow().suffix.upgrade(); } + suffix = suffix.unwrap().borrow().suffix.upgrade(); } } } @@ -64,7 +63,8 @@ impl AhoCorasick { pub fn search<'a>(&self, s: &'a str) -> Vec<&'a str> { let mut ans = vec![]; let mut cur = Rc::clone(&self.root); - for (i, c) in s.chars().enumerate() { + let mut position: usize = 0; + for c in s.chars() { loop { if let Some(child) = Rc::clone(&cur).borrow().trans.get(&c) { cur = Rc::clone(child); @@ -76,8 +76,9 @@ impl AhoCorasick { None => break, } } + position += c.len_utf8(); for &len in &cur.borrow().lengths { - ans.push(&s[i + 1 - len..=i]); + ans.push(&s[position - len..position]); } } ans @@ -95,4 +96,37 @@ mod tests { let res = ac.search("ababcxyzacxy12678acxy6543"); assert_eq!(res, ["abc", "xyz", "acxy", "678", "acxy", "6543",]); } + + #[test] + fn test_aho_corasick_with_utf8() { + let dict = [ + "abc", + "中文", + "abc中", + "abcd", + "xyz", + "acxy", + "efg", + "123", + "678", + "6543", + "ハンバーガー", + ]; + let ac = AhoCorasick::new(&dict); + let res = ac.search("ababc中xyzacxy12678acxyハンバーガー6543中文"); + assert_eq!( + res, + [ + "abc", + "abc中", + "xyz", + "acxy", + "678", + "acxy", + "ハンバーガー", + "6543", + "中文" + ] + ); + } } diff --git a/src/string/anagram.rs b/src/string/anagram.rs new file mode 100644 index 00000000000..9ea37dc4f6f --- /dev/null +++ b/src/string/anagram.rs @@ -0,0 +1,111 @@ +use std::collections::HashMap; + +/// Custom error type representing an invalid character found in the input. +#[derive(Debug, PartialEq)] +pub enum AnagramError { + NonAlphabeticCharacter, +} + +/// Checks if two strings are anagrams, ignoring spaces and case sensitivity. +/// +/// # Arguments +/// +/// * `s` - First input string. +/// * `t` - Second input string. +/// +/// # Returns +/// +/// * `Ok(true)` if the strings are anagrams. +/// * `Ok(false)` if the strings are not anagrams. +/// * `Err(AnagramError)` if either string contains non-alphabetic characters. +pub fn check_anagram(s: &str, t: &str) -> Result { + let s_cleaned = clean_string(s)?; + let t_cleaned = clean_string(t)?; + + Ok(char_count(&s_cleaned) == char_count(&t_cleaned)) +} + +/// Cleans the input string by removing spaces and converting to lowercase. +/// Returns an error if any non-alphabetic character is found. +/// +/// # Arguments +/// +/// * `s` - Input string to clean. +/// +/// # Returns +/// +/// * `Ok(String)` containing the cleaned string (no spaces, lowercase). +/// * `Err(AnagramError)` if the string contains non-alphabetic characters. +fn clean_string(s: &str) -> Result { + s.chars() + .filter(|c| !c.is_whitespace()) + .map(|c| { + if c.is_alphabetic() { + Ok(c.to_ascii_lowercase()) + } else { + Err(AnagramError::NonAlphabeticCharacter) + } + }) + .collect() +} + +/// Computes the histogram of characters in a string. +/// +/// # Arguments +/// +/// * `s` - Input string. +/// +/// # Returns +/// +/// * A `HashMap` where the keys are characters and values are their count. +fn char_count(s: &str) -> HashMap { + let mut res = HashMap::new(); + for c in s.chars() { + *res.entry(c).or_insert(0) += 1; + } + res +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (s, t, expected) = $test_case; + assert_eq!(check_anagram(s, t), expected); + assert_eq!(check_anagram(t, s), expected); + } + )* + } + } + + test_cases! { + empty_strings: ("", "", Ok(true)), + empty_and_non_empty: ("", "Ted Morgan", Ok(false)), + single_char_same: ("z", "Z", Ok(true)), + single_char_diff: ("g", "h", Ok(false)), + valid_anagram_lowercase: ("cheater", "teacher", Ok(true)), + valid_anagram_with_spaces: ("madam curie", "radium came", Ok(true)), + valid_anagram_mixed_cases: ("Satan", "Santa", Ok(true)), + valid_anagram_with_spaces_and_mixed_cases: ("Anna Madrigal", "A man and a girl", Ok(true)), + new_york_times: ("New York Times", "monkeys write", Ok(true)), + church_of_scientology: ("Church of Scientology", "rich chosen goofy cult", Ok(true)), + mcdonalds_restaurants: ("McDonald's restaurants", "Uncle Sam's standard rot", Err(AnagramError::NonAlphabeticCharacter)), + coronavirus: ("coronavirus", "carnivorous", Ok(true)), + synonym_evil: ("evil", "vile", Ok(true)), + synonym_gentleman: ("a gentleman", "elegant man", Ok(true)), + antigram: ("restful", "fluster", Ok(true)), + sentences: ("William Shakespeare", "I am a weakish speller", Ok(true)), + part_of_speech_adj_to_verb: ("silent", "listen", Ok(true)), + anagrammatized: ("Anagrams", "Ars magna", Ok(true)), + non_anagram: ("rat", "car", Ok(false)), + invalid_anagram_with_special_char: ("hello!", "world", Err(AnagramError::NonAlphabeticCharacter)), + invalid_anagram_with_numeric_chars: ("test123", "321test", Err(AnagramError::NonAlphabeticCharacter)), + invalid_anagram_with_symbols: ("check@anagram", "check@nagaram", Err(AnagramError::NonAlphabeticCharacter)), + non_anagram_length_mismatch: ("abc", "abcd", Ok(false)), + } +} diff --git a/src/string/autocomplete_using_trie.rs b/src/string/autocomplete_using_trie.rs new file mode 100644 index 00000000000..630b6e1dd79 --- /dev/null +++ b/src/string/autocomplete_using_trie.rs @@ -0,0 +1,125 @@ +/* + It autocomplete by prefix using added words. + + word List => ["apple", "orange", "oregano"] + prefix => "or" + matches => ["orange", "oregano"] +*/ + +use std::collections::HashMap; + +const END: char = '#'; + +#[derive(Debug)] +struct Trie(HashMap>); + +impl Trie { + fn new() -> Self { + Trie(HashMap::new()) + } + + fn insert(&mut self, text: &str) { + let mut trie = self; + + for c in text.chars() { + trie = trie.0.entry(c).or_insert_with(|| Box::new(Trie::new())); + } + + trie.0.insert(END, Box::new(Trie::new())); + } + + fn find(&self, prefix: &str) -> Vec { + let mut trie = self; + + for c in prefix.chars() { + let char_trie = trie.0.get(&c); + if let Some(char_trie) = char_trie { + trie = char_trie; + } else { + return vec![]; + } + } + + Self::_elements(trie) + .iter() + .map(|s| prefix.to_owned() + s) + .collect() + } + + fn _elements(map: &Trie) -> Vec { + let mut results = vec![]; + + for (c, v) in map.0.iter() { + let mut sub_result = vec![]; + if c == &END { + sub_result.push("".to_owned()) + } else { + Self::_elements(v) + .iter() + .map(|s| sub_result.push(c.to_string() + s)) + .collect() + } + + results.extend(sub_result) + } + + results + } +} + +pub struct Autocomplete { + trie: Trie, +} + +impl Autocomplete { + fn new() -> Self { + Self { trie: Trie::new() } + } + + pub fn insert_words>(&mut self, words: &[T]) { + for word in words { + self.trie.insert(word.as_ref()); + } + } + + pub fn find_words(&self, prefix: &str) -> Vec { + self.trie.find(prefix) + } +} + +impl Default for Autocomplete { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::Autocomplete; + + #[test] + fn test_autocomplete() { + let words = vec!["apple", "orange", "oregano"]; + + let mut auto_complete = Autocomplete::new(); + auto_complete.insert_words(&words); + + let prefix = "app"; + let mut auto_completed_words = auto_complete.find_words(prefix); + + let mut apple = vec!["apple"]; + apple.sort(); + + auto_completed_words.sort(); + assert_eq!(auto_completed_words, apple); + + let prefix = "or"; + let mut auto_completed_words = auto_complete.find_words(prefix); + + let mut prefix_or = vec!["orange", "oregano"]; + prefix_or.sort(); + + auto_completed_words.sort(); + assert_eq!(auto_completed_words, prefix_or); + } +} diff --git a/src/string/boyer_moore_search.rs b/src/string/boyer_moore_search.rs new file mode 100644 index 00000000000..e9c46a8c980 --- /dev/null +++ b/src/string/boyer_moore_search.rs @@ -0,0 +1,161 @@ +//! This module implements the Boyer-Moore string search algorithm, an efficient method +//! for finding all occurrences of a pattern within a given text. The algorithm skips +//! sections of the text by leveraging two key rules: the bad character rule and the +//! good suffix rule (only the bad character rule is implemented here for simplicity). + +use std::collections::HashMap; + +/// Builds the bad character table for the Boyer-Moore algorithm. +/// This table stores the last occurrence of each character in the pattern. +/// +/// # Arguments +/// * `pat` - The pattern as a slice of characters. +/// +/// # Returns +/// A `HashMap` where the keys are characters from the pattern and the values are their +/// last known positions within the pattern. +fn build_bad_char_table(pat: &[char]) -> HashMap { + let mut bad_char_table = HashMap::new(); + for (i, &ch) in pat.iter().enumerate() { + bad_char_table.insert(ch, i as isize); + } + bad_char_table +} + +/// Calculates the shift when a full match occurs in the Boyer-Moore algorithm. +/// It uses the bad character table to determine how much to shift the pattern. +/// +/// # Arguments +/// * `shift` - The current shift of the pattern on the text. +/// * `pat_len` - The length of the pattern. +/// * `text_len` - The length of the text. +/// * `bad_char_table` - The bad character table built for the pattern. +/// * `text` - The text as a slice of characters. +/// +/// # Returns +/// The number of positions to shift the pattern after a match. +fn calc_match_shift( + shift: isize, + pat_len: isize, + text_len: isize, + bad_char_table: &HashMap, + text: &[char], +) -> isize { + if shift + pat_len >= text_len { + return 1; + } + let next_ch = text[(shift + pat_len) as usize]; + pat_len - bad_char_table.get(&next_ch).unwrap_or(&-1) +} + +/// Calculates the shift when a mismatch occurs in the Boyer-Moore algorithm. +/// The bad character rule is used to determine how far to shift the pattern. +/// +/// # Arguments +/// * `mis_idx` - The mismatch index in the pattern. +/// * `shift` - The current shift of the pattern on the text. +/// * `text` - The text as a slice of characters. +/// * `bad_char_table` - The bad character table built for the pattern. +/// +/// # Returns +/// The number of positions to shift the pattern after a mismatch. +fn calc_mismatch_shift( + mis_idx: isize, + shift: isize, + text: &[char], + bad_char_table: &HashMap, +) -> isize { + let mis_ch = text[(shift + mis_idx) as usize]; + let bad_char_shift = bad_char_table.get(&mis_ch).unwrap_or(&-1); + std::cmp::max(1, mis_idx - bad_char_shift) +} + +/// Performs the Boyer-Moore string search algorithm, which searches for all +/// occurrences of a pattern within a text. +/// +/// The Boyer-Moore algorithm is efficient for large texts and patterns, as it +/// skips sections of the text based on the bad character rule and other optimizations. +/// +/// # Arguments +/// * `text` - The text to search within as a string slice. +/// * `pat` - The pattern to search for as a string slice. +/// +/// # Returns +/// A vector of starting indices where the pattern occurs in the text. +pub fn boyer_moore_search(text: &str, pat: &str) -> Vec { + let mut positions = Vec::new(); + + let text_len = text.len() as isize; + let pat_len = pat.len() as isize; + + // Handle edge cases where the text or pattern is empty, or the pattern is longer than the text + if text_len == 0 || pat_len == 0 || pat_len > text_len { + return positions; + } + + // Convert text and pattern to character vectors for easier indexing + let pat: Vec = pat.chars().collect(); + let text: Vec = text.chars().collect(); + + // Build the bad character table for the pattern + let bad_char_table = build_bad_char_table(&pat); + + let mut shift = 0; + + // Main loop: shift the pattern over the text + while shift <= text_len - pat_len { + let mut j = pat_len - 1; + + // Compare pattern from right to left + while j >= 0 && pat[j as usize] == text[(shift + j) as usize] { + j -= 1; + } + + // If we found a match (j < 0), record the position + if j < 0 { + positions.push(shift as usize); + shift += calc_match_shift(shift, pat_len, text_len, &bad_char_table, &text); + } else { + // If mismatch, calculate how far to shift based on the bad character rule + shift += calc_mismatch_shift(j, shift, &text, &bad_char_table); + } + } + + positions +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! boyer_moore_tests { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (text, pattern, expected) = $tc; + assert_eq!(boyer_moore_search(text, pattern), expected); + } + )* + }; + } + + boyer_moore_tests! { + test_simple_match: ("AABCAB12AFAABCABFFEGABCAB", "ABCAB", vec![1, 11, 20]), + test_no_match: ("AABCAB12AFAABCABFFEGABCAB", "FFF", vec![]), + test_partial_match: ("AABCAB12AFAABCABFFEGABCAB", "CAB", vec![3, 13, 22]), + test_empty_text: ("", "A", vec![]), + test_empty_pattern: ("ABC", "", vec![]), + test_both_empty: ("", "", vec![]), + test_pattern_longer_than_text: ("ABC", "ABCDEFG", vec![]), + test_single_character_text: ("A", "A", vec![0]), + test_single_character_pattern: ("AAAA", "A", vec![0, 1, 2, 3]), + test_case_sensitivity: ("ABCabcABC", "abc", vec![3]), + test_overlapping_patterns: ("AAAAA", "AAA", vec![0, 1, 2]), + test_special_characters: ("@!#$$%^&*", "$$", vec![3]), + test_numerical_pattern: ("123456789123456", "456", vec![3, 12]), + test_partial_overlap_no_match: ("ABCD", "ABCDE", vec![]), + test_single_occurrence: ("XXXXXXXXXXXXXXXXXXPATTERNXXXXXXXXXXXXXXXXXX", "PATTERN", vec![18]), + test_single_occurrence_with_noise: ("PATPATPATPATTERNPAT", "PATTERN", vec![9]), + } +} diff --git a/src/string/burrows_wheeler_transform.rs b/src/string/burrows_wheeler_transform.rs index e89e8611166..3ecef7b9ab3 100644 --- a/src/string/burrows_wheeler_transform.rs +++ b/src/string/burrows_wheeler_transform.rs @@ -1,4 +1,4 @@ -pub fn burrows_wheeler_transform(input: String) -> (String, usize) { +pub fn burrows_wheeler_transform(input: &str) -> (String, usize) { let len = input.len(); let mut table = Vec::::with_capacity(len); @@ -19,11 +19,11 @@ pub fn burrows_wheeler_transform(input: String) -> (String, usize) { (encoded, index) } -pub fn inv_burrows_wheeler_transform(input: (String, usize)) -> String { - let len = input.0.len(); +pub fn inv_burrows_wheeler_transform>(input: (T, usize)) -> String { + let len = input.0.as_ref().len(); let mut table = Vec::<(usize, char)>::with_capacity(len); for i in 0..len { - table.push((i, input.0.chars().nth(i).unwrap())); + table.push((i, input.0.as_ref().chars().nth(i).unwrap())); } table.sort_by(|a, b| a.1.cmp(&b.1)); @@ -43,25 +43,50 @@ mod tests { use super::*; #[test] - fn basic() { + //Ensure function stand-alone legitimacy + fn stand_alone_function() { assert_eq!( - inv_burrows_wheeler_transform(burrows_wheeler_transform("CARROT".to_string())), + burrows_wheeler_transform("CARROT"), + ("CTRRAO".to_owned(), 1usize) + ); + assert_eq!(inv_burrows_wheeler_transform(("CTRRAO", 1usize)), "CARROT"); + assert_eq!( + burrows_wheeler_transform("THEALGORITHMS"), + ("EHLTTRAHGOMSI".to_owned(), 11usize) + ); + assert_eq!( + inv_burrows_wheeler_transform(("EHLTTRAHGOMSI".to_string(), 11usize)), + "THEALGORITHMS" + ); + assert_eq!( + burrows_wheeler_transform("!.!.!??.=::"), + (":..!!?:=.?!".to_owned(), 0usize) + ); + assert_eq!( + inv_burrows_wheeler_transform((":..!!?:=.?!", 0usize)), + "!.!.!??.=::" + ); + } + #[test] + fn basic_characters() { + assert_eq!( + inv_burrows_wheeler_transform(burrows_wheeler_transform("CARROT")), "CARROT" ); assert_eq!( - inv_burrows_wheeler_transform(burrows_wheeler_transform("TOMATO".to_string())), + inv_burrows_wheeler_transform(burrows_wheeler_transform("TOMATO")), "TOMATO" ); assert_eq!( - inv_burrows_wheeler_transform(burrows_wheeler_transform("THISISATEST".to_string())), + inv_burrows_wheeler_transform(burrows_wheeler_transform("THISISATEST")), "THISISATEST" ); assert_eq!( - inv_burrows_wheeler_transform(burrows_wheeler_transform("THEALGORITHMS".to_string())), + inv_burrows_wheeler_transform(burrows_wheeler_transform("THEALGORITHMS")), "THEALGORITHMS" ); assert_eq!( - inv_burrows_wheeler_transform(burrows_wheeler_transform("RUST".to_string())), + inv_burrows_wheeler_transform(burrows_wheeler_transform("RUST")), "RUST" ); } @@ -69,17 +94,15 @@ mod tests { #[test] fn special_characters() { assert_eq!( - inv_burrows_wheeler_transform(burrows_wheeler_transform("!.!.!??.=::".to_string())), + inv_burrows_wheeler_transform(burrows_wheeler_transform("!.!.!??.=::")), "!.!.!??.=::" ); assert_eq!( - inv_burrows_wheeler_transform(burrows_wheeler_transform( - "!{}{}(((&&%%!??.=::".to_string() - )), + inv_burrows_wheeler_transform(burrows_wheeler_transform("!{}{}(((&&%%!??.=::")), "!{}{}(((&&%%!??.=::" ); assert_eq!( - inv_burrows_wheeler_transform(burrows_wheeler_transform("//&$[]".to_string())), + inv_burrows_wheeler_transform(burrows_wheeler_transform("//&$[]")), "//&$[]" ); } @@ -87,7 +110,7 @@ mod tests { #[test] fn empty() { assert_eq!( - inv_burrows_wheeler_transform(burrows_wheeler_transform("".to_string())), + inv_burrows_wheeler_transform(burrows_wheeler_transform("")), "" ); } diff --git a/src/string/duval_algorithm.rs b/src/string/duval_algorithm.rs new file mode 100644 index 00000000000..69e9dbff2a9 --- /dev/null +++ b/src/string/duval_algorithm.rs @@ -0,0 +1,97 @@ +//! Implementation of Duval's Algorithm to compute the standard factorization of a string +//! into Lyndon words. A Lyndon word is defined as a string that is strictly smaller +//! (lexicographically) than any of its nontrivial suffixes. This implementation operates +//! in linear time and space. + +/// Performs Duval's algorithm to factorize a given string into its Lyndon words. +/// +/// # Arguments +/// +/// * `s` - A slice of characters representing the input string. +/// +/// # Returns +/// +/// A vector of strings, where each string is a Lyndon word, representing the factorization +/// of the input string. +/// +/// # Time Complexity +/// +/// The algorithm runs in O(n) time, where `n` is the length of the input string. +pub fn duval_algorithm(s: &str) -> Vec { + factorize_duval(&s.chars().collect::>()) +} + +/// Helper function that takes a string slice, converts it to a vector of characters, +/// and then applies the Duval factorization algorithm to find the Lyndon words. +/// +/// # Arguments +/// +/// * `s` - A string slice representing the input text. +/// +/// # Returns +/// +/// A vector of strings, each representing a Lyndon word in the factorization. +fn factorize_duval(s: &[char]) -> Vec { + let mut start = 0; + let mut factors: Vec = Vec::new(); + + while start < s.len() { + let mut end = start + 1; + let mut repeat = start; + + while end < s.len() && s[repeat] <= s[end] { + if s[repeat] < s[end] { + repeat = start; + } else { + repeat += 1; + } + end += 1; + } + + while start <= repeat { + factors.push(s[start..start + end - repeat].iter().collect::()); + start += end - repeat; + } + } + + factors +} + +#[cfg(test)] +mod test { + use super::*; + + macro_rules! test_duval_algorithm { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (text, expected) = $inputs; + assert_eq!(duval_algorithm(text), expected); + } + )* + } + } + + test_duval_algorithm! { + repeating_with_suffix: ("abcdabcdababc", ["abcd", "abcd", "ababc"]), + single_repeating_char: ("aaa", ["a", "a", "a"]), + single: ("ababb", ["ababb"]), + unicode: ("അഅഅ", ["അ", "അ", "അ"]), + empty_string: ("", Vec::::new()), + single_char: ("x", ["x"]), + palindrome: ("racecar", ["r", "acecar"]), + long_repeating: ("aaaaaa", ["a", "a", "a", "a", "a", "a"]), + mixed_repeating: ("ababcbabc", ["ababcbabc"]), + non_repeating_sorted: ("abcdefg", ["abcdefg"]), + alternating_increasing: ("abababab", ["ab", "ab", "ab", "ab"]), + long_repeating_lyndon: ("abcabcabcabc", ["abc", "abc", "abc", "abc"]), + decreasing_order: ("zyxwvutsrqponm", ["z", "y", "x", "w", "v", "u", "t", "s", "r", "q", "p", "o", "n", "m"]), + alphanumeric_mixed: ("a1b2c3a1", ["a", "1b2c3a", "1"]), + special_characters: ("a@b#c$d", ["a", "@b", "#c$d"]), + unicode_complex: ("αβγδ", ["αβγδ"]), + long_string_performance: (&"a".repeat(1_000_000), vec!["a"; 1_000_000]), + palindrome_repeating_prefix: ("abccba", ["abccb", "a"]), + interrupted_lyndon: ("abcxabc", ["abcx", "abc"]), + } +} diff --git a/src/string/hamming_distance.rs b/src/string/hamming_distance.rs index 6f6b4f354c2..3137d5abc7c 100644 --- a/src/string/hamming_distance.rs +++ b/src/string/hamming_distance.rs @@ -1,47 +1,57 @@ -pub fn hamming_distance(string1: &str, string2: &str) -> usize { - let mut distance = 0; - let mut string1 = string1.chars(); - let mut string2 = string2.chars(); +/// Error type for Hamming distance calculation. +#[derive(Debug, PartialEq)] +pub enum HammingDistanceError { + InputStringsHaveDifferentLength, +} - loop { - match (string1.next(), string2.next()) { - (Some(char1), Some(char2)) if char1 != char2 => distance += 1, - (Some(char1), Some(char2)) if char1 == char2 => continue, - (None, Some(_)) | (Some(_), None) => panic!("Strings must have the same length"), - (None, None) => break, - _ => unreachable!(), - } +/// Calculates the Hamming distance between two strings. +/// +/// The Hamming distance is defined as the number of positions at which the corresponding characters of the two strings are different. +pub fn hamming_distance(string_a: &str, string_b: &str) -> Result { + if string_a.len() != string_b.len() { + return Err(HammingDistanceError::InputStringsHaveDifferentLength); } - distance + + let distance = string_a + .chars() + .zip(string_b.chars()) + .filter(|(a, b)| a != b) + .count(); + + Ok(distance) } #[cfg(test)] mod tests { use super::*; - #[test] - fn empty_strings() { - let result = hamming_distance("", ""); - assert_eq!(result, 0); - } - #[test] - fn distance_zero() { - let result = hamming_distance("rust", "rust"); - assert_eq!(result, 0); - } - #[test] - fn distance_three() { - let result = hamming_distance("karolin", "kathrin"); - assert_eq!(result, 3); - } - #[test] - fn distance_four() { - let result = hamming_distance("kathrin", "kerstin"); - assert_eq!(result, 4); + macro_rules! test_hamming_distance { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (str_a, str_b, expected) = $tc; + assert_eq!(hamming_distance(str_a, str_b), expected); + assert_eq!(hamming_distance(str_b, str_a), expected); + } + )* + } } - #[test] - fn distance_five() { - let result = hamming_distance("00000", "11111"); - assert_eq!(result, 5); + + test_hamming_distance! { + empty_inputs: ("", "", Ok(0)), + different_length: ("0", "", Err(HammingDistanceError::InputStringsHaveDifferentLength)), + length_1_inputs_identical: ("a", "a", Ok(0)), + length_1_inputs_different: ("a", "b", Ok(1)), + same_strings: ("rust", "rust", Ok(0)), + regular_input_0: ("karolin", "kathrin", Ok(3)), + regular_input_1: ("kathrin", "kerstin", Ok(4)), + regular_input_2: ("00000", "11111", Ok(5)), + different_case: ("x", "X", Ok(1)), + strings_with_no_common_chars: ("abcd", "wxyz", Ok(4)), + long_strings_one_diff: (&"a".repeat(1000), &("a".repeat(999) + "b"), Ok(1)), + long_strings_many_diffs: (&("a".repeat(500) + &"b".repeat(500)), &("b".repeat(500) + &"a".repeat(500)), Ok(1000)), + strings_with_special_chars_identical: ("!@#$%^", "!@#$%^", Ok(0)), + strings_with_special_chars_diff: ("!@#$%^", "&*()_+", Ok(6)), } } diff --git a/src/string/isogram.rs b/src/string/isogram.rs new file mode 100644 index 00000000000..30b8d66bdff --- /dev/null +++ b/src/string/isogram.rs @@ -0,0 +1,104 @@ +//! This module provides functionality to check if a given string is an isogram. +//! An isogram is a word or phrase in which no letter occurs more than once. + +use std::collections::HashMap; + +/// Enum representing possible errors that can occur while checking for isograms. +#[derive(Debug, PartialEq, Eq)] +pub enum IsogramError { + /// Indicates that the input contains a non-alphabetic character. + NonAlphabeticCharacter, +} + +/// Counts the occurrences of each alphabetic character in a given string. +/// +/// This function takes a string slice as input. It counts how many times each alphabetic character +/// appears in the input string and returns a hashmap where the keys are characters and the values +/// are their respective counts. +/// +/// # Arguments +/// +/// * `s` - A string slice that contains the input to count characters from. +/// +/// # Errors +/// +/// Returns an error if the input contains non-alphabetic characters (excluding spaces). +/// +/// # Note +/// +/// This function treats uppercase and lowercase letters as equivalent (case-insensitive). +/// Spaces are ignored and do not affect the character count. +fn count_letters(s: &str) -> Result, IsogramError> { + let mut letter_counts = HashMap::new(); + + for ch in s.to_ascii_lowercase().chars() { + if !ch.is_ascii_alphabetic() && !ch.is_whitespace() { + return Err(IsogramError::NonAlphabeticCharacter); + } + + if ch.is_ascii_alphabetic() { + *letter_counts.entry(ch).or_insert(0) += 1; + } + } + + Ok(letter_counts) +} + +/// Checks if the given input string is an isogram. +/// +/// This function takes a string slice as input. It counts the occurrences of each +/// alphabetic character (ignoring case and spaces). +/// +/// # Arguments +/// +/// * `input` - A string slice that contains the input to check for isogram properties. +/// +/// # Return +/// +/// - `Ok(true)` if all characters appear only once, or `Ok(false)` if any character appears more than once. +/// - `Err(IsogramError::NonAlphabeticCharacter) if the input contains any non-alphabetic characters. +pub fn is_isogram(s: &str) -> Result { + let letter_counts = count_letters(s)?; + Ok(letter_counts.values().all(|&count| count == 1)) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! isogram_tests { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $tc; + assert_eq!(is_isogram(input), expected); + } + )* + }; + } + + isogram_tests! { + isogram_simple: ("isogram", Ok(true)), + isogram_case_insensitive: ("Isogram", Ok(true)), + isogram_with_spaces: ("a b c d e", Ok(true)), + isogram_mixed: ("Dermatoglyphics", Ok(true)), + isogram_long: ("Subdermatoglyphic", Ok(true)), + isogram_german_city: ("Malitzschkendorf", Ok(true)), + perfect_pangram: ("Cwm fjord bank glyphs vext quiz", Ok(true)), + isogram_sentences: ("The big dwarf only jumps", Ok(true)), + isogram_french: ("Lampez un fort whisky", Ok(true)), + isogram_portuguese: ("Velho traduz sim", Ok(true)), + isogram_spanis: ("Centrifugadlos", Ok(true)), + invalid_isogram_with_repeated_char: ("hello", Ok(false)), + invalid_isogram_with_numbers: ("abc123", Err(IsogramError::NonAlphabeticCharacter)), + invalid_isogram_with_special_char: ("abc!", Err(IsogramError::NonAlphabeticCharacter)), + invalid_isogram_with_comma: ("Velho, traduz sim", Err(IsogramError::NonAlphabeticCharacter)), + invalid_isogram_with_spaces: ("a b c d a", Ok(false)), + invalid_isogram_with_repeated_phrase: ("abcabc", Ok(false)), + isogram_empty_string: ("", Ok(true)), + isogram_single_character: ("a", Ok(true)), + invalid_isogram_multiple_same_characters: ("aaaa", Ok(false)), + invalid_isogram_with_symbols: ("abc@#$%", Err(IsogramError::NonAlphabeticCharacter)), + } +} diff --git a/src/string/isomorphism.rs b/src/string/isomorphism.rs new file mode 100644 index 00000000000..8583ece1e1c --- /dev/null +++ b/src/string/isomorphism.rs @@ -0,0 +1,83 @@ +//! This module provides functionality to determine whether two strings are isomorphic. +//! +//! Two strings are considered isomorphic if the characters in one string can be replaced +//! by some mapping relation to obtain the other string. +use std::collections::HashMap; + +/// Determines whether two strings are isomorphic. +/// +/// # Arguments +/// +/// * `s` - The first string. +/// * `t` - The second string. +/// +/// # Returns +/// +/// `true` if the strings are isomorphic, `false` otherwise. +pub fn is_isomorphic(s: &str, t: &str) -> bool { + let s_chars: Vec = s.chars().collect(); + let t_chars: Vec = t.chars().collect(); + if s_chars.len() != t_chars.len() { + return false; + } + let mut s_to_t_map = HashMap::new(); + let mut t_to_s_map = HashMap::new(); + for (s_char, t_char) in s_chars.into_iter().zip(t_chars) { + if !check_mapping(&mut s_to_t_map, s_char, t_char) + || !check_mapping(&mut t_to_s_map, t_char, s_char) + { + return false; + } + } + true +} + +/// Checks the mapping between two characters and updates the map. +/// +/// # Arguments +/// +/// * `map` - The HashMap to store the mapping. +/// * `key` - The key character. +/// * `value` - The value character. +/// +/// # Returns +/// +/// `true` if the mapping is consistent, `false` otherwise. +fn check_mapping(map: &mut HashMap, key: char, value: char) -> bool { + match map.get(&key) { + Some(&mapped_char) => mapped_char == value, + None => { + map.insert(key, value); + true + } + } +} + +#[cfg(test)] +mod tests { + use super::is_isomorphic; + macro_rules! test_is_isomorphic { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (s, t, expected) = $inputs; + assert_eq!(is_isomorphic(s, t), expected); + assert_eq!(is_isomorphic(t, s), expected); + assert!(is_isomorphic(s, s)); + assert!(is_isomorphic(t, t)); + } + )* + } + } + test_is_isomorphic! { + isomorphic: ("egg", "add", true), + isomorphic_long: ("abcdaabdcdbbabababacdadad", "AbCdAAbdCdbbAbAbAbACdAdAd", true), + not_isomorphic: ("egg", "adc", false), + non_isomorphic_long: ("abcdaabdcdbbabababacdadad", "AACdAAbdCdbbAbAbAbACdAdAd", false), + isomorphic_unicode: ("天苍苍", "野茫茫", true), + isomorphic_unicode_different_byte_size: ("abb", "野茫茫", true), + empty: ("", "", true), + different_length: ("abc", "abcd", false), + } +} diff --git a/src/string/jaro_winkler_distance.rs b/src/string/jaro_winkler_distance.rs new file mode 100644 index 00000000000..e00e526e676 --- /dev/null +++ b/src/string/jaro_winkler_distance.rs @@ -0,0 +1,84 @@ +// In computer science and statistics, +// the Jaro–Winkler distance is a string metric measuring an edit distance +// between two sequences. +// It is a variant proposed in 1990 by William E. Winkler +// of the Jaro distance metric (1989, Matthew A. Jaro). + +pub fn jaro_winkler_distance(str1: &str, str2: &str) -> f64 { + if str1.is_empty() || str2.is_empty() { + return 0.0; + } + fn get_matched_characters(s1: &str, s2: &str) -> String { + let mut s2 = s2.to_string(); + let mut matched: Vec = Vec::new(); + let limit = std::cmp::min(s1.len(), s2.len()) / 2; + for (i, l) in s1.chars().enumerate() { + let left = std::cmp::max(0, i as i32 - limit as i32) as usize; + let right = std::cmp::min(i + limit + 1, s2.len()); + if s2[left..right].contains(l) { + matched.push(l); + let a = &s2[0..s2.find(l).expect("this exists")]; + let b = &s2[(s2.find(l).expect("this exists") + 1)..]; + s2 = format!("{a} {b}"); + } + } + matched.iter().collect::() + } + + let matching_1 = get_matched_characters(str1, str2); + let matching_2 = get_matched_characters(str2, str1); + let match_count = matching_1.len(); + + // transposition + let transpositions = { + let mut count = 0; + for (c1, c2) in matching_1.chars().zip(matching_2.chars()) { + if c1 != c2 { + count += 1; + } + } + count / 2 + }; + + let jaro: f64 = { + if match_count == 0 { + return 0.0; + } + (1_f64 / 3_f64) + * (match_count as f64 / str1.len() as f64 + + match_count as f64 / str2.len() as f64 + + (match_count - transpositions) as f64 / match_count as f64) + }; + + let mut prefix_len = 0.0; + let bound = std::cmp::min(std::cmp::min(str1.len(), str2.len()), 4); + for (c1, c2) in str1[..bound].chars().zip(str2[..bound].chars()) { + if c1 == c2 { + prefix_len += 1.0; + } else { + break; + } + } + jaro + (0.1 * prefix_len * (1.0 - jaro)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_jaro_winkler_distance() { + let a = jaro_winkler_distance("hello", "world"); + assert_eq!(a, 0.4666666666666666); + let a = jaro_winkler_distance("martha", "marhta"); + assert_eq!(a, 0.9611111111111111); + let a = jaro_winkler_distance("martha", "marhat"); + assert_eq!(a, 0.9611111111111111); + let a = jaro_winkler_distance("test", "test"); + assert_eq!(a, 1.0); + let a = jaro_winkler_distance("test", ""); + assert_eq!(a, 0.0); + let a = jaro_winkler_distance("hello world", "HeLLo W0rlD"); + assert_eq!(a, 0.6363636363636364); + } +} diff --git a/src/string/knuth_morris_pratt.rs b/src/string/knuth_morris_pratt.rs index 1bd183ab3f1..ec3fba0c9f1 100644 --- a/src/string/knuth_morris_pratt.rs +++ b/src/string/knuth_morris_pratt.rs @@ -1,97 +1,146 @@ -pub fn knuth_morris_pratt(st: String, pat: String) -> Vec { - if st.is_empty() || pat.is_empty() { +//! Knuth-Morris-Pratt string matching algorithm implementation in Rust. +//! +//! This module contains the implementation of the KMP algorithm, which is used for finding +//! occurrences of a pattern string within a text string efficiently. The algorithm preprocesses +//! the pattern to create a partial match table, which allows for efficient searching. + +/// Finds all occurrences of the pattern in the given string using the Knuth-Morris-Pratt algorithm. +/// +/// # Arguments +/// +/// * `string` - The string to search within. +/// * `pattern` - The pattern string to search for. +/// +/// # Returns +/// +/// A vector of starting indices where the pattern is found in the string. If the pattern or the +/// string is empty, an empty vector is returned. +pub fn knuth_morris_pratt(string: &str, pattern: &str) -> Vec { + if string.is_empty() || pattern.is_empty() { return vec![]; } - let string = st.into_bytes(); - let pattern = pat.into_bytes(); + let text_chars = string.chars().collect::>(); + let pattern_chars = pattern.chars().collect::>(); + let partial_match_table = build_partial_match_table(&pattern_chars); + find_pattern(&text_chars, &pattern_chars, &partial_match_table) +} - // build the partial match table - let mut partial = vec![0]; - for i in 1..pattern.len() { - let mut j = partial[i - 1]; - while j > 0 && pattern[j] != pattern[i] { - j = partial[j - 1]; - } - partial.push(if pattern[j] == pattern[i] { j + 1 } else { j }); - } +/// Builds the partial match table (also known as "prefix table") for the given pattern. +/// +/// The partial match table is used to skip characters while matching the pattern in the text. +/// Each entry at index `i` in the table indicates the length of the longest proper prefix of +/// the substring `pattern[0..i]` which is also a suffix of this substring. +/// +/// # Arguments +/// +/// * `pattern_chars` - The pattern string as a slice of characters. +/// +/// # Returns +/// +/// A vector representing the partial match table. +fn build_partial_match_table(pattern_chars: &[char]) -> Vec { + let mut partial_match_table = vec![0]; + pattern_chars + .iter() + .enumerate() + .skip(1) + .for_each(|(index, &char)| { + let mut length = partial_match_table[index - 1]; + while length > 0 && pattern_chars[length] != char { + length = partial_match_table[length - 1]; + } + partial_match_table.push(if pattern_chars[length] == char { + length + 1 + } else { + length + }); + }); + partial_match_table +} - // and read 'string' to find 'pattern' - let mut ret = vec![]; - let mut j = 0; +/// Finds all occurrences of the pattern in the given string using the precomputed partial match table. +/// +/// This function iterates through the string and uses the partial match table to efficiently find +/// all starting indices of the pattern in the string. +/// +/// # Arguments +/// +/// * `text_chars` - The string to search within as a slice of characters. +/// * `pattern_chars` - The pattern string to search for as a slice of characters. +/// * `partial_match_table` - The precomputed partial match table for the pattern. +/// +/// # Returns +/// +/// A vector of starting indices where the pattern is found in the string. +fn find_pattern( + text_chars: &[char], + pattern_chars: &[char], + partial_match_table: &[usize], +) -> Vec { + let mut result_indices = vec![]; + let mut match_length = 0; - for (i, &c) in string.iter().enumerate() { - while j > 0 && c != pattern[j] { - j = partial[j - 1]; - } - if c == pattern[j] { - j += 1; - } - if j == pattern.len() { - ret.push(i + 1 - j); - j = partial[j - 1]; - } - } + text_chars + .iter() + .enumerate() + .for_each(|(text_index, &text_char)| { + while match_length > 0 && text_char != pattern_chars[match_length] { + match_length = partial_match_table[match_length - 1]; + } + if text_char == pattern_chars[match_length] { + match_length += 1; + } + if match_length == pattern_chars.len() { + result_indices.push(text_index + 1 - match_length); + match_length = partial_match_table[match_length - 1]; + } + }); - ret + result_indices } #[cfg(test)] mod tests { use super::*; - #[test] - fn each_letter_matches() { - let index = knuth_morris_pratt("aaa".to_string(), "a".to_string()); - assert_eq!(index, vec![0, 1, 2]); - } - - #[test] - fn a_few_separate_matches() { - let index = knuth_morris_pratt("abababa".to_string(), "ab".to_string()); - assert_eq!(index, vec![0, 2, 4]); - } - - #[test] - fn one_match() { - let index = - knuth_morris_pratt("ABC ABCDAB ABCDABCDABDE".to_string(), "ABCDABD".to_string()); - assert_eq!(index, vec![15]); - } - - #[test] - fn lots_of_matches() { - let index = knuth_morris_pratt("aaabaabaaaaa".to_string(), "aa".to_string()); - assert_eq!(index, vec![0, 1, 4, 7, 8, 9, 10]); - } - - #[test] - fn lots_of_intricate_matches() { - let index = knuth_morris_pratt("ababababa".to_string(), "aba".to_string()); - assert_eq!(index, vec![0, 2, 4, 6]); - } - - #[test] - fn not_found0() { - let index = knuth_morris_pratt("abcde".to_string(), "f".to_string()); - assert_eq!(index, vec![]); - } - - #[test] - fn not_found1() { - let index = knuth_morris_pratt("abcde".to_string(), "ac".to_string()); - assert_eq!(index, vec![]); - } - - #[test] - fn not_found2() { - let index = knuth_morris_pratt("ababab".to_string(), "bababa".to_string()); - assert_eq!(index, vec![]); + macro_rules! test_knuth_morris_pratt { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (input, pattern, expected) = $inputs; + assert_eq!(knuth_morris_pratt(input, pattern), expected); + } + )* + } } - #[test] - fn empty_string() { - let index = knuth_morris_pratt("".to_string(), "abcdef".to_string()); - assert_eq!(index, vec![]); + test_knuth_morris_pratt! { + each_letter_matches: ("aaa", "a", vec![0, 1, 2]), + a_few_seperate_matches: ("abababa", "ab", vec![0, 2, 4]), + unicode: ("അഅഅ", "അ", vec![0, 1, 2]), + unicode_no_match_but_similar_bytes: ( + &String::from_utf8(vec![224, 180, 133]).unwrap(), + &String::from_utf8(vec![224, 180, 132]).unwrap(), + vec![] + ), + one_match: ("ABC ABCDAB ABCDABCDABDE", "ABCDABD", vec![15]), + lots_of_matches: ("aaabaabaaaaa", "aa", vec![0, 1, 4, 7, 8, 9, 10]), + lots_of_intricate_matches: ("ababababa", "aba", vec![0, 2, 4, 6]), + not_found0: ("abcde", "f", vec![]), + not_found1: ("abcde", "ac", vec![]), + not_found2: ("ababab", "bababa", vec![]), + empty_string: ("", "abcdef", vec![]), + empty_pattern: ("abcdef", "", vec![]), + single_character_string: ("a", "a", vec![0]), + single_character_pattern: ("abcdef", "d", vec![3]), + pattern_at_start: ("abcdef", "abc", vec![0]), + pattern_at_end: ("abcdef", "def", vec![3]), + pattern_in_middle: ("abcdef", "cd", vec![2]), + no_match_with_repeated_characters: ("aaaaaa", "b", vec![]), + pattern_longer_than_string: ("abc", "abcd", vec![]), + very_long_string: (&"a".repeat(10000), "a", (0..10000).collect::>()), + very_long_pattern: (&"a".repeat(10000), &"a".repeat(9999), (0..2).collect::>()), } } diff --git a/src/string/levenshtein_distance.rs b/src/string/levenshtein_distance.rs new file mode 100644 index 00000000000..1a1ccefaee4 --- /dev/null +++ b/src/string/levenshtein_distance.rs @@ -0,0 +1,164 @@ +//! Provides functions to calculate the Levenshtein distance between two strings. +//! +//! The Levenshtein distance is a measure of the similarity between two strings by calculating the minimum number of single-character +//! edits (insertions, deletions, or substitutions) required to change one string into the other. + +use std::cmp::min; + +/// Calculates the Levenshtein distance between two strings using a naive dynamic programming approach. +/// +/// The Levenshtein distance is a measure of the similarity between two strings by calculating the minimum number of single-character +/// edits (insertions, deletions, or substitutions) required to change one string into the other. +/// +/// # Arguments +/// +/// * `string1` - A reference to the first string. +/// * `string2` - A reference to the second string. +/// +/// # Returns +/// +/// The Levenshtein distance between the two input strings. +/// +/// This function computes the Levenshtein distance by constructing a dynamic programming matrix and iteratively filling it in. +/// It follows the standard top-to-bottom, left-to-right approach for filling in the matrix. +/// +/// # Complexity +/// +/// - Time complexity: O(nm), +/// - Space complexity: O(nm), +/// +/// where n and m are lengths of `string1` and `string2`. +/// +/// Note that this implementation uses a straightforward dynamic programming approach without any space optimization. +/// It may consume more memory for larger input strings compared to the optimized version. +pub fn naive_levenshtein_distance(string1: &str, string2: &str) -> usize { + let distance_matrix: Vec> = (0..=string1.len()) + .map(|i| { + (0..=string2.len()) + .map(|j| { + if i == 0 { + j + } else if j == 0 { + i + } else { + 0 + } + }) + .collect() + }) + .collect(); + + let updated_matrix = (1..=string1.len()).fold(distance_matrix, |matrix, i| { + (1..=string2.len()).fold(matrix, |mut inner_matrix, j| { + let cost = usize::from(string1.as_bytes()[i - 1] != string2.as_bytes()[j - 1]); + inner_matrix[i][j] = (inner_matrix[i - 1][j - 1] + cost) + .min(inner_matrix[i][j - 1] + 1) + .min(inner_matrix[i - 1][j] + 1); + inner_matrix + }) + }); + + updated_matrix[string1.len()][string2.len()] +} + +/// Calculates the Levenshtein distance between two strings using an optimized dynamic programming approach. +/// +/// This edit distance is defined as 1 point per insertion, substitution, or deletion required to make the strings equal. +/// +/// # Arguments +/// +/// * `string1` - The first string. +/// * `string2` - The second string. +/// +/// # Returns +/// +/// The Levenshtein distance between the two input strings. +/// For a detailed explanation, check the example on [Wikipedia](https://en.wikipedia.org/wiki/Levenshtein_distance). +/// This function iterates over the bytes in the string, so it may not behave entirely as expected for non-ASCII strings. +/// +/// Note that this implementation utilizes an optimized dynamic programming approach, significantly reducing the space complexity from O(nm) to O(n), where n and m are the lengths of `string1` and `string2`. +/// +/// Additionally, it minimizes space usage by leveraging the shortest string horizontally and the longest string vertically in the computation matrix. +/// +/// # Complexity +/// +/// - Time complexity: O(nm), +/// - Space complexity: O(n), +/// +/// where n and m are lengths of `string1` and `string2`. +pub fn optimized_levenshtein_distance(string1: &str, string2: &str) -> usize { + if string1.is_empty() { + return string2.len(); + } + let l1 = string1.len(); + let mut prev_dist: Vec = (0..=l1).collect(); + + for (row, c2) in string2.chars().enumerate() { + // we'll keep a reference to matrix[i-1][j-1] (top-left cell) + let mut prev_substitution_cost = prev_dist[0]; + // diff with empty string, since `row` starts at 0, it's `row + 1` + prev_dist[0] = row + 1; + + for (col, c1) in string1.chars().enumerate() { + // "on the left" in the matrix (i.e. the value we just computed) + let deletion_cost = prev_dist[col] + 1; + // "on the top" in the matrix (means previous) + let insertion_cost = prev_dist[col + 1] + 1; + let substitution_cost = if c1 == c2 { + // last char is the same on both ends, so the min_distance is left unchanged from matrix[i-1][i+1] + prev_substitution_cost + } else { + // substitute the last character + prev_substitution_cost + 1 + }; + // save the old value at (i-1, j-1) + prev_substitution_cost = prev_dist[col + 1]; + prev_dist[col + 1] = _min3(deletion_cost, insertion_cost, substitution_cost); + } + } + prev_dist[l1] +} + +#[inline] +fn _min3(a: T, b: T, c: T) -> T { + min(a, min(b, c)) +} + +#[cfg(test)] +mod tests { + const LEVENSHTEIN_DISTANCE_TEST_CASES: &[(&str, &str, usize)] = &[ + ("", "", 0), + ("Hello, World!", "Hello, World!", 0), + ("", "Rust", 4), + ("horse", "ros", 3), + ("tan", "elephant", 6), + ("execute", "intention", 8), + ]; + + macro_rules! levenshtein_distance_tests { + ($function:ident) => { + mod $function { + use super::*; + + fn run_test_case(string1: &str, string2: &str, expected_distance: usize) { + assert_eq!(super::super::$function(string1, string2), expected_distance); + assert_eq!(super::super::$function(string2, string1), expected_distance); + assert_eq!(super::super::$function(string1, string1), 0); + assert_eq!(super::super::$function(string2, string2), 0); + } + + #[test] + fn test_levenshtein_distance() { + for &(string1, string2, expected_distance) in + LEVENSHTEIN_DISTANCE_TEST_CASES.iter() + { + run_test_case(string1, string2, expected_distance); + } + } + } + }; + } + + levenshtein_distance_tests!(naive_levenshtein_distance); + levenshtein_distance_tests!(optimized_levenshtein_distance); +} diff --git a/src/string/lipogram.rs b/src/string/lipogram.rs new file mode 100644 index 00000000000..9a486c2a62d --- /dev/null +++ b/src/string/lipogram.rs @@ -0,0 +1,112 @@ +use std::collections::HashSet; + +/// Represents possible errors that can occur when checking for lipograms. +#[derive(Debug, PartialEq, Eq)] +pub enum LipogramError { + /// Indicates that a non-alphabetic character was found in the input. + NonAlphabeticCharacter, + /// Indicates that a missing character is not in lowercase. + NonLowercaseMissingChar, +} + +/// Computes the set of missing alphabetic characters from the input string. +/// +/// # Arguments +/// +/// * `in_str` - A string slice that contains the input text. +/// +/// # Returns +/// +/// Returns a `HashSet` containing the lowercase alphabetic characters that are not present in `in_str`. +fn compute_missing(in_str: &str) -> HashSet { + let alphabet: HashSet = ('a'..='z').collect(); + + let letters_used: HashSet = in_str + .to_lowercase() + .chars() + .filter(|c| c.is_ascii_alphabetic()) + .collect(); + + alphabet.difference(&letters_used).cloned().collect() +} + +/// Checks if the provided string is a lipogram, meaning it is missing specific characters. +/// +/// # Arguments +/// +/// * `lipogram_str` - A string slice that contains the text to be checked for being a lipogram. +/// * `missing_chars` - A reference to a `HashSet` containing the expected missing characters. +/// +/// # Returns +/// +/// Returns `Ok(true)` if the string is a lipogram that matches the provided missing characters, +/// `Ok(false)` if it does not match, or a `LipogramError` if the input contains invalid characters. +pub fn is_lipogram( + lipogram_str: &str, + missing_chars: &HashSet, +) -> Result { + for &c in missing_chars { + if !c.is_lowercase() { + return Err(LipogramError::NonLowercaseMissingChar); + } + } + + for c in lipogram_str.chars() { + if !c.is_ascii_alphabetic() && !c.is_whitespace() { + return Err(LipogramError::NonAlphabeticCharacter); + } + } + + let missing = compute_missing(lipogram_str); + Ok(missing == *missing_chars) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_lipogram { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (input, missing_chars, expected) = $tc; + assert_eq!(is_lipogram(input, &missing_chars), expected); + } + )* + } + } + + test_lipogram! { + perfect_pangram: ( + "The quick brown fox jumps over the lazy dog", + HashSet::from([]), + Ok(true) + ), + lipogram_single_missing: ( + "The quick brown fox jumped over the lazy dog", + HashSet::from(['s']), + Ok(true) + ), + lipogram_multiple_missing: ( + "The brown fox jumped over the lazy dog", + HashSet::from(['q', 'i', 'c', 'k', 's']), + Ok(true) + ), + long_lipogram_single_missing: ( + "A jovial swain should not complain of any buxom fair who mocks his pain and thinks it gain to quiz his awkward air", + HashSet::from(['e']), + Ok(true) + ), + invalid_non_lowercase_chars: ( + "The quick brown fox jumped over the lazy dog", + HashSet::from(['X']), + Err(LipogramError::NonLowercaseMissingChar) + ), + invalid_non_alphabetic_input: ( + "The quick brown fox jumps over the lazy dog 123@!", + HashSet::from([]), + Err(LipogramError::NonAlphabeticCharacter) + ), + } +} diff --git a/src/string/manacher.rs b/src/string/manacher.rs index 98ea95aa90c..e45a3f15612 100644 --- a/src/string/manacher.rs +++ b/src/string/manacher.rs @@ -69,7 +69,7 @@ pub fn manacher(s: String) -> String { .map(|(idx, _)| idx) .unwrap(); let radius_of_max = (length_of_palindrome[center_of_max] - 1) / 2; - let answer = &chars[(center_of_max - radius_of_max)..(center_of_max + radius_of_max + 1)] + let answer = &chars[(center_of_max - radius_of_max)..=(center_of_max + radius_of_max)] .iter() .collect::(); answer.replace('#', "") diff --git a/src/string/mod.rs b/src/string/mod.rs index d298a2474ff..6ba37f39f29 100644 --- a/src/string/mod.rs +++ b/src/string/mod.rs @@ -1,20 +1,53 @@ mod aho_corasick; +mod anagram; +mod autocomplete_using_trie; +mod boyer_moore_search; mod burrows_wheeler_transform; +mod duval_algorithm; mod hamming_distance; +mod isogram; +mod isomorphism; +mod jaro_winkler_distance; mod knuth_morris_pratt; +mod levenshtein_distance; +mod lipogram; mod manacher; +mod palindrome; +mod pangram; mod rabin_karp; mod reverse; +mod run_length_encoding; +mod shortest_palindrome; +mod suffix_array; +mod suffix_array_manber_myers; +mod suffix_tree; mod z_algorithm; pub use self::aho_corasick::AhoCorasick; +pub use self::anagram::check_anagram; +pub use self::autocomplete_using_trie::Autocomplete; +pub use self::boyer_moore_search::boyer_moore_search; pub use self::burrows_wheeler_transform::{ burrows_wheeler_transform, inv_burrows_wheeler_transform, }; +pub use self::duval_algorithm::duval_algorithm; pub use self::hamming_distance::hamming_distance; +pub use self::isogram::is_isogram; +pub use self::isomorphism::is_isomorphic; +pub use self::jaro_winkler_distance::jaro_winkler_distance; pub use self::knuth_morris_pratt::knuth_morris_pratt; +pub use self::levenshtein_distance::{naive_levenshtein_distance, optimized_levenshtein_distance}; +pub use self::lipogram::is_lipogram; pub use self::manacher::manacher; +pub use self::palindrome::is_palindrome; +pub use self::pangram::is_pangram; +pub use self::pangram::PangramStatus; pub use self::rabin_karp::rabin_karp; pub use self::reverse::reverse; +pub use self::run_length_encoding::{run_length_decoding, run_length_encoding}; +pub use self::shortest_palindrome::shortest_palindrome; +pub use self::suffix_array::generate_suffix_array; +pub use self::suffix_array_manber_myers::generate_suffix_array_manber_myers; +pub use self::suffix_tree::{Node, SuffixTree}; pub use self::z_algorithm::match_pattern; pub use self::z_algorithm::z_array; diff --git a/src/string/palindrome.rs b/src/string/palindrome.rs new file mode 100644 index 00000000000..6ee2d0be7ca --- /dev/null +++ b/src/string/palindrome.rs @@ -0,0 +1,74 @@ +//! A module for checking if a given string is a palindrome. + +/// Checks if the given string is a palindrome. +/// +/// A palindrome is a sequence that reads the same backward as forward. +/// This function ignores non-alphanumeric characters and is case-insensitive. +/// +/// # Arguments +/// +/// * `s` - A string slice that represents the input to be checked. +/// +/// # Returns +/// +/// * `true` if the string is a palindrome; otherwise, `false`. +pub fn is_palindrome(s: &str) -> bool { + let mut chars = s + .chars() + .filter(|c| c.is_alphanumeric()) + .map(|c| c.to_ascii_lowercase()); + + while let (Some(c1), Some(c2)) = (chars.next(), chars.next_back()) { + if c1 != c2 { + return false; + } + } + + true +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! palindrome_tests { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $inputs; + assert_eq!(is_palindrome(input), expected); + } + )* + } + } + + palindrome_tests! { + odd_palindrome: ("madam", true), + even_palindrome: ("deified", true), + single_character_palindrome: ("x", true), + single_word_palindrome: ("eye", true), + case_insensitive_palindrome: ("RaceCar", true), + mixed_case_and_punctuation_palindrome: ("A man, a plan, a canal, Panama!", true), + mixed_case_and_space_palindrome: ("No 'x' in Nixon", true), + empty_string: ("", true), + pompeii_palindrome: ("Roma-Olima-Milo-Amor", true), + napoleon_palindrome: ("Able was I ere I saw Elba", true), + john_taylor_palindrome: ("Lewd did I live, & evil I did dwel", true), + well_know_english_palindrome: ("Never odd or even", true), + palindromic_phrase: ("Rats live on no evil star", true), + names_palindrome: ("Hannah", true), + prime_minister_of_cambodia: ("Lon Nol", true), + japanese_novelist_and_manga_writer: ("Nisio Isin", true), + actor: ("Robert Trebor", true), + rock_vocalist: ("Ola Salo", true), + pokemon_species: ("Girafarig", true), + lychrel_num_56: ("121", true), + universal_palindrome_date: ("02/02/2020", true), + french_palindrome: ("une Slave valse nu", true), + finnish_palindrome: ("saippuakivikauppias", true), + non_palindrome_simple: ("hello", false), + non_palindrome_with_punctuation: ("hello!", false), + non_palindrome_mixed_case: ("Hello, World", false), + } +} diff --git a/src/string/pangram.rs b/src/string/pangram.rs new file mode 100644 index 00000000000..19ccad4a688 --- /dev/null +++ b/src/string/pangram.rs @@ -0,0 +1,92 @@ +//! This module provides functionality to check if a given string is a pangram. +//! +//! A pangram is a sentence that contains every letter of the alphabet at least once. +//! This module can distinguish between a non-pangram, a regular pangram, and a +//! perfect pangram, where each letter appears exactly once. + +use std::collections::HashSet; + +/// Represents the status of a string in relation to the pangram classification. +#[derive(PartialEq, Debug)] +pub enum PangramStatus { + NotPangram, + Pangram, + PerfectPangram, +} + +fn compute_letter_counts(pangram_str: &str) -> std::collections::HashMap { + let mut letter_counts = std::collections::HashMap::new(); + + for ch in pangram_str + .to_lowercase() + .chars() + .filter(|c| c.is_ascii_alphabetic()) + { + *letter_counts.entry(ch).or_insert(0) += 1; + } + + letter_counts +} + +/// Determines if the input string is a pangram, and classifies it as either a regular or perfect pangram. +/// +/// # Arguments +/// +/// * `pangram_str` - A reference to the string slice to be checked for pangram status. +/// +/// # Returns +/// +/// A `PangramStatus` enum indicating whether the string is a pangram, and if so, whether it is a perfect pangram. +pub fn is_pangram(pangram_str: &str) -> PangramStatus { + let letter_counts = compute_letter_counts(pangram_str); + + let alphabet: HashSet = ('a'..='z').collect(); + let used_letters: HashSet<_> = letter_counts.keys().cloned().collect(); + + if used_letters != alphabet { + return PangramStatus::NotPangram; + } + + if letter_counts.values().all(|&count| count == 1) { + PangramStatus::PerfectPangram + } else { + PangramStatus::Pangram + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! pangram_tests { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $tc; + assert_eq!(is_pangram(input), expected); + } + )* + }; + } + + pangram_tests! { + test_not_pangram_simple: ("This is not a pangram", PangramStatus::NotPangram), + test_not_pangram_day: ("today is a good day", PangramStatus::NotPangram), + test_not_pangram_almost: ("this is almost a pangram but it does not have bcfghjkqwxy and the last letter", PangramStatus::NotPangram), + test_pangram_standard: ("The quick brown fox jumps over the lazy dog", PangramStatus::Pangram), + test_pangram_boxer: ("A mad boxer shot a quick, gloved jab to the jaw of his dizzy opponent", PangramStatus::Pangram), + test_pangram_discotheques: ("Amazingly few discotheques provide jukeboxes", PangramStatus::Pangram), + test_pangram_zebras: ("How vexingly quick daft zebras jump", PangramStatus::Pangram), + test_perfect_pangram_jock: ("Mr. Jock, TV quiz PhD, bags few lynx", PangramStatus::PerfectPangram), + test_empty_string: ("", PangramStatus::NotPangram), + test_repeated_letter: ("aaaaa", PangramStatus::NotPangram), + test_non_alphabetic: ("12345!@#$%", PangramStatus::NotPangram), + test_mixed_case_pangram: ("ThE QuiCk BroWn FoX JumPs OveR tHe LaZy DoG", PangramStatus::Pangram), + test_perfect_pangram_with_symbols: ("Mr. Jock, TV quiz PhD, bags few lynx!", PangramStatus::PerfectPangram), + test_long_non_pangram: (&"a".repeat(1000), PangramStatus::NotPangram), + test_near_pangram_missing_one_letter: ("The quick brown fox jumps over the lazy do", PangramStatus::NotPangram), + test_near_pangram_missing_two_letters: ("The quick brwn f jumps ver the lazy dg", PangramStatus::NotPangram), + test_near_pangram_with_special_characters: ("Th3 qu!ck brown f0x jumps 0v3r th3 l@zy d0g.", PangramStatus::NotPangram), + } +} diff --git a/src/string/rabin_karp.rs b/src/string/rabin_karp.rs index 5fbcbc884b3..9901849990a 100644 --- a/src/string/rabin_karp.rs +++ b/src/string/rabin_karp.rs @@ -1,66 +1,84 @@ -const MODULUS: u16 = 101; -const BASE: u16 = 256; - -pub fn rabin_karp(target: String, pattern: String) -> Vec { - // Quick exit - if target.is_empty() || pattern.is_empty() || pattern.len() > target.len() { +//! This module implements the Rabin-Karp string searching algorithm. +//! It uses a rolling hash technique to find all occurrences of a pattern +//! within a target string efficiently. + +const MOD: usize = 101; +const RADIX: usize = 256; + +/// Finds all starting indices where the `pattern` appears in the `text`. +/// +/// # Arguments +/// * `text` - The string where the search is performed. +/// * `pattern` - The substring pattern to search for. +/// +/// # Returns +/// A vector of starting indices where the pattern is found. +pub fn rabin_karp(text: &str, pattern: &str) -> Vec { + if text.is_empty() || pattern.is_empty() || pattern.len() > text.len() { return vec![]; } - let pattern_hash = hash(pattern.as_str()); + let pat_hash = compute_hash(pattern); + let mut radix_pow = 1; - // Pre-calculate BASE^(n-1) - let mut pow_rem: u16 = 1; + // Compute RADIX^(n-1) % MOD for _ in 0..pattern.len() - 1 { - pow_rem *= BASE; - pow_rem %= MODULUS; + radix_pow = (radix_pow * RADIX) % MOD; } let mut rolling_hash = 0; - let mut ret = vec![]; - for i in 0..=target.len() - pattern.len() { + let mut result = vec![]; + for i in 0..=text.len() - pattern.len() { rolling_hash = if i == 0 { - hash(&target[0..pattern.len()]) + compute_hash(&text[0..pattern.len()]) } else { - recalculate_hash( - target.as_str(), - i - 1, - i + pattern.len() - 1, - rolling_hash, - pow_rem, - ) + update_hash(text, i - 1, i + pattern.len() - 1, rolling_hash, radix_pow) }; - if rolling_hash == pattern_hash && pattern[..] == target[i..i + pattern.len()] { - ret.push(i); + if rolling_hash == pat_hash && pattern[..] == text[i..i + pattern.len()] { + result.push(i); } } - ret + result } -// hash(s) is defined as BASE^(n-1) * s_0 + BASE^(n-2) * s_1 + ... + BASE^0 * s_(n-1) -fn hash(s: &str) -> u16 { - let mut res: u16 = 0; - for &c in s.as_bytes().iter() { - res = (res * BASE % MODULUS + c as u16) % MODULUS; - } - res +/// Calculates the hash of a string using the Rabin-Karp formula. +/// +/// # Arguments +/// * `s` - The string to calculate the hash for. +/// +/// # Returns +/// The hash value of the string modulo `MOD`. +fn compute_hash(s: &str) -> usize { + let mut hash_val = 0; + for &byte in s.as_bytes().iter() { + hash_val = (hash_val * RADIX + byte as usize) % MOD; + } + hash_val } -// new_hash = (old_hash - BASE^(n-1) * s_(i-n)) * BASE + s_i -fn recalculate_hash( +/// Updates the rolling hash when shifting the search window. +/// +/// # Arguments +/// * `s` - The full text where the search is performed. +/// * `old_idx` - The index of the character that is leaving the window. +/// * `new_idx` - The index of the new character entering the window. +/// * `old_hash` - The hash of the previous substring. +/// * `radix_pow` - The precomputed value of RADIX^(n-1) % MOD. +/// +/// # Returns +/// The updated hash for the new substring. +fn update_hash( s: &str, - old_index: usize, - new_index: usize, - old_hash: u16, - pow_rem: u16, -) -> u16 { + old_idx: usize, + new_idx: usize, + old_hash: usize, + radix_pow: usize, +) -> usize { let mut new_hash = old_hash; - let (old_ch, new_ch) = ( - s.as_bytes()[old_index] as u16, - s.as_bytes()[new_index] as u16, - ); - new_hash = (new_hash + MODULUS - pow_rem * old_ch % MODULUS) % MODULUS; - new_hash = (new_hash * BASE + new_ch) % MODULUS; + let old_char = s.as_bytes()[old_idx] as usize; + let new_char = s.as_bytes()[new_idx] as usize; + new_hash = (new_hash + MOD - (old_char * radix_pow % MOD)) % MOD; + new_hash = (new_hash * RADIX + new_char) % MOD; new_hash } @@ -68,76 +86,38 @@ fn recalculate_hash( mod tests { use super::*; - #[test] - fn hi_hash() { - let hash_result = hash("hi"); - assert_eq!(hash_result, 65); - } - - #[test] - fn abr_hash() { - let hash_result = hash("abr"); - assert_eq!(hash_result, 4); - } - - #[test] - fn bra_hash() { - let hash_result = hash("bra"); - assert_eq!(hash_result, 30); - } - - // Attribution to @pgimalac for his tests from Knuth-Morris-Pratt - #[test] - fn each_letter_matches() { - let index = rabin_karp("aaa".to_string(), "a".to_string()); - assert_eq!(index, vec![0, 1, 2]); - } - - #[test] - fn a_few_separate_matches() { - let index = rabin_karp("abababa".to_string(), "ab".to_string()); - assert_eq!(index, vec![0, 2, 4]); - } - - #[test] - fn one_match() { - let index = rabin_karp("ABC ABCDAB ABCDABCDABDE".to_string(), "ABCDABD".to_string()); - assert_eq!(index, vec![15]); - } - - #[test] - fn lots_of_matches() { - let index = rabin_karp("aaabaabaaaaa".to_string(), "aa".to_string()); - assert_eq!(index, vec![0, 1, 4, 7, 8, 9, 10]); - } - - #[test] - fn lots_of_intricate_matches() { - let index = rabin_karp("ababababa".to_string(), "aba".to_string()); - assert_eq!(index, vec![0, 2, 4, 6]); - } - - #[test] - fn not_found0() { - let index = rabin_karp("abcde".to_string(), "f".to_string()); - assert_eq!(index, vec![]); - } - - #[test] - fn not_found1() { - let index = rabin_karp("abcde".to_string(), "ac".to_string()); - assert_eq!(index, vec![]); - } - - #[test] - fn not_found2() { - let index = rabin_karp("ababab".to_string(), "bababa".to_string()); - assert_eq!(index, vec![]); + macro_rules! test_cases { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (text, pattern, expected) = $inputs; + assert_eq!(rabin_karp(text, pattern), expected); + } + )* + }; } - #[test] - fn empty_string() { - let index = rabin_karp("".to_string(), "abcdef".to_string()); - assert_eq!(index, vec![]); + test_cases! { + single_match_at_start: ("hello world", "hello", vec![0]), + single_match_at_end: ("hello world", "world", vec![6]), + single_match_in_middle: ("abc def ghi", "def", vec![4]), + multiple_matches: ("ababcabc", "abc", vec![2, 5]), + overlapping_matches: ("aaaaa", "aaa", vec![0, 1, 2]), + no_match: ("abcdefg", "xyz", vec![]), + pattern_is_entire_string: ("abc", "abc", vec![0]), + target_is_multiple_patterns: ("abcabcabc", "abc", vec![0, 3, 6]), + empty_text: ("", "abc", vec![]), + empty_pattern: ("abc", "", vec![]), + empty_text_and_pattern: ("", "", vec![]), + pattern_larger_than_text: ("abc", "abcd", vec![]), + large_text_small_pattern: (&("a".repeat(1000) + "b"), "b", vec![1000]), + single_char_match: ("a", "a", vec![0]), + single_char_no_match: ("a", "b", vec![]), + large_pattern_no_match: ("abc", "defghi", vec![]), + repeating_chars: ("aaaaaa", "aa", vec![0, 1, 2, 3, 4]), + special_characters: ("abc$def@ghi", "$def@", vec![3]), + numeric_and_alphabetic_mix: ("abc123abc456", "123abc", vec![3]), + case_sensitivity: ("AbcAbc", "abc", vec![]), } } diff --git a/src/string/reverse.rs b/src/string/reverse.rs index a8e72200787..bf17745a147 100644 --- a/src/string/reverse.rs +++ b/src/string/reverse.rs @@ -1,3 +1,12 @@ +/// Reverses the given string. +/// +/// # Arguments +/// +/// * `text` - A string slice that holds the string to be reversed. +/// +/// # Returns +/// +/// * A new `String` that is the reverse of the input string. pub fn reverse(text: &str) -> String { text.chars().rev().collect() } @@ -6,18 +15,26 @@ pub fn reverse(text: &str) -> String { mod tests { use super::*; - #[test] - fn test_simple() { - assert_eq!(reverse("racecar"), "racecar"); + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $test_case; + assert_eq!(reverse(input), expected); + } + )* + }; } - #[test] - fn test_assymetric() { - assert_eq!(reverse("abcdef"), "fedcba") - } - - #[test] - fn test_sentence() { - assert_eq!(reverse("step on no pets"), "step on no pets"); + test_cases! { + test_simple_palindrome: ("racecar", "racecar"), + test_non_palindrome: ("abcdef", "fedcba"), + test_sentence_with_spaces: ("step on no pets", "step on no pets"), + test_empty_string: ("", ""), + test_single_character: ("a", "a"), + test_leading_trailing_spaces: (" hello ", " olleh "), + test_unicode_characters: ("你好", "好你"), + test_mixed_content: ("a1b2c3!", "!3c2b1a"), } } diff --git a/src/string/run_length_encoding.rs b/src/string/run_length_encoding.rs new file mode 100644 index 00000000000..1952df4c230 --- /dev/null +++ b/src/string/run_length_encoding.rs @@ -0,0 +1,78 @@ +pub fn run_length_encoding(target: &str) -> String { + if target.trim().is_empty() { + return "".to_string(); + } + let mut count: i32 = 0; + let mut base_character: String = "".to_string(); + let mut encoded_target = String::new(); + + for c in target.chars() { + if base_character == *"" { + base_character = c.to_string(); + } + if c.to_string() == base_character { + count += 1; + } else { + encoded_target.push_str(&count.to_string()); + count = 1; + encoded_target.push_str(&base_character); + base_character = c.to_string(); + } + } + encoded_target.push_str(&count.to_string()); + encoded_target.push_str(&base_character); + + encoded_target +} + +pub fn run_length_decoding(target: &str) -> String { + if target.trim().is_empty() { + return "".to_string(); + } + let mut character_count = String::new(); + let mut decoded_target = String::new(); + + for c in target.chars() { + character_count.push(c); + let is_numeric: bool = character_count.parse::().is_ok(); + + if !is_numeric { + let pop_char: char = character_count.pop().unwrap(); + decoded_target.push_str( + &pop_char + .to_string() + .repeat(character_count.parse().unwrap()), + ); + character_count = "".to_string(); + } + } + + decoded_target +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_run_length { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (raw_str, encoded) = $test_case; + assert_eq!(run_length_encoding(raw_str), encoded); + assert_eq!(run_length_decoding(encoded), raw_str); + } + )* + }; + } + + test_run_length! { + empty_input: ("", ""), + repeated_char: ("aaaaaaaaaa", "10a"), + no_repeated: ("abcdefghijk", "1a1b1c1d1e1f1g1h1i1j1k"), + regular_input: ("aaaaabbbcccccdddddddddd", "5a3b5c10d"), + two_blocks_with_same_char: ("aaabbaaaa", "3a2b4a"), + long_input: ("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbcccccdddddddddd", "200a3b5c10d"), + } +} diff --git a/src/string/shortest_palindrome.rs b/src/string/shortest_palindrome.rs new file mode 100644 index 00000000000..e143590f3fc --- /dev/null +++ b/src/string/shortest_palindrome.rs @@ -0,0 +1,119 @@ +//! This module provides functions for finding the shortest palindrome +//! that can be formed by adding characters to the left of a given string. +//! References +//! +//! - [KMP](https://www.scaler.com/topics/data-structures/kmp-algorithm/) +//! - [Prefix Functions and KPM](https://oi-wiki.org/string/kmp/) + +/// Finds the shortest palindrome that can be formed by adding characters +/// to the left of the given string `s`. +/// +/// # Arguments +/// +/// * `s` - A string slice that holds the input string. +/// +/// # Returns +/// +/// Returns a new string that is the shortest palindrome, formed by adding +/// the necessary characters to the beginning of `s`. +pub fn shortest_palindrome(s: &str) -> String { + if s.is_empty() { + return "".to_string(); + } + + let original_chars: Vec = s.chars().collect(); + let suffix_table = compute_suffix(&original_chars); + + let mut reversed_chars: Vec = s.chars().rev().collect(); + // The prefix of the original string matches the suffix of the reversed string. + let prefix_match = compute_prefix_match(&original_chars, &reversed_chars, &suffix_table); + + reversed_chars.append(&mut original_chars[prefix_match[original_chars.len() - 1]..].to_vec()); + reversed_chars.iter().collect() +} + +/// Computes the suffix table used for the KMP (Knuth-Morris-Pratt) string +/// matching algorithm. +/// +/// # Arguments +/// +/// * `chars` - A slice of characters for which the suffix table is computed. +/// +/// # Returns +/// +/// Returns a vector of `usize` representing the suffix table. Each element +/// at index `i` indicates the longest proper suffix which is also a proper +/// prefix of the substring `chars[0..=i]`. +pub fn compute_suffix(chars: &[char]) -> Vec { + let mut suffix = vec![0; chars.len()]; + for i in 1..chars.len() { + let mut j = suffix[i - 1]; + while j > 0 && chars[j] != chars[i] { + j = suffix[j - 1]; + } + suffix[i] = j + (chars[j] == chars[i]) as usize; + } + suffix +} + +/// Computes the prefix matches of the original string against its reversed +/// version using the suffix table. +/// +/// # Arguments +/// +/// * `original` - A slice of characters representing the original string. +/// * `reversed` - A slice of characters representing the reversed string. +/// * `suffix` - A slice containing the suffix table computed for the original string. +/// +/// # Returns +/// +/// Returns a vector of `usize` where each element at index `i` indicates the +/// length of the longest prefix of `original` that matches a suffix of +/// `reversed[0..=i]`. +pub fn compute_prefix_match(original: &[char], reversed: &[char], suffix: &[usize]) -> Vec { + let mut match_table = vec![0; original.len()]; + match_table[0] = usize::from(original[0] == reversed[0]); + for i in 1..original.len() { + let mut j = match_table[i - 1]; + while j > 0 && reversed[i] != original[j] { + j = suffix[j - 1]; + } + match_table[i] = j + usize::from(reversed[i] == original[j]); + } + match_table +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::string::is_palindrome; + + macro_rules! test_shortest_palindrome { + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $inputs; + assert!(is_palindrome(expected)); + assert_eq!(shortest_palindrome(input), expected); + assert_eq!(shortest_palindrome(expected), expected); + } + )* + } + } + + test_shortest_palindrome! { + empty: ("", ""), + extend_left_1: ("aacecaaa", "aaacecaaa"), + extend_left_2: ("abcd", "dcbabcd"), + unicode_1: ("അ", "അ"), + unicode_2: ("a牛", "牛a牛"), + single_char: ("x", "x"), + already_palindrome: ("racecar", "racecar"), + extend_left_3: ("abcde", "edcbabcde"), + extend_left_4: ("abca", "acbabca"), + long_string: ("abcdefg", "gfedcbabcdefg"), + repetitive: ("aaaaa", "aaaaa"), + complex: ("abacdfgdcaba", "abacdgfdcabacdfgdcaba"), + } +} diff --git a/src/string/suffix_array.rs b/src/string/suffix_array.rs new file mode 100644 index 00000000000..a89575fc8e8 --- /dev/null +++ b/src/string/suffix_array.rs @@ -0,0 +1,95 @@ +// In computer science, a suffix array is a sorted array of all suffixes of a string. +// It is a data structure used in, among others, full-text indices, data-compression algorithms, +// and the field of bibliometrics. Source: https://en.wikipedia.org/wiki/Suffix_array + +use std::cmp::Ordering; + +#[derive(Clone)] +struct Suffix { + index: usize, + rank: (i32, i32), +} + +impl Suffix { + fn cmp(&self, b: &Self) -> Ordering { + let a = self; + let ((a1, a2), (b1, b2)) = (a.rank, b.rank); + match a1.cmp(&b1) { + Ordering::Equal => { + if a2 < b2 { + Ordering::Less + } else { + Ordering::Greater + } + } + o => o, + } + } +} + +pub fn generate_suffix_array(txt: &str) -> Vec { + let n = txt.len(); + let mut suffixes: Vec = vec![ + Suffix { + index: 0, + rank: (-1, -1) + }; + n + ]; + for (i, suf) in suffixes.iter_mut().enumerate() { + suf.index = i; + suf.rank.0 = (txt.chars().nth(i).expect("this should exist") as u32 - 'a' as u32) as i32; + suf.rank.1 = if (i + 1) < n { + (txt.chars().nth(i + 1).expect("this should exist") as u32 - 'a' as u32) as i32 + } else { + -1 + } + } + suffixes.sort_by(|a, b| a.cmp(b)); + let mut ind = vec![0; n]; + let mut k = 4; + while k < 2 * n { + let mut rank = 0; + let mut prev_rank = suffixes[0].rank.0; + suffixes[0].rank.0 = rank; + ind[suffixes[0].index] = 0; + + for i in 1..n { + if suffixes[i].rank.0 == prev_rank && suffixes[i].rank.1 == suffixes[i - 1].rank.1 { + prev_rank = suffixes[i].rank.0; + suffixes[i].rank.0 = rank; + } else { + prev_rank = suffixes[i].rank.0; + rank += 1; + suffixes[i].rank.0 = rank; + } + ind[suffixes[i].index] = i; + } + for i in 0..n { + let next_index = suffixes[i].index + (k / 2); + suffixes[i].rank.1 = if next_index < n { + suffixes[ind[next_index]].rank.0 + } else { + -1 + } + } + suffixes.sort_by(|a, b| a.cmp(b)); + k *= 2; + } + let mut suffix_arr = Vec::new(); + for suf in suffixes { + suffix_arr.push(suf.index); + } + suffix_arr +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_suffix_array() { + let a = generate_suffix_array("banana"); + assert_eq!(a, vec![5, 3, 1, 0, 4, 2]); + } +} diff --git a/src/string/suffix_array_manber_myers.rs b/src/string/suffix_array_manber_myers.rs new file mode 100644 index 00000000000..4c4d58c5901 --- /dev/null +++ b/src/string/suffix_array_manber_myers.rs @@ -0,0 +1,107 @@ +pub fn generate_suffix_array_manber_myers(input: &str) -> Vec { + if input.is_empty() { + return Vec::new(); + } + let n = input.len(); + let mut suffixes: Vec<(usize, &str)> = Vec::with_capacity(n); + + for (i, _suffix) in input.char_indices() { + suffixes.push((i, &input[i..])); + } + + suffixes.sort_by_key(|&(_, s)| s); + + let mut suffix_array: Vec = vec![0; n]; + let mut rank = vec![0; n]; + + let mut cur_rank = 0; + let mut prev_suffix = &suffixes[0].1; + + for (i, suffix) in suffixes.iter().enumerate() { + if &suffix.1 != prev_suffix { + cur_rank += 1; + prev_suffix = &suffix.1; + } + rank[suffix.0] = cur_rank; + suffix_array[i] = suffix.0; + } + + let mut k = 1; + let mut new_rank: Vec = vec![0; n]; + + while k < n { + suffix_array.sort_by_key(|&x| (rank[x], rank[(x + k) % n])); + + let mut cur_rank = 0; + let mut prev = suffix_array[0]; + new_rank[prev] = cur_rank; + + for &suffix in suffix_array.iter().skip(1) { + let next = suffix; + if (rank[prev], rank[(prev + k) % n]) != (rank[next], rank[(next + k) % n]) { + cur_rank += 1; + } + new_rank[next] = cur_rank; + prev = next; + } + + std::mem::swap(&mut rank, &mut new_rank); + + k <<= 1; + } + + suffix_array +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_suffix_array() { + let input = "banana"; + let expected_result = vec![5, 3, 1, 0, 4, 2]; + assert_eq!(generate_suffix_array_manber_myers(input), expected_result); + } + + #[test] + fn test_empty_string() { + let input = ""; + let expected_result: Vec = Vec::new(); + assert_eq!(generate_suffix_array_manber_myers(input), expected_result); + } + + #[test] + fn test_single_character() { + let input = "a"; + let expected_result = vec![0]; + assert_eq!(generate_suffix_array_manber_myers(input), expected_result); + } + #[test] + fn test_repeating_characters() { + let input = "zzzzzz"; + let expected_result = vec![5, 4, 3, 2, 1, 0]; + assert_eq!(generate_suffix_array_manber_myers(input), expected_result); + } + + #[test] + fn test_long_string() { + let input = "abcdefghijklmnopqrstuvwxyz"; + let expected_result: Vec = (0..26).collect(); + assert_eq!(generate_suffix_array_manber_myers(input), expected_result); + } + + #[test] + fn test_mix_of_characters() { + let input = "abracadabra!"; + let expected_result = vec![11, 10, 7, 0, 3, 5, 8, 1, 4, 6, 9, 2]; + assert_eq!(generate_suffix_array_manber_myers(input), expected_result); + } + + #[test] + fn test_whitespace_characters() { + let input = " hello world "; + let expected_result = vec![12, 0, 6, 11, 2, 1, 10, 3, 4, 5, 8, 9, 7]; + assert_eq!(generate_suffix_array_manber_myers(input), expected_result); + } +} diff --git a/src/string/suffix_tree.rs b/src/string/suffix_tree.rs new file mode 100644 index 00000000000..24cedf6b197 --- /dev/null +++ b/src/string/suffix_tree.rs @@ -0,0 +1,152 @@ +// In computer science, a suffix tree (also called PAT tree or, in an earlier form, position tree) +// is a compressed trie containing all the suffixes of the given text as their keys and positions +// in the text as their values. Suffix trees allow particularly fast implementations of many +// important string operations. Source: https://en.wikipedia.org/wiki/Suffix_tree + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Node { + pub sub: String, // substring of input string + pub ch: Vec, // vector of child nodes +} + +impl Node { + fn new(sub: String, children: Vec) -> Self { + Node { + sub, + ch: children.to_vec(), + } + } + pub fn empty() -> Self { + Node { + sub: "".to_string(), + ch: vec![], + } + } +} + +pub struct SuffixTree { + pub nodes: Vec, +} + +impl SuffixTree { + pub fn new(s: &str) -> Self { + let mut suf_tree = SuffixTree { + nodes: vec![Node::empty()], + }; + for i in 0..s.len() { + let (_, substr) = s.split_at(i); + suf_tree.add_suffix(substr); + } + suf_tree + } + fn add_suffix(&mut self, suf: &str) { + let mut n = 0; + let mut i = 0; + while i < suf.len() { + let b = suf.chars().nth(i); + let mut x2 = 0; + let mut n2: usize; + loop { + let children = &self.nodes[n].ch; + if children.len() == x2 { + n2 = self.nodes.len(); + self.nodes.push(Node::new( + { + let (_, sub) = suf.split_at(i); + sub.to_string() + }, + vec![], + )); + self.nodes[n].ch.push(n2); + return; + } + n2 = children[x2]; + if self.nodes[n2].sub.chars().next() == b { + break; + } + x2 += 1; + } + let sub2 = self.nodes[n2].sub.clone(); + let mut j = 0; + while j < sub2.len() { + if suf.chars().nth(i + j) != sub2.chars().nth(j) { + let n3 = n2; + n2 = self.nodes.len(); + self.nodes.push(Node::new( + { + let (sub, _) = sub2.split_at(j); + sub.to_string() + }, + vec![n3], + )); + let (_, temp_sub) = sub2.split_at(j); + self.nodes[n3].sub = temp_sub.to_string(); + self.nodes[n].ch[x2] = n2; + break; + } + j += 1; + } + i += j; + n = n2; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_suffix_tree() { + let suf_tree = SuffixTree::new("banana$"); + assert_eq!( + suf_tree.nodes, + vec![ + Node { + sub: "".to_string(), + ch: vec![1, 8, 6, 10] + }, + Node { + sub: "banana$".to_string(), + ch: vec![] + }, + Node { + sub: "na$".to_string(), + ch: vec![] + }, + Node { + sub: "na$".to_string(), + ch: vec![] + }, + Node { + sub: "na".to_string(), + ch: vec![2, 5] + }, + Node { + sub: "$".to_string(), + ch: vec![] + }, + Node { + sub: "na".to_string(), + ch: vec![3, 7] + }, + Node { + sub: "$".to_string(), + ch: vec![] + }, + Node { + sub: "a".to_string(), + ch: vec![4, 9] + }, + Node { + sub: "$".to_string(), + ch: vec![] + }, + Node { + sub: "$".to_string(), + ch: vec![] + } + ] + ); + } +} diff --git a/src/string/z_algorithm.rs b/src/string/z_algorithm.rs index b6c3dec839e..a2825e02ddc 100644 --- a/src/string/z_algorithm.rs +++ b/src/string/z_algorithm.rs @@ -1,3 +1,83 @@ +//! This module provides functionalities to match patterns in strings +//! and compute the Z-array for a given input string. + +/// Calculates the Z-value for a given substring of the input string +/// based on a specified pattern. +/// +/// # Parameters +/// - `input_string`: A slice of elements that represents the input string. +/// - `pattern`: A slice of elements representing the pattern to match. +/// - `start_index`: The index in the input string to start checking for matches. +/// - `z_value`: The initial Z-value to be computed. +/// +/// # Returns +/// The computed Z-value indicating the length of the matching prefix. +fn calculate_z_value( + input_string: &[T], + pattern: &[T], + start_index: usize, + mut z_value: usize, +) -> usize { + let size = input_string.len(); + let pattern_size = pattern.len(); + + while (start_index + z_value) < size && z_value < pattern_size { + if input_string[start_index + z_value] != pattern[z_value] { + break; + } + z_value += 1; + } + z_value +} + +/// Initializes the Z-array value based on a previous match and updates +/// it to optimize further calculations. +/// +/// # Parameters +/// - `z_array`: A mutable slice of the Z-array to be updated. +/// - `i`: The current index in the input string. +/// - `match_end`: The index of the last character matched in the pattern. +/// - `last_match`: The index of the last match found. +/// +/// # Returns +/// The initialized Z-array value for the current index. +fn initialize_z_array_from_previous_match( + z_array: &[usize], + i: usize, + match_end: usize, + last_match: usize, +) -> usize { + std::cmp::min(z_array[i - last_match], match_end - i + 1) +} + +/// Finds the starting indices of all full matches of the pattern +/// in the Z-array. +/// +/// # Parameters +/// - `z_array`: A slice of the Z-array containing computed Z-values. +/// - `pattern_size`: The length of the pattern to find in the Z-array. +/// +/// # Returns +/// A vector containing the starting indices of full matches. +fn find_full_matches(z_array: &[usize], pattern_size: usize) -> Vec { + z_array + .iter() + .enumerate() + .filter_map(|(idx, &z_value)| (z_value == pattern_size).then_some(idx)) + .collect() +} + +/// Matches the occurrences of a pattern in an input string starting +/// from a specified index. +/// +/// # Parameters +/// - `input_string`: A slice of elements to search within. +/// - `pattern`: A slice of elements that represents the pattern to match. +/// - `start_index`: The index in the input string to start the search. +/// - `only_full_matches`: If true, only full matches of the pattern will be returned. +/// +/// # Returns +/// A vector containing the starting indices of the matches. fn match_with_z_array( input_string: &[T], pattern: &[T], @@ -8,41 +88,53 @@ fn match_with_z_array( let pattern_size = pattern.len(); let mut last_match: usize = 0; let mut match_end: usize = 0; - let mut array = vec![0usize; size]; + let mut z_array = vec![0usize; size]; + for i in start_index..size { - // getting plain z array of a string requires matching from index - // 1 instead of 0 (which gives a trivial result instead) if i <= match_end { - array[i] = std::cmp::min(array[i - last_match], match_end - i + 1); - } - while (i + array[i]) < size && array[i] < pattern_size { - if input_string[i + array[i]] != pattern[array[i]] { - break; - } - array[i] += 1; + z_array[i] = initialize_z_array_from_previous_match(&z_array, i, match_end, last_match); } - if (i + array[i]) > (match_end + 1) { - match_end = i + array[i] - 1; + + z_array[i] = calculate_z_value(input_string, pattern, i, z_array[i]); + + if i + z_array[i] > match_end + 1 { + match_end = i + z_array[i] - 1; last_match = i; } } + if !only_full_matches { - array + z_array } else { - let mut answer: Vec = vec![]; - for (idx, number) in array.iter().enumerate() { - if *number == pattern_size { - answer.push(idx); - } - } - answer + find_full_matches(&z_array, pattern_size) } } +/// Constructs the Z-array for the given input string. +/// +/// The Z-array is an array where the i-th element is the length of the longest +/// substring starting from s[i] that is also a prefix of s. +/// +/// # Parameters +/// - `input`: A slice of the input string for which the Z-array is to be constructed. +/// +/// # Returns +/// A vector representing the Z-array of the input string. pub fn z_array(input: &[T]) -> Vec { match_with_z_array(input, input, 1, false) } +/// Matches the occurrences of a given pattern in an input string. +/// +/// This function acts as a wrapper around `match_with_z_array` to provide a simpler +/// interface for pattern matching, returning only full matches. +/// +/// # Parameters +/// - `input`: A slice of the input string where the pattern will be searched. +/// - `pattern`: A slice of the pattern to search for in the input string. +/// +/// # Returns +/// A vector of indices where the pattern matches the input string. pub fn match_pattern(input: &[T], pattern: &[T]) -> Vec { match_with_z_array(input, pattern, 0, true) } @@ -51,56 +143,67 @@ pub fn match_pattern(input: &[T], pattern: &[T]) -> Vec { mod tests { use super::*; - #[test] - fn test_z_array() { - let string = "aabaabab"; - let array = z_array(string.as_bytes()); - assert_eq!(array, vec![0, 1, 0, 4, 1, 0, 1, 0]); + macro_rules! test_match_pattern { + ($($name:ident: ($input:expr, $pattern:expr, $expected:expr),)*) => { + $( + #[test] + fn $name() { + let (input, pattern, expected) = ($input, $pattern, $expected); + assert_eq!(match_pattern(input.as_bytes(), pattern.as_bytes()), expected); + } + )* + }; } - #[test] - fn pattern_in_text() { - let text: &str = concat!( - "lorem ipsum dolor sit amet, consectetur ", - "adipiscing elit, sed do eiusmod tempor ", - "incididunt ut labore et dolore magna aliqua" - ); - let pattern1 = "rem"; - let pattern2 = "em"; - let pattern3 = ";alksdjfoiwer"; - let pattern4 = "m"; - - assert_eq!(match_pattern(text.as_bytes(), pattern1.as_bytes()), vec![2]); - assert_eq!( - match_pattern(text.as_bytes(), pattern2.as_bytes()), - vec![3, 73] - ); - assert_eq!(match_pattern(text.as_bytes(), pattern3.as_bytes()), vec![]); - assert_eq!( - match_pattern(text.as_bytes(), pattern4.as_bytes()), - vec![4, 10, 23, 68, 74, 110] - ); + macro_rules! test_z_array_cases { + ($($name:ident: ($input:expr, $expected:expr),)*) => { + $( + #[test] + fn $name() { + let (input, expected) = ($input, $expected); + assert_eq!(z_array(input.as_bytes()), expected); + } + )* + }; + } - let text2 = "aaaaaaaa"; - let pattern5 = "aaa"; - assert_eq!( - match_pattern(text2.as_bytes(), pattern5.as_bytes()), + test_match_pattern! { + simple_match: ("abcabcabc", "abc", vec![0, 3, 6]), + no_match: ("abcdef", "xyz", vec![]), + single_char_match: ("aaaaaa", "a", vec![0, 1, 2, 3, 4, 5]), + overlapping_match: ("abababa", "aba", vec![0, 2, 4]), + full_string_match: ("pattern", "pattern", vec![0]), + empty_pattern: ("nonempty", " ", vec![]), + pattern_larger_than_text: ("small", "largerpattern", vec![]), + repeated_pattern_in_text: ( + "aaaaaaaa", + "aaa", vec![0, 1, 2, 3, 4, 5] - ) + ), + pattern_not_in_lipsum: ( + concat!( + "lorem ipsum dolor sit amet, consectetur ", + "adipiscing elit, sed do eiusmod tempor ", + "incididunt ut labore et dolore magna aliqua" + ), + ";alksdjfoiwer", + vec![] + ), + pattern_in_lipsum: ( + concat!( + "lorem ipsum dolor sit amet, consectetur ", + "adipiscing elit, sed do eiusmod tempor ", + "incididunt ut labore et dolore magna aliqua" + ), + "m", + vec![4, 10, 23, 68, 74, 110] + ), } - #[test] - fn long_pattern_in_text() { - let text = vec![65u8; 1e5 as usize]; - let pattern = vec![65u8; 5e4 as usize]; - - let mut expected_answer = vec![0usize; (1e5 - 5e4 + 1f64) as usize]; - for (idx, i) in expected_answer.iter_mut().enumerate() { - *i = idx; - } - assert_eq!( - match_pattern(text.as_slice(), pattern.as_slice()), - expected_answer - ); + test_z_array_cases! { + basic_z_array: ("aabaabab", vec![0, 1, 0, 4, 1, 0, 1, 0]), + empty_string: ("", vec![]), + single_char_z_array: ("a", vec![0]), + repeated_char_z_array: ("aaaaaa", vec![0, 5, 4, 3, 2, 1]), } }