|
| 1 | +import logging |
| 2 | +import warnings |
| 3 | +from typing import List, Optional |
| 4 | + |
| 5 | +from dowhy.causal_refuter import CausalRefuter |
| 6 | +from dowhy.causal_refuters.assess_overlap_overrule import OverlapConfig, OverruleAnalyzer, SupportConfig |
| 7 | + |
| 8 | +logger = logging.getLogger(__name__) |
| 9 | + |
| 10 | + |
| 11 | +class AssessOverlap(CausalRefuter): |
| 12 | + """Assess Overlap |
| 13 | +
|
| 14 | + This class implements the OverRule algorithm for assessing support and overlap via Boolean Rulesets, from [1]. |
| 15 | +
|
| 16 | + [1] Oberst, M., Johansson, F., Wei, D., Gao, T., Brat, G., Sontag, D., & Varshney, K. (2020). Characterization of |
| 17 | + Overlap in Observational Studies. In S. Chiappa & R. Calandra (Eds.), Proceedings of the Twenty Third International |
| 18 | + Conference on Artificial Intelligence and Statistics (Vol. 108, pp. 788–798). PMLR. https://arxiv.org/abs/1907.04138 |
| 19 | + """ |
| 20 | + |
| 21 | + def __init__(self, *args, **kwargs): |
| 22 | + """ |
| 23 | + Initialize the parameters required for the refuter. |
| 24 | +
|
| 25 | + Arguments are passed through to the `refute_estimate` method. See dowhy.causal_refuters.assess_overlap_overrule |
| 26 | + for the definition of the `SupportConfig` and `OverlapConfig` dataclasses that define optimization |
| 27 | + hyperparameters. |
| 28 | +
|
| 29 | + .. warning:: |
| 30 | + This method is only compatible with estimators that use backdoor adjustment, and will attempt to acquire |
| 31 | + the set of backdoor variables via `self._target_estimand.get_backdoor_variables()`. |
| 32 | +
|
| 33 | + :param: cat_feats: List[str]: List of categorical features, all others will be discretized |
| 34 | + :param: support_config: SupportConfig: DataClass with configuration options for learning support rules |
| 35 | + :param: overlap_config: OverlapConfig: DataClass with configuration options for learning overlap rules |
| 36 | + :param: overlap_eps: float: Defines the range of propensity scores for a point to be considered in the overlap |
| 37 | + region, with the range defined as `(overlap_eps, 1 - overlap_eps)`, defaults to 0.1 |
| 38 | + :param: overrule_verbose: bool: Enable verbose logging of optimization output, defaults to False |
| 39 | + :param: support_only: bool: Only fit rules to describe the support region (do not fit overlap rules), defaults to False |
| 40 | + :param: overlap_only: bool: Only fit rules to describe the overlap region (do not fit support rules), defaults to False |
| 41 | + """ |
| 42 | + super().__init__(*args, **kwargs) |
| 43 | + # TODO: Check that the target estimand has backdoor variables? |
| 44 | + self._backdoor_vars = self._target_estimand.get_backdoor_variables() |
| 45 | + self._cat_feats = kwargs.pop("cat_feats", []) |
| 46 | + self._support_config = kwargs.pop("support_config", None) |
| 47 | + self._overlap_config = kwargs.pop("overlap_config", None) |
| 48 | + self._overlap_eps = kwargs.pop("overlap_eps", 0.1) |
| 49 | + if self._overlap_eps < 0 or self._overlap_eps > 1: |
| 50 | + raise ValueError(f"Value of `overlap_eps` must be in [0, 1], got {self._overlap_eps}") |
| 51 | + self._support_only = kwargs.pop("support_only", False) |
| 52 | + self._overlap_only = kwargs.pop("overlap_only", False) |
| 53 | + self._overrule_verbose = kwargs.pop("overrule_verbose", False) |
| 54 | + |
| 55 | + def refute_estimate(self, show_progress_bar=False): |
| 56 | + """ |
| 57 | + Learn overlap and support rules. |
| 58 | +
|
| 59 | + :param show_progress_bar: Not implemented, will raise error if set to True, defaults to False |
| 60 | + :type show_progress_bar: bool |
| 61 | + :raises NotImplementedError: Will raise this error if show_progress_bar=True |
| 62 | + :returns: object of class OverruleAnalyzer |
| 63 | + """ |
| 64 | + if show_progress_bar: |
| 65 | + warnings.warn("No progress bar is available for OverRule") |
| 66 | + |
| 67 | + return assess_support_and_overlap_overrule( |
| 68 | + data=self._data, |
| 69 | + backdoor_vars=self._backdoor_vars, |
| 70 | + treatment_name=self._treatment_name, |
| 71 | + cat_feats=self._cat_feats, |
| 72 | + overlap_config=self._overlap_config, |
| 73 | + support_config=self._support_config, |
| 74 | + overlap_eps=self._overlap_eps, |
| 75 | + support_only=self._support_only, |
| 76 | + overlap_only=self._overlap_only, |
| 77 | + verbose=self._overrule_verbose, |
| 78 | + ) |
| 79 | + |
| 80 | + |
| 81 | +def assess_support_and_overlap_overrule( |
| 82 | + data, |
| 83 | + backdoor_vars: List[str], |
| 84 | + treatment_name: str, |
| 85 | + cat_feats: List[str] = [], |
| 86 | + overlap_config: Optional[OverlapConfig] = None, |
| 87 | + support_config: Optional[SupportConfig] = None, |
| 88 | + overlap_eps: float = 0.1, |
| 89 | + support_only: bool = False, |
| 90 | + overlap_only: bool = False, |
| 91 | + verbose: bool = False, |
| 92 | +): |
| 93 | + """ |
| 94 | + Learn support and overlap rules using OverRule. |
| 95 | +
|
| 96 | + :param data: Data containing backdoor variables and treatment name |
| 97 | + :param backdoor_vars: List of backdoor variables. Support and overlap rules will only be learned with respect to |
| 98 | + these variables |
| 99 | + :type backdoor_vars: List[str] |
| 100 | + :param treatment_name: Treatment name |
| 101 | + :type treatment_name: str |
| 102 | + :param cat_feats: Categorical features |
| 103 | + :type cat_feats: List[str] |
| 104 | + :param overlap_config: Configuration for learning overlap rules |
| 105 | + :type overlap_config: OverlapConfig |
| 106 | + :param support_config: Configuration for learning support rules |
| 107 | + :type support_config: SupportConfig |
| 108 | + :param: overlap_eps: float: Defines the range of propensity scores for a point to be considered in the overlap |
| 109 | + region, with the range defined as `(overlap_eps, 1 - overlap_eps)`, defaults to 0.1 |
| 110 | + :param: support_only: bool: Only fit the support region |
| 111 | + :param: overlap_only: bool: Only fit the overlap region |
| 112 | + :param: verbose: bool: Enable verbose logging of optimization output, defaults to False |
| 113 | + """ |
| 114 | + analyzer = OverruleAnalyzer( |
| 115 | + backdoor_vars=backdoor_vars, |
| 116 | + treatment_name=treatment_name, |
| 117 | + cat_feats=cat_feats, |
| 118 | + overlap_config=overlap_config, |
| 119 | + support_config=support_config, |
| 120 | + overlap_eps=overlap_eps, |
| 121 | + support_only=support_only, |
| 122 | + overlap_only=overlap_only, |
| 123 | + verbose=verbose, |
| 124 | + ) |
| 125 | + analyzer.fit(data) |
| 126 | + return analyzer |
0 commit comments