Skip to content

Commit 3114151

Browse files
authoredJan 21, 2025··
[FEA] Generalized Adjustment Criterion (#1292)
* This PR adds support for identifying generalized (non-backdoor) adjustment sets. Specifically, it adds support for finding a minimal adjustment set if one exists (it is guaranteed to find a set if one does exist). Ongoing work in the pywhy-graphs library to enumerate all m-separating sets in causal graphs will later unlock the ability to enumerate all generalized adjustment sets. Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * adding default case Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * adding minimal test Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * poe format Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * adding test, throwing on unsupported Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * tweaks Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * dependency bump Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * delete misc files Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * fix dictionary mapping Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * make test check python version Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * adding another test Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * adding docs Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * restore notebooks I dont want to change Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * remove extraneous comment Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * remove comment and print statement from example notebook Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * add comma Signed-off-by: Nicholas Parente <parentenickj@gmail.com> * address typos Signed-off-by: Nicholas Parente <parentenickj@gmail.com> --------- Signed-off-by: Nicholas Parente <parentenickj@gmail.com>
1 parent ffb761f commit 3114151

15 files changed

+740
-108
lines changed
 

‎docs/source/example_notebooks/dowhy_generalized_covariate_adjustment_example.ipynb

+296
Large diffs are not rendered by default.

‎dowhy/causal_identifier/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
AutoIdentifier,
33
BackdoorAdjustment,
44
EstimandType,
5-
construct_backdoor_estimand,
5+
GeneralizedAdjustment,
6+
construct_adjustment_estimand,
67
construct_frontdoor_estimand,
78
construct_iv_estimand,
89
identify_effect_auto,
@@ -16,11 +17,12 @@
1617
"identify_effect_auto",
1718
"identify_effect_id",
1819
"BackdoorAdjustment",
20+
"GeneralizedAdjustment",
1921
"EstimandType",
2022
"IdentifiedEstimand",
2123
"IDIdentifier",
2224
"identify_effect",
23-
"construct_backdoor_estimand",
25+
"construct_adjustment_estimand",
2426
"construct_frontdoor_estimand",
2527
"construct_iv_estimand",
2628
]
+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
class AdjustmentSet:
2+
"""Class for storing an adjustment set."""
3+
4+
BACKDOOR = "backdoor"
5+
# General adjustment sets generalize backdoor sets, but we will differentiate
6+
# between the two given the ubiquity of the backdoor criterion.
7+
GENERAL = "general"
8+
9+
def __init__(
10+
self,
11+
adjustment_type,
12+
adjustment_variables,
13+
num_paths_blocked_by_observed_nodes=None,
14+
):
15+
self.adjustment_type = adjustment_type
16+
self.adjustment_variables = adjustment_variables
17+
self.num_paths_blocked_by_observed_nodes = num_paths_blocked_by_observed_nodes
18+
19+
def get_adjustment_type(self):
20+
"""Return the technique associated with this adjustment set (backdoor, etc.)"""
21+
return self.adjustment_type
22+
23+
def get_adjustment_variables(self):
24+
"""Return a list containing the adjustment variables"""
25+
return self.adjustment_variables
26+
27+
def get_num_paths_blocked_by_observed_nodes(self):
28+
"""Return the number of paths blocked by observed nodes (optional)"""
29+
return self.num_paths_blocked_by_observed_nodes

‎dowhy/causal_identifier/auto_identifier.py

+142-45
Large diffs are not rendered by default.

‎dowhy/causal_identifier/backdoor.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import networkx as nx
22

3+
from dowhy.causal_identifier.adjustment_set import AdjustmentSet
34
from dowhy.utils.graph_operations import adjacency_matrix_to_adjacency_list
45

56

