|
4 | 4 | the future.
|
5 | 5 | """
|
6 | 6 | from enum import Enum, auto
|
7 |
| -from typing import Any, Callable, Dict, Optional, Set, Tuple, Union |
| 7 | +from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
8 | 8 |
|
9 | 9 | import networkx as nx
|
10 | 10 | import numpy as np
|
@@ -59,7 +59,7 @@ def refute_causal_structure(
|
59 | 59 | if parents and non_descendants:
|
60 | 60 | # test local Markov condition, null hypothesis: conditional independence
|
61 | 61 | lmc_p_value = conditional_independence_test(
|
62 |
| - data[node].values, data[non_descendants].values, data[parents].values |
| 62 | + data[node].to_numpy(), data[non_descendants].to_numpy(), data[parents].to_numpy() |
63 | 63 | )
|
64 | 64 | lmc_test_result = dict(p_value=lmc_p_value)
|
65 | 65 | all_p_values.append(lmc_p_value)
|
@@ -160,8 +160,8 @@ def refute_invertible_model(
|
160 | 160 | )
|
161 | 161 |
|
162 | 162 |
|
163 |
| -def _get_non_descendants(causal_graph: DirectedGraph, node: Any, exclude_parents: bool = False) -> Set[Any]: |
| 163 | +def _get_non_descendants(causal_graph: DirectedGraph, node: Any, exclude_parents: bool = False) -> List[Any]: |
164 | 164 | nodes_to_exclude = nx.descendants(causal_graph, node).union({node})
|
165 | 165 | if exclude_parents:
|
166 | 166 | nodes_to_exclude = nodes_to_exclude.union(causal_graph.predecessors(node))
|
167 |
| - return set(causal_graph.nodes).difference(nodes_to_exclude) |
| 167 | + return list(set(causal_graph.nodes).difference(nodes_to_exclude)) |
0 commit comments