CEM
Contrastive Explanation Method
CEM generates counterfactuals by finding contrastive perturbations.
Overview
CEM identifies both pertinent positives (features that must be present) and pertinent negatives (features that must be absent) for a prediction.
Usage
from counterfactuals.cf_methods.local_methods.cem import CEM_CF
method = CEM_CF(
gen_model=gen_model,
disc_model=classifier,
disc_model_criterion=criterion,
device="cuda"
)
result = method.explain(
X=instance,
y_origin=0,
y_target=1,
X_train=X_train,
y_train=y_train
)
API Reference
CEM_CF
CEM_CF(disc_model, mode='PN', kappa=0.2, beta=0.1, c_init=10.0, c_steps=5, max_iterations=200, learning_rate_init=0.01, device=None, **kwargs)
Bases: BaseCounterfactualMethod, LocalCounterfactualMixin
Source code in counterfactuals/cf_methods/local_methods/cem/cem.py
| def __init__(
self,
disc_model: PytorchBase,
mode: str = "PN",
kappa: float = 0.2,
beta: float = 0.1,
c_init: float = 10.0,
c_steps: int = 5,
max_iterations: int = 200,
learning_rate_init: float = 1e-2,
device: str | None = None,
**kwargs, # ignore other arguments
) -> None:
# Initialize base/mixin (moves model to device if applicable)
super().__init__(disc_model=disc_model, device=device)
tf.compat.v1.disable_eager_execution()
predict_proba = lambda x: disc_model.predict_proba(x).numpy() # noqa: E731
num_features = disc_model.input_size
shape = (1, num_features)
# Set gradient clipping
clip = (-1000.0, 1000.0)
# Get feature ranges from model
feature_range = (0, 1) # Default range, should be adjusted based on data
self.cf = CEM(
predict_proba,
mode=mode,
shape=shape,
kappa=kappa,
beta=beta,
feature_range=feature_range,
max_iterations=max_iterations,
c_init=c_init,
c_steps=c_steps,
learning_rate_init=learning_rate_init,
clip=clip,
)
|
fit
Fit the CEM model on training data
Source code in counterfactuals/cf_methods/local_methods/cem/cem.py
| def fit(self, X_train: np.ndarray) -> None:
"""Fit the CEM model on training data"""
self.cf.fit(X_train, no_info_type="median")
|