WACH
Weighted Actionable Counterfactual Explanations
WACH focuses on generating actionable counterfactuals with weighted feature importance.
Overview
WACH emphasizes actionability by weighting features based on their modifiability.
Usage
from counterfactuals.cf_methods.local_methods import WACH
method = WACH(
gen_model=gen_model,
disc_model=classifier,
disc_model_criterion=criterion,
device="cuda"
)
result = method.explain(
X=instance,
y_origin=0,
y_target=1,
X_train=X_train,
y_train=y_train
)
API Reference
WACH
WACH(disc_model, target_class='other', **kwargs)
Bases: BaseCounterfactualMethod, LocalCounterfactualMixin
Source code in counterfactuals/cf_methods/local_methods/wach/wach.py
| def __init__(
self,
disc_model: PytorchBase,
target_class: int = "other", # any class other than origin will do
**kwargs, # ignore other arguments
) -> None:
tf.compat.v1.disable_eager_execution()
target_proba = 1.0
tol = 0.51 # want counterfactuals with p(class)>0.99
self.target_class = target_class
max_iter = 1000
lam_init = 1e-1
max_lam_steps = 10
learning_rate_init = 0.1
predict_proba = lambda x: disc_model.predict_proba(x).numpy() # noqa: E731
num_features = disc_model.input_size
# TODO: Change in future to allow for different feature ranges
feature_range = (0, 1)
self.cf = Counterfactual(
predict_proba,
shape=(1, num_features),
target_proba=target_proba,
tol=tol,
target_class=target_class,
max_iter=max_iter,
lam_init=lam_init,
max_lam_steps=max_lam_steps,
learning_rate_init=learning_rate_init,
feature_range=feature_range,
)
|