Skip to content

CEM

Contrastive Explanation Method

CEM generates counterfactuals by finding contrastive perturbations.

Overview

CEM identifies both pertinent positives (features that must be present) and pertinent negatives (features that must be absent) for a prediction.

Usage

from counterfactuals.cf_methods.local_methods.cem import CEM_CF

method = CEM_CF(
    gen_model=gen_model,
    disc_model=classifier,
    disc_model_criterion=criterion,
    device="cuda"
)

result = method.explain(
    X=instance,
    y_origin=0,
    y_target=1,
    X_train=X_train,
    y_train=y_train
)

API Reference

CEM_CF

CEM_CF(disc_model, mode='PN', kappa=0.2, beta=0.1, c_init=10.0, c_steps=5, max_iterations=200, learning_rate_init=0.01, device=None, **kwargs)

Bases: BaseCounterfactualMethod, LocalCounterfactualMixin

Source code in counterfactuals/cf_methods/local_methods/cem/cem.py
def __init__(
    self,
    disc_model: PytorchBase,
    mode: str = "PN",
    kappa: float = 0.2,
    beta: float = 0.1,
    c_init: float = 10.0,
    c_steps: int = 5,
    max_iterations: int = 200,
    learning_rate_init: float = 1e-2,
    device: str | None = None,
    **kwargs,  # ignore other arguments
) -> None:
    # Initialize base/mixin (moves model to device if applicable)
    super().__init__(disc_model=disc_model, device=device)

    tf.compat.v1.disable_eager_execution()
    predict_proba = lambda x: disc_model.predict_proba(x).numpy()  # noqa: E731
    num_features = disc_model.input_size
    shape = (1, num_features)

    # Set gradient clipping
    clip = (-1000.0, 1000.0)

    # Get feature ranges from model
    feature_range = (0, 1)  # Default range, should be adjusted based on data

    self.cf = CEM(
        predict_proba,
        mode=mode,
        shape=shape,
        kappa=kappa,
        beta=beta,
        feature_range=feature_range,
        max_iterations=max_iterations,
        c_init=c_init,
        c_steps=c_steps,
        learning_rate_init=learning_rate_init,
        clip=clip,
    )

fit

fit(X_train)

Fit the CEM model on training data

Source code in counterfactuals/cf_methods/local_methods/cem/cem.py
def fit(self, X_train: np.ndarray) -> None:
    """Fit the CEM model on training data"""
    self.cf.fit(X_train, no_info_type="median")