How to use the causalml.match.NearestNeighborMatch function in causalml

To help you get started, we’ve selected a few causalml examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github uber / causalml / causalml / match.py View on Github external
def single_match(self, score_cols, pihat_threshold, caliper):
        matcher = NearestNeighborMatch(caliper=caliper, replace=True)
        df_matched = matcher.match(
            data=self.df[self.df[self.ps_col] < pihat_threshold],
            treatment_col=self.treatment_col, score_cols=score_cols
        )
        return df_matched
github uber / causalml / causalml / match.py View on Github external
logger.info('shape: {}\n{}'.format(df.shape, df.head()))

    pm = ElasticNetPropensityModel(random_state=42)
    w = df[args.treatment_col].values
    X = load_data(data=df,
                  features=args.feature_cols,
                  transformations=PROPENSITY_FEATURE_TRANSFORMATIONS)

    logger.info('Scoring with a propensity model: {}'.format(pm))
    df[SCORE_COL] = pm.fit_predict(X, w)

    logger.info('Balance before matching:\n{}'.format(create_table_one(data=df,
                                                                       treatment_col=args.treatment_col,
                                                                       features=MATCHING_COVARIATES)))
    logger.info('Matching based on the propensity score with the nearest neighbor model')
    psm = NearestNeighborMatch(replace=args.replace,
                               ratio=args.ratio,
                               random_state=42)
    matched = psm.match_by_group(data=df,
                                 treatment_col=args.treatment_col,
                                 score_cols=[SCORE_COL],
                                 groupby_col=args.groupby_col)
    logger.info('shape: {}\n{}'.format(matched.shape, matched.head()))

    logger.info('Balance after matching:\n{}'.format(create_table_one(data=matched,
                                                                      treatment_col=args.treatment_col,
                                                                      features=MATCHING_COVARIATES)))
    matched.to_csv(args.output_file, index=False)
    logger.info('Matched data saved as {}'.format(args.output_file))