@@ -113,11 +114,13 @@ def get_backdoor_vars(self):
113114
self._path_search(adjlist, node1, node2, path_dict)
114115
if len(path_dict) != 0:
115116
obj = HittingSetAlgorithm(path_dict[(node1, node2)].get_condition_vars(), self._colliders)
116-
117-
backdoor_set = {}
118-
backdoor_set["backdoor_set"] = tuple(obj.find_set())
119-
backdoor_set["num_paths_blocked_by_observed_nodes"] = obj.num_sets()
120-
backdoor_sets.append(backdoor_set)
117+
backdoor_sets.append(
118+
AdjustmentSet(
119+
adjustment_type=AdjustmentSet.BACKDOOR,
120+
adjustment_variables=tuple(obj.find_set()),
121+
num_paths_blocked_by_observed_nodes=obj.num_sets(),
122+
)
123+
)
121124

122125
return backdoor_sets
123126

‎dowhy/causal_identifier/identified_estimand.py

+16
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,22 @@ def __init__(
1717
estimand_type=None,
1818
estimands=None,
1919
backdoor_variables=None,
20+
general_adjustment_variables=None,
2021
instrumental_variables=None,
2122
frontdoor_variables=None,
2223
mediator_variables=None,
2324
mediation_first_stage_confounders=None,
2425
mediation_second_stage_confounders=None,
2526
default_backdoor_id=None,
27+
default_adjustment_set_id=None,
2628
identifier_method=None,
2729
no_directed_path=False,
2830
):
2931
self.identifier = identifier
3032
self.treatment_variable = parse_state(treatment_variable)
3133
self.outcome_variable = parse_state(outcome_variable)
3234
self.backdoor_variables = backdoor_variables
35+
self.general_adjustment_variables = general_adjustment_variables
3336
self.instrumental_variables = parse_state(instrumental_variables)
3437
self.frontdoor_variables = parse_state(frontdoor_variables)
3538
self.mediator_variables = parse_state(mediator_variables)
@@ -38,6 +41,7 @@ def __init__(
3841
self.estimand_type = estimand_type
3942
self.estimands = estimands
4043
self.default_backdoor_id = default_backdoor_id
44+
self.default_adjustment_set_id = default_adjustment_set_id
4145
self.identifier_method = identifier_method
4246
self.no_directed_path = no_directed_path
4347

@@ -78,6 +82,13 @@ def get_instrumental_variables(self):
7882
"""Return a list containing the instrumental variables (if present)"""
7983
return self.instrumental_variables
8084

85+
def get_general_adjustment_variables(self, key: Optional[str] = None):
86+
"""Return a list containing general adjustment variables."""
87+
if key is None:
88+
return self.general_adjustment_variables[self.default_adjustment_set_id]
89+
else:
90+
return self.general_adjustment_variables[key]
91+
8192
def __deepcopy__(self, memo):
8293
return IdentifiedEstimand(
8394
self.identifier, # not deep copied
@@ -86,10 +97,12 @@ def __deepcopy__(self, memo):
8697
estimand_type=copy.deepcopy(self.estimand_type),
8798
estimands=copy.deepcopy(self.estimands),
8899
backdoor_variables=copy.deepcopy(self.backdoor_variables),
100+
general_adjustment_variables=copy.deepcopy(self.general_adjustment_variables),
89101
instrumental_variables=copy.deepcopy(self.instrumental_variables),
90102
frontdoor_variables=copy.deepcopy(self.frontdoor_variables),
91103
mediator_variables=copy.deepcopy(self.mediator_variables),
92104
default_backdoor_id=copy.deepcopy(self.default_backdoor_id),
105+
default_adjustment_set_id=copy.deepcopy(self.default_adjustment_set_id),
93106
identifier_method=copy.deepcopy(self.identifier_method),
94107
)
95108

@@ -112,6 +125,9 @@ def __str__(self, only_target_estimand: bool = False, show_all_backdoor_sets: bo
112125
# Just show the default backdoor set
113126
if k.startswith("backdoor") and k != "backdoor":
114127
continue
128+
# Just show the default generalized adjustment set
129+
if k.startswith("general") and k != "general_adjustment":
130+
continue
115131
if only_target_estimand and k != self.identifier_method:
116132
continue
117133
s += "\n### Estimand : {0}\n".format(i)

‎dowhy/graph.py

+54
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""This module defines the fundamental interfaces and functions related to causal graphs."""
22

3+
import copy
34
import itertools
45
import logging
56
import re
@@ -187,13 +188,66 @@ def is_blocked(graph: nx.DiGraph, path, conditioned_nodes):
187188
return False
188189

189190

191+
def get_ancestors(graph: nx.DiGraph, nodes):
192+
ancestors = set()
193+
for node_name in nodes:
194+
ancestors = ancestors.union(set(nx.ancestors(graph, node_name)))
195+
return ancestors
196+
197+
190198
def get_descendants(graph: nx.DiGraph, nodes):
191199
descendants = set()
192200
for node_name in nodes:
193201
descendants = descendants.union(set(nx.descendants(graph, node_name)))
194202
return descendants
195203

196204

205+
def get_proper_causal_path_nodes(graph: nx.DiGraph, action_nodes, outcome_nodes):
206+
"""Method to get the proper causal path nodes, as described in van der Zander et al. "Constructing Separators and
207+
Adjustment Sets in Ancestral Graphs", Section 4.1. We cannot use do_surgery() since we require deep copies of the given graph.
208+
209+
:param graph: the causal graph in question
210+
:param action_nodes: the action nodes
211+
:param outcome_nodes: the outcome nodes
212+
213+
:returns: the set of nodes that lie on proper causal paths from X to Y
214+
"""
215+
216+
# 1) Create a pair of modified graphs by removing inbound and outbound arrows from the action nodes, respectively.
217+
graph_post_interv = copy.deepcopy(graph) # remove incoming arrows to our action nodes
218+
edges_to_remove = [(u, v) for u, v in graph_post_interv.in_edges(action_nodes)]
219+
graph_post_interv.remove_edges_from(edges_to_remove)
220+
graph_with_action_nodes_as_sinks = copy.deepcopy(graph) # remove outbound arrows from our action nodes
221+
edges_to_remove = [(u, v) for u, v in graph_with_action_nodes_as_sinks.out_edges(action_nodes)]
222+
graph_with_action_nodes_as_sinks.remove_edges_from(edges_to_remove)
223+
224+
# 2) Use the modified graphs to identify the nodes which lie on proper causal paths from the
225+
# action nodes to the outcome nodes.
226+
de_x = get_descendants(graph_post_interv, action_nodes).union(action_nodes)
227+
an_y = get_ancestors(graph_with_action_nodes_as_sinks, outcome_nodes).union(outcome_nodes)
228+
return (set(de_x) - set(action_nodes)) & an_y
229+
230+
231+
def get_proper_backdoor_graph(graph: nx.DiGraph, action_nodes, outcome_nodes):
232+
"""Method to get the proper backdoor graph from a causal graph, as described in van der Zander et al. "Constructing Separators and
233+
Adjustment Sets in Ancestral Graphs", Section 4.1. We cannot use do_surgery() since we require deep copies of the given graph.
234+
235+
:param graph: the causal graph in question
236+
:param action_nodes: the action nodes
237+
:param outcome_nodes: the outcome nodes
238+
239+
:returns: a new graph which is the proper backdoor graph of the original
240+
"""
241+
242+
# First we can just call get_proper_causal_path_nodes, then
243+
# we remove edges from the action_nodes to the proper causal path nodes.
244+
graph_pbd = copy.deepcopy(graph)
245+
graph_pbd.remove_edges_from(
246+
[(u, v) for u in action_nodes for v in get_proper_causal_path_nodes(graph, action_nodes, outcome_nodes)]
247+
)
248+
return graph_pbd
249+
250+
197251
def check_dseparation(graph: nx.DiGraph, nodes1, nodes2, nodes3, new_graph=None, dseparation_algo="default"):
198252
if dseparation_algo == "default":
199253
if new_graph is None:

‎poetry.lock

+24-35
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎pyproject.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,10 @@ pandas = [
7171
{version = "<2.0", python = "<3.9"},
7272
{version = ">1.0", python = ">=3.9"}
7373
]
74-
networkx = ">=2.8.5"
74+
networkx = [
75+
{version = ">=3.3", python = ">=3.10"},
76+
{version = ">=2.8.5", python = "<3.10"}
77+
]
7578
sympy = ">=1.10.1"
7679
scikit-learn = ">1.0"
7780
pydot = { version = "^1.4.2", optional = true }

‎tests/causal_identifiers/base.py

+38-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22

33
from dowhy.graph import build_graph_from_str
44

5-
from .example_graphs import TEST_FRONTDOOR_GRAPH_SOLUTIONS, TEST_GRAPH_SOLUTIONS
5+
from .example_graphs import (
6+
TEST_FRONTDOOR_GRAPH_SOLUTIONS,
7+
TEST_GRAPH_SOLUTIONS,
8+
TEST_GRAPH_SOLUTIONS_COMPLETE_ADJUSTMENT,
9+
)
610

711

812
class IdentificationTestGraphSolution(object):
@@ -34,15 +38,39 @@ def __init__(
3438
observed_variables,
3539
valid_frontdoor_sets,
3640
invalid_frontdoor_sets,
41+
action_nodes=None,
42+
outcome_nodes=None,
3743
):
44+
if outcome_nodes is None:
45+
outcome_nodes = ["Y"]
46+
if action_nodes is None:
47+
action_nodes = ["X"]
3848
self.graph = build_graph_from_str(graph_str)
39-
self.action_nodes = ["X"]
40-
self.outcome_nodes = ["Y"]
49+
self.action_nodes = action_nodes
50+
self.outcome_nodes = outcome_nodes
4151
self.observed_nodes = observed_variables
4252
self.valid_frontdoor_sets = valid_frontdoor_sets
4353
self.invalid_frontdoor_sets = invalid_frontdoor_sets
4454

4555

56+
class IdentificationTestGeneralCovariateAdjustmentGraphSolution(object):
57+
def __init__(
58+
self,
59+
graph_str,
60+
observed_variables,
61+
action_nodes,
62+
outcome_nodes,
63+
minimal_adjustment_sets,
64+
exhaustive_adjustment_sets=None,
65+
):
66+
self.graph = build_graph_from_str(graph_str)
67+
self.action_nodes = action_nodes
68+
self.outcome_nodes = outcome_nodes
69+
self.observed_nodes = observed_variables
70+
self.minimal_adjustment_sets = minimal_adjustment_sets
71+
self.exhaustive_adjustment_sets = exhaustive_adjustment_sets
72+
73+
4674
@pytest.fixture(params=TEST_GRAPH_SOLUTIONS.keys())
4775
def example_graph_solution(request):
4876
return IdentificationTestGraphSolution(**TEST_GRAPH_SOLUTIONS[request.param])
@@ -51,3 +79,10 @@ def example_graph_solution(request):
5179
@pytest.fixture(params=TEST_FRONTDOOR_GRAPH_SOLUTIONS.keys())
5280
def example_frontdoor_graph_solution(request):
5381
return IdentificationTestFrontdoorGraphSolution(**TEST_FRONTDOOR_GRAPH_SOLUTIONS[request.param])
82+
83+
84+
@pytest.fixture(params=TEST_GRAPH_SOLUTIONS_COMPLETE_ADJUSTMENT.keys())
85+
def example_complete_adjustment_graph_solution(request):
86+
return IdentificationTestGeneralCovariateAdjustmentGraphSolution(
87+
**TEST_GRAPH_SOLUTIONS_COMPLETE_ADJUSTMENT[request.param]
88+
)

0 commit comments

Comments
 (0)
Please sign in to comment.