Skip to content

WACH

Weighted Actionable Counterfactual Explanations

WACH focuses on generating actionable counterfactuals with weighted feature importance.

Overview

WACH emphasizes actionability by weighting features based on their modifiability.

Usage

from counterfactuals.cf_methods.local_methods import WACH

method = WACH(
    gen_model=gen_model,
    disc_model=classifier,
    disc_model_criterion=criterion,
    device="cuda"
)

result = method.explain(
    X=instance,
    y_origin=0,
    y_target=1,
    X_train=X_train,
    y_train=y_train
)

API Reference

WACH

WACH(disc_model, target_class='other', **kwargs)

Bases: BaseCounterfactualMethod, LocalCounterfactualMixin

Source code in counterfactuals/cf_methods/local_methods/wach/wach.py
def __init__(
    self,
    disc_model: PytorchBase,
    target_class: int = "other",  # any class other than origin will do
    **kwargs,  # ignore other arguments
) -> None:
    tf.compat.v1.disable_eager_execution()
    target_proba = 1.0
    tol = 0.51  # want counterfactuals with p(class)>0.99
    self.target_class = target_class
    max_iter = 1000
    lam_init = 1e-1
    max_lam_steps = 10
    learning_rate_init = 0.1
    predict_proba = lambda x: disc_model.predict_proba(x).numpy()  # noqa: E731
    num_features = disc_model.input_size

    # TODO: Change in future to allow for different feature ranges
    feature_range = (0, 1)

    self.cf = Counterfactual(
        predict_proba,
        shape=(1, num_features),
        target_proba=target_proba,
        tol=tol,
        target_class=target_class,
        max_iter=max_iter,
        lam_init=lam_init,
        max_lam_steps=max_lam_steps,
        learning_rate_init=learning_rate_init,
        feature_range=feature_range,
    )