From b7108958aa9bf7c74f17a33c4d046da004cd1d6b Mon Sep 17 00:00:00 2001 From: Carlos Villar Date: Thu, 20 Oct 2022 22:02:53 +0200 Subject: [PATCH 1/5] Added Viterbi algorithm Fixes: #7465 Squashed commits --- dynamic_programming/viterbi.py | 324 +++++++++++++++++++++++++++++++++ 1 file changed, 324 insertions(+) create mode 100644 dynamic_programming/viterbi.py diff --git a/dynamic_programming/viterbi.py b/dynamic_programming/viterbi.py new file mode 100644 index 000000000000..64565ae09ede --- /dev/null +++ b/dynamic_programming/viterbi.py @@ -0,0 +1,324 @@ +from collections.abc import Callable +from typing import Any, Dict, List, Tuple + + +def viterbi( + observations_space: List[str], + states_space: List[str], + initial_probabilities: Dict[str, float], + transition_probabilities: Dict[str, Dict[str, float]], + emission_probabilities: Dict[str, Dict[str, float]], +) -> List[str]: + """ + Viterbi Algorithm, to find the most likely path of + states from the start and the expected output. + https://en.wikipedia.org/wiki/Viterbi_algorithm + + Wikipedia example + >>> observations = ["normal", "cold", "dizzy"] + >>> states = ["Healthy", "Fever"] + >>> start_p = {"Healthy": 0.6, "Fever": 0.4} + >>> trans_p = { + ... "Healthy": {"Healthy": 0.7, "Fever": 0.3}, + ... "Fever": {"Healthy": 0.4, "Fever": 0.6}, + ... } + >>> emit_p = { + ... "Healthy": {"normal": 0.5, "cold": 0.4, "dizzy": 0.1}, + ... "Fever": {"normal": 0.1, "cold": 0.3, "dizzy": 0.6}, + ... } + >>> viterbi(observations, states, start_p, trans_p, emit_p) + ['Healthy', 'Healthy', 'Fever'] + + # >>> viterbi((), states, start_p, trans_p, emit_p) + # Traceback (most recent call last): + # ... + # ValueError: There's an empty parameter + # + # >>> viterbi(observations, (), start_p, trans_p, emit_p) + # Traceback (most recent call last): + # ... + # ValueError: There's an empty parameter + # + # >>> viterbi(observations, states, {}, trans_p, emit_p) + # Traceback (most recent call last): + # ... + # ValueError: There's an empty parameter + # + # >>> viterbi(observations, states, start_p, {}, emit_p) + # Traceback (most recent call last): + # ... + # ValueError: There's an empty parameter + # + # >>> viterbi(observations, states, start_p, trans_p, {}) + # Traceback (most recent call last): + # ... + # ValueError: There's an empty parameter + # + # >>> viterbi("invalid", states, start_p, trans_p, emit_p) + # Traceback (most recent call last): + # ... + # ValueError: observations_space must be a list + # + # >>> viterbi(("valid", 123), states, start_p, trans_p, emit_p) + # Traceback (most recent call last): + # ... + # ValueError: observations_space must be a list of strings + # + # >>> viterbi(observations, "invalid", start_p, trans_p, emit_p) + # Traceback (most recent call last): + # ... + # ValueError: states_space must be a list + # + # >>> viterbi(observations, ("valid", 123), start_p, trans_p, emit_p) + # Traceback (most recent call last): + # ... + # ValueError: states_space must be a list of strings + # + # >>> viterbi(observations, states, "invalid", trans_p, emit_p) + # Traceback (most recent call last): + # ... + # ValueError: initial_probabilities must be a dict + # + # >>> viterbi(observations, states, {2:2}, trans_p, emit_p) + # Traceback (most recent call last): + # ... + # ValueError: initial_probabilities all keys must be strings + # + # >>> viterbi(observations, states, {"a":2}, trans_p, emit_p) + # Traceback (most recent call last): + # ... + # ValueError: initial_probabilities all values must be float + # + # >>> viterbi(observations, states, start_p, "invalid", emit_p) + # Traceback (most recent call last): + # ... + # ValueError: transition_probabilities must be a dict + # + # >>> viterbi(observations, states, start_p, {"a":2}, emit_p) + # Traceback (most recent call last): + # ... + # ValueError: transition_probabilities all values must be dict + # + # >>> viterbi(observations, states, start_p, {2:{2:2}}, emit_p) + # Traceback (most recent call last): + # ... + # ValueError: transition_probabilities all keys must be strings + # + # >>> viterbi(observations, states, start_p, {"a":{2:2}}, emit_p) + # Traceback (most recent call last): + # ... + # ValueError: transition_probabilities all keys must be strings + # + # >>> viterbi(observations, states, start_p, {"a":{"b":2}}, emit_p) + # Traceback (most recent call last): + # ... + # ValueError: transition_probabilities nested dictionary all values must be float + # + # >>> viterbi(observations, states, start_p, trans_p, "invalid") + # Traceback (most recent call last): + # ... + # ValueError: emission_probabilities must be a dict + # + # >>> viterbi(observations, states, start_p, trans_p, None) + # Traceback (most recent call last): + # ... + # ValueError: There's an empty parameter + + """ + _validation( + observations_space, + states_space, + initial_probabilities, + transition_probabilities, + emission_probabilities, + ) + # Creates data structures and fill initial step + pointers, probabilities = _initialise_probabilities_and_pointers( + observations_space, + states_space, + initial_probabilities, + emission_probabilities, + ) + + # Function for the process forward calculations + def _prior_state(observation: str, prior_observation: str, state: str) -> Callable: + def _func(k_state: str) -> float: + return ( + probabilities[(k_state, prior_observation)] + * transition_probabilities[k_state][state] + * emission_probabilities[state][observation] + ) + + return _func + + # Fills the data structure with the probabilities of + # different transitions and pointers to previous states + _process_forward( + observations_space, states_space, _prior_state, probabilities, pointers + ) + + # The final observation + last_state = _extract_final_state(observations_space, states_space, probabilities) + + # Process pointers backwards + return _extract_best_path(observations_space, last_state, pointers) + + +def _validation( + observations_space: Any, + states_space: Any, + initial_probabilities: Any, + transition_probabilities: Any, + emission_probabilities: Any, +) -> None: + _validate_not_empty( + observations_space, + states_space, + initial_probabilities, + transition_probabilities, + emission_probabilities, + ) + _validate_lists(observations_space, states_space) + _validate_dicts( + initial_probabilities, transition_probabilities, emission_probabilities + ) + + +def _validate_not_empty( + observations_space: Any, + states_space: Any, + initial_probabilities: Any, + transition_probabilities: Any, + emission_probabilities: Any, +) -> None: + if not all( + [ + observations_space, + states_space, + initial_probabilities, + transition_probabilities, + emission_probabilities, + ] + ): + raise ValueError("There's an empty parameter") + + +def _validate_lists(observations_space: Any, states_space: Any) -> None: + _validate_list(observations_space, "observations_space") + _validate_list(states_space, "states_space") + + +def _validate_list(_object: Any, var_name: str) -> None: + if not isinstance(_object, list): + raise ValueError(f"{var_name} must be a list") + else: + for x in _object: + if not isinstance(x, str): + raise ValueError(f"{var_name} must be a list of strings") + + +def _validate_dicts( + initial_probabilities: Any, + transition_probabilities: Any, + emission_probabilities: Any, +) -> None: + _validate_dict(initial_probabilities, "initial_probabilities", float) + _validate_nested_dict(transition_probabilities, "transition_probabilities") + _validate_nested_dict(emission_probabilities, "emission_probabilities") + + +def _validate_nested_dict(_object: Any, var_name: str) -> None: + _validate_dict(_object, var_name, dict) + for x in _object.values(): + _validate_dict(x, var_name, float, True) + + +def _validate_dict(_object: Any, var_name: str, value_type: type, nested: bool = False): + if not isinstance(_object, dict): + raise ValueError(f"{var_name} must be a dict") + if not all(isinstance(x, str) for x in _object): + raise ValueError(f"{var_name} all keys must be strings") + if not all(isinstance(x, value_type) for x in _object.values()): + nested_text = "nested dictionary " if nested else "" + raise ValueError( + f"{var_name} {nested_text}all values must be {value_type.__name__}" + ) + + +def _initialise_probabilities_and_pointers( + observations_space: List[str], + states_space: List[str], + initial_probabilities: Dict[str, float], + emission_probabilities: Dict[str, Dict[str, float]], +) -> Tuple[dict, dict]: + probabilities = {} + pointers = {} + for state in states_space: + observation = observations_space[0] + probabilities[(state, observation)] = ( + initial_probabilities[state] * emission_probabilities[state][observation] + ) + pointers[(state, observation)] = None + return pointers, probabilities + + +def _process_forward( + observations_space: List[str], + states_space: List[str], + _prior_state: Callable, + probabilities: dict, + pointers: dict, +) -> None: + for o in range(1, len(observations_space)): + observation = observations_space[o] + prior_observation = observations_space[o - 1] + for state in states_space: + arg_max = _arg_max( + _prior_state(observation, prior_observation, state), states_space + ) + + probabilities[(state, observation)] = _prior_state( + observation, prior_observation, state + )(arg_max) + pointers[(state, observation)] = arg_max + + +def _extract_final_state(observations_space, states_space, probabilities): + final_observation = observations_space[len(observations_space) - 1] + + def _best_final_state(k_state: str) -> float: + return probabilities[(k_state, final_observation)] + + last_state = _arg_max(_best_final_state, states_space) + return last_state + + +def _extract_best_path( + observations_space: List[str], + last_observation: str, + pointers: dict, +) -> List[str]: + previous = last_observation + result = [] + for o in range(len(observations_space) - 1, -1, -1): + result.append(previous) + previous = pointers[previous, observations_space[o]] + result.reverse() + return result + + +def _arg_max(prior_state: Callable, states_space: List[str]) -> str: + arg_max = "" + max_probability = -1 + for k_state in states_space: + probability = prior_state(k_state) + if probability > max_probability: + max_probability = probability + arg_max = k_state + return arg_max + + +if __name__ == "__main__": + from doctest import testmod + + testmod() From e6d825eaee9d3ad86e30c415a1f7da36ed39351c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 22 Oct 2022 10:53:00 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dynamic_programming/viterbi.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/dynamic_programming/viterbi.py b/dynamic_programming/viterbi.py index 64565ae09ede..283e529c5e9d 100644 --- a/dynamic_programming/viterbi.py +++ b/dynamic_programming/viterbi.py @@ -3,12 +3,12 @@ def viterbi( - observations_space: List[str], - states_space: List[str], - initial_probabilities: Dict[str, float], - transition_probabilities: Dict[str, Dict[str, float]], - emission_probabilities: Dict[str, Dict[str, float]], -) -> List[str]: + observations_space: list[str], + states_space: list[str], + initial_probabilities: dict[str, float], + transition_probabilities: dict[str, dict[str, float]], + emission_probabilities: dict[str, dict[str, float]], +) -> list[str]: """ Viterbi Algorithm, to find the most likely path of states from the start and the expected output. @@ -246,11 +246,11 @@ def _validate_dict(_object: Any, var_name: str, value_type: type, nested: bool = def _initialise_probabilities_and_pointers( - observations_space: List[str], - states_space: List[str], - initial_probabilities: Dict[str, float], - emission_probabilities: Dict[str, Dict[str, float]], -) -> Tuple[dict, dict]: + observations_space: list[str], + states_space: list[str], + initial_probabilities: dict[str, float], + emission_probabilities: dict[str, dict[str, float]], +) -> tuple[dict, dict]: probabilities = {} pointers = {} for state in states_space: @@ -263,8 +263,8 @@ def _initialise_probabilities_and_pointers( def _process_forward( - observations_space: List[str], - states_space: List[str], + observations_space: list[str], + states_space: list[str], _prior_state: Callable, probabilities: dict, pointers: dict, @@ -294,10 +294,10 @@ def _best_final_state(k_state: str) -> float: def _extract_best_path( - observations_space: List[str], + observations_space: list[str], last_observation: str, pointers: dict, -) -> List[str]: +) -> list[str]: previous = last_observation result = [] for o in range(len(observations_space) - 1, -1, -1): @@ -307,7 +307,7 @@ def _extract_best_path( return result -def _arg_max(prior_state: Callable, states_space: List[str]) -> str: +def _arg_max(prior_state: Callable, states_space: list[str]) -> str: arg_max = "" max_probability = -1 for k_state in states_space: From 99fed0ca5142d79f37e6648ba7369f97494297cd Mon Sep 17 00:00:00 2001 From: Carlos Villar Date: Sat, 22 Oct 2022 13:46:56 +0200 Subject: [PATCH 3/5] Added doctest for validators --- dynamic_programming/viterbi.py | 345 ++++++++++++++++++++++----------- 1 file changed, 231 insertions(+), 114 deletions(-) diff --git a/dynamic_programming/viterbi.py b/dynamic_programming/viterbi.py index 64565ae09ede..7658d303611a 100644 --- a/dynamic_programming/viterbi.py +++ b/dynamic_programming/viterbi.py @@ -10,119 +10,119 @@ def viterbi( emission_probabilities: Dict[str, Dict[str, float]], ) -> List[str]: """ - Viterbi Algorithm, to find the most likely path of - states from the start and the expected output. - https://en.wikipedia.org/wiki/Viterbi_algorithm - - Wikipedia example - >>> observations = ["normal", "cold", "dizzy"] - >>> states = ["Healthy", "Fever"] - >>> start_p = {"Healthy": 0.6, "Fever": 0.4} - >>> trans_p = { - ... "Healthy": {"Healthy": 0.7, "Fever": 0.3}, - ... "Fever": {"Healthy": 0.4, "Fever": 0.6}, - ... } - >>> emit_p = { - ... "Healthy": {"normal": 0.5, "cold": 0.4, "dizzy": 0.1}, - ... "Fever": {"normal": 0.1, "cold": 0.3, "dizzy": 0.6}, - ... } - >>> viterbi(observations, states, start_p, trans_p, emit_p) - ['Healthy', 'Healthy', 'Fever'] - - # >>> viterbi((), states, start_p, trans_p, emit_p) - # Traceback (most recent call last): - # ... - # ValueError: There's an empty parameter - # - # >>> viterbi(observations, (), start_p, trans_p, emit_p) - # Traceback (most recent call last): - # ... - # ValueError: There's an empty parameter - # - # >>> viterbi(observations, states, {}, trans_p, emit_p) - # Traceback (most recent call last): - # ... - # ValueError: There's an empty parameter - # - # >>> viterbi(observations, states, start_p, {}, emit_p) - # Traceback (most recent call last): - # ... - # ValueError: There's an empty parameter - # - # >>> viterbi(observations, states, start_p, trans_p, {}) - # Traceback (most recent call last): - # ... - # ValueError: There's an empty parameter - # - # >>> viterbi("invalid", states, start_p, trans_p, emit_p) - # Traceback (most recent call last): - # ... - # ValueError: observations_space must be a list - # - # >>> viterbi(("valid", 123), states, start_p, trans_p, emit_p) - # Traceback (most recent call last): - # ... - # ValueError: observations_space must be a list of strings - # - # >>> viterbi(observations, "invalid", start_p, trans_p, emit_p) - # Traceback (most recent call last): - # ... - # ValueError: states_space must be a list - # - # >>> viterbi(observations, ("valid", 123), start_p, trans_p, emit_p) - # Traceback (most recent call last): - # ... - # ValueError: states_space must be a list of strings - # - # >>> viterbi(observations, states, "invalid", trans_p, emit_p) - # Traceback (most recent call last): - # ... - # ValueError: initial_probabilities must be a dict - # - # >>> viterbi(observations, states, {2:2}, trans_p, emit_p) - # Traceback (most recent call last): - # ... - # ValueError: initial_probabilities all keys must be strings - # - # >>> viterbi(observations, states, {"a":2}, trans_p, emit_p) - # Traceback (most recent call last): - # ... - # ValueError: initial_probabilities all values must be float - # - # >>> viterbi(observations, states, start_p, "invalid", emit_p) - # Traceback (most recent call last): - # ... - # ValueError: transition_probabilities must be a dict - # - # >>> viterbi(observations, states, start_p, {"a":2}, emit_p) - # Traceback (most recent call last): - # ... - # ValueError: transition_probabilities all values must be dict - # - # >>> viterbi(observations, states, start_p, {2:{2:2}}, emit_p) - # Traceback (most recent call last): - # ... - # ValueError: transition_probabilities all keys must be strings - # - # >>> viterbi(observations, states, start_p, {"a":{2:2}}, emit_p) - # Traceback (most recent call last): - # ... - # ValueError: transition_probabilities all keys must be strings - # - # >>> viterbi(observations, states, start_p, {"a":{"b":2}}, emit_p) - # Traceback (most recent call last): - # ... - # ValueError: transition_probabilities nested dictionary all values must be float - # - # >>> viterbi(observations, states, start_p, trans_p, "invalid") - # Traceback (most recent call last): - # ... - # ValueError: emission_probabilities must be a dict - # - # >>> viterbi(observations, states, start_p, trans_p, None) - # Traceback (most recent call last): - # ... - # ValueError: There's an empty parameter + Viterbi Algorithm, to find the most likely path of + states from the start and the expected output. + https://en.wikipedia.org/wiki/Viterbi_algorithm + sdafads + Wikipedia example + >>> observations = ["normal", "cold", "dizzy"] + >>> states = ["Healthy", "Fever"] + >>> start_p = {"Healthy": 0.6, "Fever": 0.4} + >>> trans_p = { + ... "Healthy": {"Healthy": 0.7, "Fever": 0.3}, + ... "Fever": {"Healthy": 0.4, "Fever": 0.6}, + ... } + >>> emit_p = { + ... "Healthy": {"normal": 0.5, "cold": 0.4, "dizzy": 0.1}, + ... "Fever": {"normal": 0.1, "cold": 0.3, "dizzy": 0.6}, + ... } + >>> viterbi(observations, states, start_p, trans_p, emit_p) + ['Healthy', 'Healthy', 'Fever'] + + >>> viterbi((), states, start_p, trans_p, emit_p) + Traceback (most recent call last): + ... + ValueError: There's an empty parameter + + >>> viterbi(observations, (), start_p, trans_p, emit_p) + Traceback (most recent call last): + ... + ValueError: There's an empty parameter + + >>> viterbi(observations, states, {}, trans_p, emit_p) + Traceback (most recent call last): + ... + ValueError: There's an empty parameter + + >>> viterbi(observations, states, start_p, {}, emit_p) + Traceback (most recent call last): + ... + ValueError: There's an empty parameter + + >>> viterbi(observations, states, start_p, trans_p, {}) + Traceback (most recent call last): + ... + ValueError: There's an empty parameter + + >>> viterbi("invalid", states, start_p, trans_p, emit_p) + Traceback (most recent call last): + ... + ValueError: observations_space must be a list + + >>> viterbi(["valid", 123], states, start_p, trans_p, emit_p) + Traceback (most recent call last): + ... + ValueError: observations_space must be a list of strings + + >>> viterbi(observations, "invalid", start_p, trans_p, emit_p) + Traceback (most recent call last): + ... + ValueError: states_space must be a list + + >>> viterbi(observations, ["valid", 123], start_p, trans_p, emit_p) + Traceback (most recent call last): + ... + ValueError: states_space must be a list of strings + + >>> viterbi(observations, states, "invalid", trans_p, emit_p) + Traceback (most recent call last): + ... + ValueError: initial_probabilities must be a dict + + >>> viterbi(observations, states, {2:2}, trans_p, emit_p) + Traceback (most recent call last): + ... + ValueError: initial_probabilities all keys must be strings + + >>> viterbi(observations, states, {"a":2}, trans_p, emit_p) + Traceback (most recent call last): + ... + ValueError: initial_probabilities all values must be float + + >>> viterbi(observations, states, start_p, "invalid", emit_p) + Traceback (most recent call last): + ... + ValueError: transition_probabilities must be a dict + + >>> viterbi(observations, states, start_p, {"a":2}, emit_p) + Traceback (most recent call last): + ... + ValueError: transition_probabilities all values must be dict + + >>> viterbi(observations, states, start_p, {2:{2:2}}, emit_p) + Traceback (most recent call last): + ... + ValueError: transition_probabilities all keys must be strings + + >>> viterbi(observations, states, start_p, {"a":{2:2}}, emit_p) + Traceback (most recent call last): + ... + ValueError: transition_probabilities all keys must be strings + + >>> viterbi(observations, states, start_p, {"a":{"b":2}}, emit_p) + Traceback (most recent call last): + ... + ValueError: transition_probabilities nested dictionary all values must be float + + >>> viterbi(observations, states, start_p, trans_p, "invalid") + Traceback (most recent call last): + ... + ValueError: emission_probabilities must be a dict + + >>> viterbi(observations, states, start_p, trans_p, None) + Traceback (most recent call last): + ... + ValueError: There's an empty parameter """ _validation( @@ -171,6 +171,25 @@ def _validation( transition_probabilities: Any, emission_probabilities: Any, ) -> None: + """ + >>> observations = ["normal", "cold", "dizzy"] + >>> states = ["Healthy", "Fever"] + >>> start_p = {"Healthy": 0.6, "Fever": 0.4} + >>> trans_p = { + ... "Healthy": {"Healthy": 0.7, "Fever": 0.3}, + ... "Fever": {"Healthy": 0.4, "Fever": 0.6}, + ... } + >>> emit_p = { + ... "Healthy": {"normal": 0.5, "cold": 0.4, "dizzy": 0.1}, + ... "Fever": {"normal": 0.1, "cold": 0.3, "dizzy": 0.6}, + ... } + >>> _validation(observations, states, start_p, trans_p, emit_p) + + >>> _validation([], states, start_p, trans_p, emit_p) + Traceback (most recent call last): + ... + ValueError: There's an empty parameter + """ _validate_not_empty( observations_space, states_space, @@ -191,6 +210,18 @@ def _validate_not_empty( transition_probabilities: Any, emission_probabilities: Any, ) -> None: + """ + >>> _validate_not_empty(["a"], ["b"], {"c":0.5}, {"d": {"e": 0.6}}, {"f": {"g": 0.7}}) + + >>> _validate_not_empty(["a"], ["b"], {"c":0.5}, {}, {"f": {"g": 0.7}}) + Traceback (most recent call last): + ... + ValueError: There's an empty parameter + >>> _validate_not_empty(["a"], ["b"], None, {"d": {"e": 0.6}}, {"f": {"g": 0.7}}) + Traceback (most recent call last): + ... + ValueError: There's an empty parameter + """ if not all( [ observations_space, @@ -204,11 +235,37 @@ def _validate_not_empty( def _validate_lists(observations_space: Any, states_space: Any) -> None: + """ + >>> _validate_lists(["a"], ["b"]) + + >>> _validate_lists(1234, ["b"]) + Traceback (most recent call last): + ... + ValueError: observations_space must be a list + + >>> _validate_lists(["a"], [3]) + Traceback (most recent call last): + ... + ValueError: states_space must be a list of strings + """ _validate_list(observations_space, "observations_space") _validate_list(states_space, "states_space") def _validate_list(_object: Any, var_name: str) -> None: + """ + >>> _validate_list(["a"], "mock_name") + + >>> _validate_list("a", "mock_name") + Traceback (most recent call last): + ... + ValueError: mock_name must be a list + >>> _validate_list([0.5], "mock_name") + Traceback (most recent call last): + ... + ValueError: mock_name must be a list of strings + + """ if not isinstance(_object, list): raise ValueError(f"{var_name} must be a list") else: @@ -222,18 +279,78 @@ def _validate_dicts( transition_probabilities: Any, emission_probabilities: Any, ) -> None: + """ + >>> _validate_dicts({"c":0.5}, {"d": {"e": 0.6}}, {"f": {"g": 0.7}}) + + >>> _validate_dicts("invalid", {"d": {"e": 0.6}}, {"f": {"g": 0.7}}) + Traceback (most recent call last): + ... + ValueError: initial_probabilities must be a dict + >>> _validate_dicts({"c":0.5}, {2: {"e": 0.6}}, {"f": {"g": 0.7}}) + Traceback (most recent call last): + ... + ValueError: transition_probabilities all keys must be strings + >>> _validate_dicts({"c":0.5}, {"d": {"e": 0.6}}, {"f": {2: 0.7}}) + Traceback (most recent call last): + ... + ValueError: emission_probabilities all keys must be strings + >>> _validate_dicts({"c":0.5}, {"d": {"e": 0.6}}, {"f": {"g": "h"}}) + Traceback (most recent call last): + ... + ValueError: emission_probabilities nested dictionary all values must be float + """ _validate_dict(initial_probabilities, "initial_probabilities", float) _validate_nested_dict(transition_probabilities, "transition_probabilities") _validate_nested_dict(emission_probabilities, "emission_probabilities") def _validate_nested_dict(_object: Any, var_name: str) -> None: + """ + >>> _validate_nested_dict({"a":{"b": 0.5}}, "mock_name") + + >>> _validate_nested_dict("invalid", "mock_name") + Traceback (most recent call last): + ... + ValueError: mock_name must be a dict + >>> _validate_nested_dict({"a": 8}, "mock_name") + Traceback (most recent call last): + ... + ValueError: mock_name all values must be dict + >>> _validate_nested_dict({"a":{2: 0.5}}, "mock_name") + Traceback (most recent call last): + ... + ValueError: mock_name all keys must be strings + >>> _validate_nested_dict({"a":{"b": 4}}, "mock_name") + Traceback (most recent call last): + ... + ValueError: mock_name nested dictionary all values must be float + """ _validate_dict(_object, var_name, dict) for x in _object.values(): _validate_dict(x, var_name, float, True) def _validate_dict(_object: Any, var_name: str, value_type: type, nested: bool = False): + """ + >>> _validate_dict({"b": 0.5}, "mock_name", float) + + >>> _validate_dict("invalid", "mock_name", float) + Traceback (most recent call last): + ... + ValueError: mock_name must be a dict + >>> _validate_dict({"a": 8}, "mock_name", dict) + Traceback (most recent call last): + ... + ValueError: mock_name all values must be dict + >>> _validate_dict({2: 0.5}, "mock_name",float, True) + Traceback (most recent call last): + ... + ValueError: mock_name all keys must be strings + >>> _validate_dict({"b": 4}, "mock_name", float,True) + Traceback (most recent call last): + ... + ValueError: mock_name nested dictionary all values must be float + """ if not isinstance(_object, dict): raise ValueError(f"{var_name} must be a dict") if not all(isinstance(x, str) for x in _object): @@ -250,7 +367,7 @@ def _initialise_probabilities_and_pointers( states_space: List[str], initial_probabilities: Dict[str, float], emission_probabilities: Dict[str, Dict[str, float]], -) -> Tuple[dict, dict]: +) -> tuple[dict, dict]: probabilities = {} pointers = {} for state in states_space: From b17c499c687829b309189903353b287b89d7db7e Mon Sep 17 00:00:00 2001 From: Carlos Villar Date: Sat, 22 Oct 2022 14:07:13 +0200 Subject: [PATCH 4/5] moved all extracted functions to the main function --- dynamic_programming/viterbi.py | 199 +++++++++++++-------------------- 1 file changed, 78 insertions(+), 121 deletions(-) diff --git a/dynamic_programming/viterbi.py b/dynamic_programming/viterbi.py index 7658d303611a..ef2519f22f4d 100644 --- a/dynamic_programming/viterbi.py +++ b/dynamic_programming/viterbi.py @@ -1,14 +1,13 @@ -from collections.abc import Callable -from typing import Any, Dict, List, Tuple +from typing import Any def viterbi( - observations_space: List[str], - states_space: List[str], - initial_probabilities: Dict[str, float], - transition_probabilities: Dict[str, Dict[str, float]], - emission_probabilities: Dict[str, Dict[str, float]], -) -> List[str]: + observations_space: list, + states_space: list, + initial_probabilities: dict, + transition_probabilities: dict, + emission_probabilities: dict, +) -> list: """ Viterbi Algorithm, to find the most likely path of states from the start and the expected output. @@ -133,35 +132,65 @@ def viterbi( emission_probabilities, ) # Creates data structures and fill initial step - pointers, probabilities = _initialise_probabilities_and_pointers( - observations_space, - states_space, - initial_probabilities, - emission_probabilities, - ) + probabilities: dict = {} + pointers: dict = {} + for state in states_space: + observation = observations_space[0] + probabilities[(state, observation)] = ( + initial_probabilities[state] * emission_probabilities[state][observation] + ) + pointers[(state, observation)] = None - # Function for the process forward calculations - def _prior_state(observation: str, prior_observation: str, state: str) -> Callable: - def _func(k_state: str) -> float: - return ( - probabilities[(k_state, prior_observation)] - * transition_probabilities[k_state][state] + # Fills the data structure with the probabilities of + # different transitions and pointers to previous states + for o in range(1, len(observations_space)): + observation = observations_space[o] + prior_observation = observations_space[o - 1] + for state in states_space: + # Calculates the argmax for probability function + arg_max = "" + max_probability = -1 + for k_state in states_space: + probability = ( + probabilities[(k_state, prior_observation)] + * transition_probabilities[k_state][state] + * emission_probabilities[state][observation] + ) + if probability > max_probability: + max_probability = probability + arg_max = k_state + + # Update probabilities and pointers dicts + probabilities[(state, observation)] = ( + probabilities[(arg_max, prior_observation)] + * transition_probabilities[arg_max][state] * emission_probabilities[state][observation] ) - return _func - - # Fills the data structure with the probabilities of - # different transitions and pointers to previous states - _process_forward( - observations_space, states_space, _prior_state, probabilities, pointers - ) + pointers[(state, observation)] = arg_max # The final observation - last_state = _extract_final_state(observations_space, states_space, probabilities) + final_observation = observations_space[len(observations_space) - 1] + + # argmax for given final observation + arg_max = "" + max_probability = -1 + for k_state in states_space: + probability = probabilities[(k_state, final_observation)] + if probability > max_probability: + max_probability = probability + arg_max = k_state + last_state = arg_max # Process pointers backwards - return _extract_best_path(observations_space, last_state, pointers) + previous = last_state + result = [] + for o in range(len(observations_space) - 1, -1, -1): + result.append(previous) + previous = pointers[previous, observations_space[o]] + result.reverse() + + return result def _validation( @@ -211,7 +240,8 @@ def _validate_not_empty( emission_probabilities: Any, ) -> None: """ - >>> _validate_not_empty(["a"], ["b"], {"c":0.5}, {"d": {"e": 0.6}}, {"f": {"g": 0.7}}) + >>> _validate_not_empty(["a"], ["b"], {"c":0.5}, + ... {"d": {"e": 0.6}}, {"f": {"g": 0.7}}) >>> _validate_not_empty(["a"], ["b"], {"c":0.5}, {}, {"f": {"g": 0.7}}) Traceback (most recent call last): @@ -332,25 +362,25 @@ def _validate_nested_dict(_object: Any, var_name: str) -> None: def _validate_dict(_object: Any, var_name: str, value_type: type, nested: bool = False): """ - >>> _validate_dict({"b": 0.5}, "mock_name", float) + >>> _validate_dict({"b": 0.5}, "mock_name", float) - >>> _validate_dict("invalid", "mock_name", float) - Traceback (most recent call last): - ... - ValueError: mock_name must be a dict - >>> _validate_dict({"a": 8}, "mock_name", dict) - Traceback (most recent call last): - ... - ValueError: mock_name all values must be dict - >>> _validate_dict({2: 0.5}, "mock_name",float, True) - Traceback (most recent call last): - ... - ValueError: mock_name all keys must be strings - >>> _validate_dict({"b": 4}, "mock_name", float,True) - Traceback (most recent call last): - ... - ValueError: mock_name nested dictionary all values must be float - """ + >>> _validate_dict("invalid", "mock_name", float) + Traceback (most recent call last): + ... + ValueError: mock_name must be a dict + >>> _validate_dict({"a": 8}, "mock_name", dict) + Traceback (most recent call last): + ... + ValueError: mock_name all values must be dict + >>> _validate_dict({2: 0.5}, "mock_name",float, True) + Traceback (most recent call last): + ... + ValueError: mock_name all keys must be strings + >>> _validate_dict({"b": 4}, "mock_name", float,True) + Traceback (most recent call last): + ... + ValueError: mock_name nested dictionary all values must be float + """ if not isinstance(_object, dict): raise ValueError(f"{var_name} must be a dict") if not all(isinstance(x, str) for x in _object): @@ -362,79 +392,6 @@ def _validate_dict(_object: Any, var_name: str, value_type: type, nested: bool = ) -def _initialise_probabilities_and_pointers( - observations_space: List[str], - states_space: List[str], - initial_probabilities: Dict[str, float], - emission_probabilities: Dict[str, Dict[str, float]], -) -> tuple[dict, dict]: - probabilities = {} - pointers = {} - for state in states_space: - observation = observations_space[0] - probabilities[(state, observation)] = ( - initial_probabilities[state] * emission_probabilities[state][observation] - ) - pointers[(state, observation)] = None - return pointers, probabilities - - -def _process_forward( - observations_space: List[str], - states_space: List[str], - _prior_state: Callable, - probabilities: dict, - pointers: dict, -) -> None: - for o in range(1, len(observations_space)): - observation = observations_space[o] - prior_observation = observations_space[o - 1] - for state in states_space: - arg_max = _arg_max( - _prior_state(observation, prior_observation, state), states_space - ) - - probabilities[(state, observation)] = _prior_state( - observation, prior_observation, state - )(arg_max) - pointers[(state, observation)] = arg_max - - -def _extract_final_state(observations_space, states_space, probabilities): - final_observation = observations_space[len(observations_space) - 1] - - def _best_final_state(k_state: str) -> float: - return probabilities[(k_state, final_observation)] - - last_state = _arg_max(_best_final_state, states_space) - return last_state - - -def _extract_best_path( - observations_space: List[str], - last_observation: str, - pointers: dict, -) -> List[str]: - previous = last_observation - result = [] - for o in range(len(observations_space) - 1, -1, -1): - result.append(previous) - previous = pointers[previous, observations_space[o]] - result.reverse() - return result - - -def _arg_max(prior_state: Callable, states_space: List[str]) -> str: - arg_max = "" - max_probability = -1 - for k_state in states_space: - probability = prior_state(k_state) - if probability > max_probability: - max_probability = probability - arg_max = k_state - return arg_max - - if __name__ == "__main__": from doctest import testmod From 5ddae7cd8601f5f169fcc393586a49aa6dd3e69b Mon Sep 17 00:00:00 2001 From: Carlos Villar Date: Sat, 22 Oct 2022 14:46:52 +0200 Subject: [PATCH 5/5] Forgot a type hint --- dynamic_programming/viterbi.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dynamic_programming/viterbi.py b/dynamic_programming/viterbi.py index ef2519f22f4d..93ab845e2ae8 100644 --- a/dynamic_programming/viterbi.py +++ b/dynamic_programming/viterbi.py @@ -360,7 +360,9 @@ def _validate_nested_dict(_object: Any, var_name: str) -> None: _validate_dict(x, var_name, float, True) -def _validate_dict(_object: Any, var_name: str, value_type: type, nested: bool = False): +def _validate_dict( + _object: Any, var_name: str, value_type: type, nested: bool = False +) -> None: """ >>> _validate_dict({"b": 0.5}, "mock_name", float)