Skip to content

ReViCE

Regional Variant of PPCEF (Group PPCEF)

ReViCE generates counterfactuals for groups of similar instances.

Overview

ReViCE extends PPCEF to handle groups, finding counterfactual transformations that work for clusters of similar instances.

Usage

from counterfactuals.cf_methods.group_methods.group_ppcef import RPPCEF

method = RPPCEF(
    gen_model=gen_model,
    disc_model=classifier,
    disc_model_criterion=criterion,
    device="cuda"
)

result = method.explain(
    X=X_test,
    y_origin=y_test,
    y_target=target_class,
    X_train=X_train,
    y_train=y_train,
    n_groups=5
)

API Reference

RPPCEF

RPPCEF(cf_method_type, gen_model, disc_model, disc_model_criterion, init_cf_method_from_kmeans=False, K=None, X=None, device=None, actionable_features=None, **kwargs)

Bases: BaseCounterfactualMethod, GroupCounterfactualMixin

Source code in counterfactuals/cf_methods/group_methods/group_ppcef/rppcef.py
def __init__(
    self,
    cf_method_type: str,
    gen_model: GenerativePytorchMixin,
    disc_model: PytorchBase,
    disc_model_criterion: torch.nn.modules.loss._Loss,
    init_cf_method_from_kmeans: bool = False,
    K: int = None,
    X: np.ndarray = None,
    device: str = None,
    # TODO: poprawa nazewnictwa
    actionable_features: list = None,
    **kwargs,
):
    # Initialize mixins / base to set up models and device
    super().__init__(
        gen_model=gen_model,
        disc_model=disc_model,
        disc_model_criterion=disc_model_criterion,
        device=device,
        **kwargs,
    )

    self.actionable_features = actionable_features
    # initialize delta after we know device/other attributes
    self.delta = self._init_cf_method(cf_method_type, K, init_cf_method_from_kmeans, X)
    self.loss_components_logs = {}

explain

explain(X, y_origin, y_target, X_train, y_train)

Explains the model's prediction for a given input.

Source code in counterfactuals/cf_methods/group_methods/group_ppcef/rppcef.py
def explain(
    self,
    X: np.ndarray,
    y_origin: np.ndarray,
    y_target: np.ndarray,
    X_train: np.ndarray,
    y_train: np.ndarray,
):
    """
    Explains the model's prediction for a given input.
    """
    raise NotImplementedError("This method is not implemented for this class.")

explain_dataloader

explain_dataloader(dataloader, alpha, alpha_s, alpha_k, log_prob_threshold, epochs=1000, lr=0.0005, patience=100, patience_eps=0.001)

Trains the model for a specified number of epochs.

Source code in counterfactuals/cf_methods/group_methods/group_ppcef/rppcef.py
def explain_dataloader(
    self,
    dataloader: DataLoader,
    alpha: int,
    alpha_s: int,
    alpha_k: int,
    log_prob_threshold: float,
    epochs: int = 1000,
    lr: float = 0.0005,
    patience: int = 100,
    patience_eps: int = 1e-3,
):
    """
    Trains the model for a specified number of epochs.
    """
    self.loss_components_logs = {}
    self.gen_model.eval()
    for param in self.gen_model.parameters():
        param.requires_grad = False

    if self.disc_model:
        self.disc_model.eval()
        for param in self.disc_model.parameters():
            param.requires_grad = False

    target_class = []
    original = []
    original_class = []
    for xs_origin, contexts_origin in dataloader:
        xs_origin = xs_origin.to(self.device)
        contexts_origin = contexts_origin.to(self.device)

        contexts_origin = contexts_origin.reshape(-1, 1)
        contexts_target = torch.abs(1 - contexts_origin)

        xs_origin = torch.as_tensor(xs_origin)
        xs_origin.requires_grad = False

        optimizer = torch.optim.Adam(self.delta.parameters(), lr=lr)

        min_loss = float("inf")
        dist_flag = False

        for epoch in (epoch_pbar := tqdm(range(epochs), dynamic_ncols=True)):
            optimizer.zero_grad()
            loss_components = self._search_step(
                self.delta,
                xs_origin,
                contexts_origin,
                contexts_target,
                alpha=alpha,
                alpha_s=alpha_s,
                alpha_k=alpha_k,
                log_prob_threshold=log_prob_threshold,
            )
            mean_loss = loss_components["loss"].mean()
            mean_loss.backward()
            optimizer.step()

            self._log_loss_components(loss_components)

            loss = loss_components["loss"].detach().cpu().mean().item()
            # Progress bar description
            epoch_pbar.set_description(
                ", ".join(
                    [
                        f"{k}: {v.detach().cpu().mean().item():.4f}"
                        for k, v in loss_components.items()
                    ]
                )
            )
            # Early stopping handling
            if (loss < (min_loss - patience_eps)) or (epoch < 1000):
                min_loss = loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter > patience:
                    if not dist_flag:
                        patience_counter = 0
                        dist_flag = True
                    else:
                        break

        original.append(xs_origin.detach().cpu().numpy())
        original_class.append(contexts_origin.detach().cpu().numpy())
        target_class.append(contexts_target.detach().cpu().numpy())

    x_origs = np.concatenate(original, axis=0)
    y_origs = np.concatenate(original_class, axis=0)
    y_target = np.concatenate(target_class, axis=0)
    # x_cfs = x_origs + self.delta().detach().numpy()
    return self.delta, x_origs, y_origs, y_target