Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
:param data: pandas.DataFrame containing the data
:param identified_estimand: dowhy.causal_identifier.IdentifiedEstimand: and estimand using a backdoor method
for effect identification.
:param treatments: list or str: names of the treatment variables
:param outcomes: list or str: names of the outcome variables
:param variable_types: dict: A dictionary containing the variable's names and types. 'c' for continuous, 'o'
for ordered, 'd' for discrete, and 'u' for unordered discrete.
:param keep_original_treatment: bool: Whether to use `make_treatment_effective`, or to keep the original
treatment assignments.
:param params: (optional) additional method parameters
"""
self._data = data.copy()
self._target_estimand = identified_estimand
self._treatment_names = parse_state(treatments)
self._outcome_names = parse_state(outcomes)
self._estimate = None
self._variable_types = variable_types
self.num_cores = num_cores
self.point_sampler = True
self.sampler = None
self.keep_original_treatment = keep_original_treatment
if params is not None:
for key, value in params.items():
setattr(self, key, value)
self._df = self._data.copy()
if not self._variable_types:
self._infer_variable_types()
def __init__(self, treatment_variable, outcome_variable,
estimand_type=None, estimands=None,
backdoor_variables=None, instrumental_variables=None):
self.treatment_variable = parse_state(treatment_variable)
self.outcome_variable = parse_state(outcome_variable)
self.backdoor_variables = parse_state(backdoor_variables)
self.instrumental_variables = parse_state(instrumental_variables)
self.estimand_type = estimand_type
self.estimands = estimands
self.identifier_method = None
:param graph: path to DOT file containing a DAG or a string containing
a DAG specification in DOT format
:param common_causes: names of common causes of treatment and _outcome
:param instruments: names of instrumental variables for the effect of
treatment on outcome
:param effect_modifiers: names of variables that can modify the treatment effect (useful for heterogeneous treatment effect estimation)
:param estimand_type: the type of estimand requested (currently only "nonparametric-ate" is supported). In the future, may support other specific parametric forms of identification.
:proceed_when_unidentifiable: does the identification proceed by ignoring potential unobserved confounders. Binary flag.
:missing_nodes_as_confounders: Binary flag indicating whether variables in the dataframe that are not included in the causal graph, should be automatically included as confounder nodes.
:returns: an instance of CausalModel class
"""
self._data = data
self._treatment = parse_state(treatment)
self._outcome = parse_state(outcome)
self._estimand_type = estimand_type
self._proceed_when_unidentifiable = proceed_when_unidentifiable
self._missing_nodes_as_confounders = missing_nodes_as_confounders
if 'logging_level' in kwargs:
logging.basicConfig(level=kwargs['logging_level'])
else:
logging.basicConfig(level=logging.INFO)
# TODO: move the logging level argument to a json file. Tue 20 Feb 2018 06:56:27 PM DST
self.logger = logging.getLogger(__name__)
if graph is None:
self.logger.warning("Causal Graph not provided. DoWhy will construct a graph based on data inputs.")
self._common_causes = parse_state(common_causes)
self._instruments = parse_state(instruments)
self._effect_modifiers = parse_state(effect_modifiers)
:param num_cores: int: if the inference method only supports sampling a point at a time, this will parallelize
sampling.
:param variable_types: dict: The dictionary containing the variable types. Must contain the union of the causal
state, control variables, and the outcome.
:param outcome: str: The outcome variable.
:param params: dict: extra parameters to set as attributes on the sampler object
:param dot_graph: str: A string specifying the causal graph.
:param common_causes: list: A list of strings containing the variable names to control for.
:param estimand_type: str: 'nonparametric-ate' is the only one currently supported. Others may be added later, to allow for specific, parametric estimands.
:param proceed_when_unidentifiable: bool: A flag to over-ride user prompts to proceed when effects aren't
identifiable with the assumptions provided.
:param stateful: bool: Whether to retain state. By default, the do operation is stateless.
:return: pandas.DataFrame: A DataFrame containing the sampled outcome
"""
x, keep_original_treatment = self.parse_x(x)
outcome = parse_state(outcome)
if not stateful or method != self._method:
self.reset()
if not self._causal_model:
self._causal_model = CausalModel(self._obj,
[xi for xi in x.keys()],
outcome,
graph=dot_graph,
common_causes=common_causes,
instruments=None,
estimand_type=estimand_type,
proceed_when_unidentifiable=proceed_when_unidentifiable)
#self._identified_estimand = self._causal_model.identify_effect()
if not self._sampler:
self._method = method
do_sampler_class = do_samplers.get_class_object(method + "_sampler")
self._sampler = do_sampler_class(self._obj,
self._estimand_type = estimand_type
self._proceed_when_unidentifiable = proceed_when_unidentifiable
self._missing_nodes_as_confounders = missing_nodes_as_confounders
if 'logging_level' in kwargs:
logging.basicConfig(level=kwargs['logging_level'])
else:
logging.basicConfig(level=logging.INFO)
# TODO: move the logging level argument to a json file. Tue 20 Feb 2018 06:56:27 PM DST
self.logger = logging.getLogger(__name__)
if graph is None:
self.logger.warning("Causal Graph not provided. DoWhy will construct a graph based on data inputs.")
self._common_causes = parse_state(common_causes)
self._instruments = parse_state(instruments)
self._effect_modifiers = parse_state(effect_modifiers)
if common_causes is not None and instruments is not None:
self._graph = CausalGraph(
self._treatment,
self._outcome,
common_cause_names=self._common_causes,
instrument_names=self._instruments,
effect_modifier_names = self._effect_modifiers,
observed_node_names=self._data.columns.tolist()
)
elif common_causes is not None:
self._graph = CausalGraph(
self._treatment,
self._outcome,
common_cause_names=self._common_causes,
effect_modifier_names = self._effect_modifiers,
observed_node_names=self._data.columns.tolist()
def __init__(self,
treatment_name, outcome_name,
graph=None,
common_cause_names=None,
instrument_names=None,
effect_modifier_names=None,
observed_node_names=None,
missing_nodes_as_confounders=False):
self.treatment_name = parse_state(treatment_name)
self.outcome_name = parse_state(outcome_name)
instrument_names = parse_state(instrument_names)
common_cause_names = parse_state(common_cause_names)
effect_modifier_names = parse_state(effect_modifier_names)
self.logger = logging.getLogger(__name__)
if graph is None:
self._graph = nx.DiGraph()
self._graph = self.build_graph(common_cause_names,
instrument_names, effect_modifier_names)
elif re.match(r".*\.dot", graph):
# load dot file
try:
import pygraphviz as pgv
self._graph = nx.DiGraph(nx.drawing.nx_agraph.read_dot(graph))
except Exception as e:
self.logger.error("Pygraphviz cannot be loaded. " + str(e) + "\nTrying pydot...")
try:
import pydot
self._graph = nx.DiGraph(nx.drawing.nx_pydot.read_dot(graph))
self._outcome = parse_state(outcome)
self._estimand_type = estimand_type
self._proceed_when_unidentifiable = proceed_when_unidentifiable
self._missing_nodes_as_confounders = missing_nodes_as_confounders
if 'logging_level' in kwargs:
logging.basicConfig(level=kwargs['logging_level'])
else:
logging.basicConfig(level=logging.INFO)
# TODO: move the logging level argument to a json file. Tue 20 Feb 2018 06:56:27 PM DST
self.logger = logging.getLogger(__name__)
if graph is None:
self.logger.warning("Causal Graph not provided. DoWhy will construct a graph based on data inputs.")
self._common_causes = parse_state(common_causes)
self._instruments = parse_state(instruments)
self._effect_modifiers = parse_state(effect_modifiers)
if common_causes is not None and instruments is not None:
self._graph = CausalGraph(
self._treatment,
self._outcome,
common_cause_names=self._common_causes,
instrument_names=self._instruments,
effect_modifier_names = self._effect_modifiers,
observed_node_names=self._data.columns.tolist()
)
elif common_causes is not None:
self._graph = CausalGraph(
self._treatment,
self._outcome,
common_cause_names=self._common_causes,
effect_modifier_names = self._effect_modifiers,
def __init__(self,
treatment_name, outcome_name,
graph=None,
common_cause_names=None,
instrument_names=None,
effect_modifier_names=None,
observed_node_names=None,
missing_nodes_as_confounders=False):
self.treatment_name = parse_state(treatment_name)
self.outcome_name = parse_state(outcome_name)
instrument_names = parse_state(instrument_names)
common_cause_names = parse_state(common_cause_names)
effect_modifier_names = parse_state(effect_modifier_names)
self.logger = logging.getLogger(__name__)
if graph is None:
self._graph = nx.DiGraph()
self._graph = self.build_graph(common_cause_names,
instrument_names, effect_modifier_names)
elif re.match(r".*\.dot", graph):
# load dot file
try:
import pygraphviz as pgv
self._graph = nx.DiGraph(nx.drawing.nx_agraph.read_dot(graph))
except Exception as e:
self.logger.error("Pygraphviz cannot be loaded. " + str(e) + "\nTrying pydot...")
def get_causes(self, nodes, remove_edges = None):
nodes = parse_state(nodes)
new_graph=None
if remove_edges is not None:
new_graph = self._graph.copy() # caution: shallow copy of the attributes
sources = parse_state(remove_edges["sources"])
targets = parse_state(remove_edges["targets"])
for s in sources:
for t in targets:
new_graph.remove_edge(s, t)
causes = set()
for v in nodes:
causes = causes.union(self.get_ancestors(v, new_graph=new_graph))
return causes
def get_common_causes(self, nodes1, nodes2):
"""
Assume that nodes1 causes nodes2 (e.g., nodes1 are the treatments and nodes2 are the outcomes)
"""
# TODO Refactor to remove this from here and only implement this logic in causalIdentifier. Unnecessary assumption of nodes1 to be causing nodes2.
nodes1 = parse_state(nodes1)
nodes2 = parse_state(nodes2)
causes_1 = set()
causes_2 = set()
for node in nodes1:
causes_1 = causes_1.union(self.get_ancestors(node))
for node in nodes2:
# Cannot simply compute ancestors, since that will also include nodes1 and its parents (e.g. instruments)
parents_2 = self.get_parents(node)
for parent in parents_2:
if parent not in nodes1:
causes_2 = causes_2.union(set([parent,]))
causes_2 = causes_2.union(self.get_ancestors(parent))
return list(causes_1.intersection(causes_2))