Skip to content

PPCEF

Plausible Probabilistic Counterfactual Explanations with Flows

PPCEF is the flagship method of this library, generating counterfactuals that are both valid and plausible by leveraging normalizing flows.

Overview

PPCEF optimizes counterfactuals to lie in high-density regions of the data distribution, ensuring they represent realistic inputs rather than adversarial examples.

Key Innovation

Unlike proximity-only methods, PPCEF uses a generative model (normalizing flow) to assess and maximize the plausibility of generated counterfactuals.

Algorithm

The method minimizes a combined objective:

\[ \mathcal{L} = \alpha \cdot \mathcal{L}_{\text{validity}} + \beta \cdot \mathcal{L}_{\text{proximity}} + \gamma \cdot \mathcal{L}_{\text{plausibility}} \]

Where: - \(\mathcal{L}_{\text{validity}}\): Cross-entropy loss for target class - \(\mathcal{L}_{\text{proximity}}\): Distance to original instance - \(\mathcal{L}_{\text{plausibility}}\): Negative log-likelihood under the flow

Usage

from counterfactuals.cf_methods.local_methods import PPCEF
from counterfactuals.models.generators import MaskedAutoregressiveFlow
from counterfactuals.models.classifiers import MLPClassifier

# Initialize models
gen_model = MaskedAutoregressiveFlow(...)
classifier = MLPClassifier(...)

# Create PPCEF instance
method = PPCEF(
    gen_model=gen_model,
    disc_model=classifier,
    disc_model_criterion=torch.nn.CrossEntropyLoss(),
    device="cuda"
)

# Generate counterfactual
result = method.explain(
    X=instance,
    y_origin=0,
    y_target=1,
    X_train=X_train,
    y_train=y_train,
    epochs=100,
    lr=0.01,
    alpha=1.0,
    beta=0.5
)

Parameters

Parameter Type Default Description
gen_model BaseGenerator required Trained generative model (flow)
disc_model BaseClassifier required Trained classifier
epochs int 100 Optimization iterations
lr float 0.01 Learning rate
alpha float 1.0 Validity loss weight
beta float 0.5 Proximity loss weight

Strengths

  • High plausibility of generated counterfactuals
  • Works well with tabular data
  • Supports actionability constraints

Limitations

  • Requires training a generative model
  • Slower than simple optimization methods
  • Performance depends on flow quality

References

  • [Paper citation placeholder]

API Reference

PPCEF

PPCEF(gen_model, disc_model, disc_model_criterion, device=None)

Bases: BaseCounterfactualMethod, LocalCounterfactualMixin

Source code in counterfactuals/cf_methods/local_methods/ppcef/ppcef.py
def __init__(
    self,
    gen_model: GenerativePytorchMixin,
    disc_model: PytorchBase,
    disc_model_criterion,
    device=None,
):
    self.disc_model_criterion = disc_model_criterion
    self.gen_model = gen_model
    self.disc_model = disc_model
    self.device = device if device is not None else "cpu"
    self.gen_model.to(self.device)
    self.disc_model.to(self.device)
    self.beta = 0

explain

explain(X, y_origin, y_target, X_train, y_train)

Explains the model's prediction for a given input.

Source code in counterfactuals/cf_methods/local_methods/ppcef/ppcef.py
def explain(
    self,
    X: np.ndarray,
    y_origin: np.ndarray,
    y_target: np.ndarray,
    X_train: np.ndarray,
    y_train: np.ndarray,
):
    """
    Explains the model's prediction for a given input.
    """
    raise NotImplementedError("This method is not implemented for this class.")

explain_dataloader

explain_dataloader(dataloader, epochs=1000, lr=0.0005, patience_eps=1e-05, **search_step_kwargs)

Search counterfactual explanations for the given dataloader.

Source code in counterfactuals/cf_methods/local_methods/ppcef/ppcef.py
def explain_dataloader(
    self,
    dataloader: DataLoader,
    epochs: int = 1000,
    lr: float = 0.0005,
    patience_eps: int = 1e-5,
    **search_step_kwargs,
):
    """
    Search counterfactual explanations for the given dataloader.
    """
    self.epochs = epochs
    self.gen_model.eval()
    for param in self.gen_model.parameters():
        param.requires_grad = False

    if self.disc_model:
        self.disc_model.eval()
        for param in self.disc_model.parameters():
            param.requires_grad = False

    deltas = []
    target_class = []
    original = []
    original_class = []
    for xs_origin, contexts_origin in dataloader:
        xs_origin = xs_origin.to(self.device)
        contexts_origin = contexts_origin.to(self.device)

        contexts_origin = contexts_origin.reshape(-1, 1)
        contexts_target = torch.abs(1 - contexts_origin)

        xs_origin = torch.as_tensor(xs_origin)
        xs_origin.requires_grad = False
        delta = torch.zeros_like(xs_origin, requires_grad=True)

        optimizer = optim.Adam([delta], lr=lr)
        loss_components_logging = {}

        for epoch in (epoch_pbar := tqdm(range(epochs))):
            search_step_kwargs["epoch"] = epoch
            optimizer.zero_grad()
            loss_components = self._search_step(
                delta,
                xs_origin,
                contexts_origin,
                contexts_target,
                **search_step_kwargs,
            )
            mean_loss = loss_components["loss"].mean()
            mean_loss.backward()
            optimizer.step()

            for loss_name, loss in loss_components.items():
                loss_components_logging.setdefault(f"cf_search/{loss_name}", []).append(
                    loss.mean().detach().cpu().item()
                )

            disc_loss = loss_components["loss_disc"].detach().cpu().mean().item()
            prob_loss = loss_components["max_inner"].detach().cpu().mean().item()
            epoch_pbar.set_description(
                f"Discriminator loss: {disc_loss:.4f}, Prob loss: {prob_loss:.4f}"
            )
            # if disc_loss < patience_eps and prob_loss < patience_eps:
            #     break

        deltas.append(delta.detach().cpu().numpy())
        original.append(xs_origin.detach().cpu().numpy())
        original_class.append(contexts_origin.detach().cpu().numpy())
        target_class.append(contexts_target.detach().cpu().numpy())

    deltas = np.concatenate(deltas, axis=0)
    originals = np.concatenate(original, axis=0)
    original_classes = np.concatenate(original_class, axis=0)
    target_classes = np.concatenate(target_class, axis=0)
    x_cfs = originals + deltas

    return ExplanationResult(
        x_cfs=x_cfs,
        y_cf_targets=target_classes,
        x_origs=originals,
        y_origs=original_classes,
        logs=loss_components_logging,
    )