Skip to content

DiCoFlex

Diverse Counterfactuals with Flexible Constraints

DiCoFlex generates diverse counterfactual explanations while respecting flexible feature constraints.

Overview

DiCoFlex extends standard counterfactual generation by:

  1. Producing multiple diverse counterfactuals
  2. Supporting flexible actionability constraints
  3. Balancing diversity with validity

Usage

from counterfactuals.cf_methods.local_methods.DiCoFlex import DiCoFlex

method = DiCoFlex(
    gen_model=gen_model,
    disc_model=classifier,
    disc_model_criterion=criterion,
    device="cuda"
)

result = method.explain(
    X=instance,
    y_origin=0,
    y_target=1,
    X_train=X_train,
    y_train=y_train
)

Parameters

Parameter Type Default Description
gen_model BaseGenerator required Trained generative model
disc_model BaseClassifier required Trained classifier
n_counterfactuals int 5 Number of CFs to generate

API Reference

DiCoFlex

DiCoFlex(gen_model, disc_model, class_to_index, mask_vectors, params, device=None)

Bases: BaseCounterfactualMethod, LocalCounterfactualMixin

Sample-based counterfactual generator backed by a conditional flow.

Source code in counterfactuals/cf_methods/local_methods/DiCoFlex/method.py
def __init__(
    self,
    gen_model: GenerativePytorchMixin,
    disc_model: PytorchBase,
    class_to_index: Dict[int, int],
    mask_vectors: list[np.ndarray],
    params: DiCoFlexParams,
    device: Optional[str] = None,
) -> None:
    super().__init__(
        gen_model=gen_model,
        disc_model=disc_model,
        device=device,
    )
    if not mask_vectors:
        raise ValueError("At least one mask vector must be supplied for DiCoFlex.")
    if params.mask_index >= len(mask_vectors):
        raise ValueError("mask_index exceeds available mask vectors.")
    if params.target_class not in class_to_index:
        raise ValueError(
            f"Target class {params.target_class} not observed in the training data."
        )
    self.class_to_index = class_to_index
    self.mask_vectors = mask_vectors
    self.params = params
    self.device = device or "cpu"
    self.gen_model.to(self.device)
    self.disc_model.to(self.device)

explain

explain(X, y_origin, y_target, X_train=None, y_train=None, **kwargs)

Generate counterfactuals for the provided samples.

Source code in counterfactuals/cf_methods/local_methods/DiCoFlex/method.py
def explain(
    self,
    X: np.ndarray,
    y_origin: np.ndarray,
    y_target: np.ndarray,
    X_train: Optional[np.ndarray] = None,
    y_train: Optional[np.ndarray] = None,
    **kwargs,
) -> ExplanationResult:
    """Generate counterfactuals for the provided samples."""
    x_np = np.asarray(X, dtype=np.float32)
    y_origin_vec = np.asarray(y_origin).reshape(-1, 1)
    y_target_vec = np.asarray(y_target).reshape(-1, 1)
    (
        cf_batch,
        y_target_flat,
        x_orig_flat,
        y_origin_flat,
        target_probs,
        valid_mask,
        log_probs,
        group_ids,
    ) = self._sample_counterfactuals(x_np, y_origin_vec, y_target_vec)
    logs = {
        "sampling/mean_target_probability": float(target_probs.mean()),
        "sampling/valid_ratio": float(valid_mask.mean()),
        "sampling/log_prob_mean": float(log_probs.mean()),
        "model_returned_mask": valid_mask.tolist(),
        "cf_group_ids": group_ids.tolist(),
    }
    return ExplanationResult(
        x_cfs=cf_batch,
        y_cf_targets=y_target_flat,
        x_origs=x_orig_flat,
        y_origs=y_origin_flat,
        logs=logs,
        cf_group_ids=group_ids,
    )

explain_dataloader

explain_dataloader(dataloader, epochs, lr, patience_eps=1e-05, **kwargs)

Adapter around explain() for DataLoader inputs.

Source code in counterfactuals/cf_methods/local_methods/DiCoFlex/method.py
def explain_dataloader(
    self,
    dataloader,
    epochs: int,
    lr: float,
    patience_eps: float = 1e-5,
    **kwargs,
) -> ExplanationResult:
    """Adapter around explain() for DataLoader inputs."""
    xs, ys = [], []
    for batch_x, batch_y in dataloader:
        xs.append(batch_x.numpy())
        ys.append(batch_y.numpy())
    X = np.vstack(xs)
    y_origin = np.concatenate(ys)
    y_target = np.full_like(y_origin, fill_value=self.params.target_class)
    return self.explain(
        X=X,
        y_origin=y_origin,
        y_target=y_target,
    )