Skip to content

CCHVAE

Conditional Counterfactual Hierarchical VAE

CCHVAE uses a hierarchical variational autoencoder for counterfactual generation.

Overview

CCHVAE learns a latent representation that enables generating plausible counterfactuals through latent space traversal.

Usage

from counterfactuals.cf_methods.local_methods.c_chvae import CCHVAE

method = CCHVAE(
    mlmodel=ml_model,
    hyperparams=hyperparams
)

result = method.get_counterfactuals(
    factuals=factuals
)

API Reference

CCHVAE

CCHVAE(mlmodel, hyperparams=None)

Implementation of CCHVAE.

This class implements the Counterfactuals via Conditional Variational Autoencoders (CCHVAE) method for generating model-agnostic counterfactual explanations for tabular data, following Pawelczyk et al. (2020).

Parameters:

Name Type Description Default
mlmodel MLModel

Black-box model wrapper used for prediction and data access.

required
hyperparams Dict

Dictionary of hyperparameters. See Notes for details.

None
Notes

Hyperparameters (hyperparams) control initialization and search behavior:

  • "data_name" (str): Name of the dataset.
  • "n_search_samples" (int, default: 300): Number of candidate counterfactuals sampled per iteration.
  • "p_norm" (int in {1, 2}): L_p norm used for distance calculation.
  • "step" (float, default: 0.1): Step size for expanding the search radius.
  • "max_iter" (int, default: 2000): Maximum iterations per factual instance.
  • "clamp" (bool, default: True): If True, feature values are clamped to [0, 1].
  • "binary_cat_features" (bool, default: True): If True, categorical encoding uses drop-if-binary.
  • "vae_params" (Dict): Parameters for the VAE:
  • "layers" (List[int]): Number of neurons per layer.
  • "train" (bool, default: True): Whether to train a new VAE.
  • "kl_weight" (float, default: 0.3): KL divergence weight for the VAE loss.
  • "lambda_reg" (float, default: 1e-6): Regularization weight for VAE.
  • "epochs" (int, default: 5): Training epochs for the VAE.
  • "lr" (float, default: 1e-3): Learning rate for the VAE optimizer.
  • "batch_size" (int, default: 32): Batch size for VAE training.
References

Pawelczyk, M., Broelemann, K., & Kasneci, G. (2020). Learning Model-Agnostic Counterfactual Explanations for Tabular Data. In Proceedings of The Web Conference 2020.

Parameters:

Name Type Description Default
mlmodel MLModel

Model wrapper providing prediction and dataset utilities.

required
hyperparams Dict

Hyperparameter dictionary

None

Raises:

Type Description
ValueError

If the provided model backend is unsupported.

FileNotFoundError

If VAE loading is requested but the model file is missing.

Source code in counterfactuals/cf_methods/local_methods/c_chvae/c_chvae.py
def __init__(self, mlmodel: MLModel, hyperparams: Dict = None) -> None:
    """Initializes the CCHVAE method.

    Args:
      mlmodel: Model wrapper providing prediction and dataset utilities.
      hyperparams: Hyperparameter dictionary

    Raises:
      ValueError: If the provided model backend is unsupported.
      FileNotFoundError: If VAE loading is requested but the model file is missing.
    """
    supported_backends = ["pytorch"]
    if mlmodel.backend not in supported_backends:
        raise ValueError(f"{mlmodel.backend} is not in supported backends {supported_backends}")

    self._mlmodel = mlmodel
    self._params = hyperparams

    self._n_search_samples = self._params["n_search_samples"]
    self._p_norm = self._params["p_norm"]
    self._step = self._params["step"]
    self._max_iter = self._params["max_iter"]
    self._clamp = self._params["clamp"]

    vae_params = self._params["vae_params"]
    self._generative_model = self._load_vae(
        self._mlmodel.data.df, vae_params, self._mlmodel, self._params["data_name"]
    )

get_counterfactuals

get_counterfactuals(factuals)

Generates counterfactuals for the given factual instances with validation.

This method applies the internal search for each factual row, checks the validity of found counterfactuals, and returns them in the original feature order of the model.

Parameters:

Name Type Description Default
factuals DataFrame

DataFrame of factual instances.

required

Returns:

Type Description
DataFrame

DataFrame containing validated counterfactual instances aligned to factuals.

Source code in counterfactuals/cf_methods/local_methods/c_chvae/c_chvae.py
def get_counterfactuals(self, factuals: pd.DataFrame) -> pd.DataFrame:
    """Generates counterfactuals for the given factual instances with validation.

    This method applies the internal search for each factual row, checks the
    validity of found counterfactuals, and returns them in the original feature
    order of the model.

    Args:
      factuals: DataFrame of factual instances.

    Returns:
      DataFrame containing validated counterfactual instances aligned to `factuals`.
    """
    factuals = self._mlmodel.get_ordered_features(factuals)

    encoded_feature_names = self._mlmodel.data.categorical
    cat_features_indices = [
        factuals.columns.get_loc(feature) for feature in encoded_feature_names
    ]

    df_cfs = factuals.apply(
        lambda x: self._counterfactual_search(
            self._step, x.reshape((1, -1)), cat_features_indices
        ),
        raw=True,
        axis=1,
    )

    df_cfs = check_counterfactuals(self._mlmodel, df_cfs, factuals.index)
    df_cfs = self._mlmodel.get_ordered_features(df_cfs)
    return df_cfs

get_counterfactuals_without_check

get_counterfactuals_without_check(factuals)

Generates counterfactuals without running the post-hoc validity checks.

This is similar to get_counterfactuals but skips check_counterfactuals, returning the raw counterfactual outputs projected back to the model's original feature order.

Parameters:

Name Type Description Default
factuals DataFrame

DataFrame of factual instances.

required

Returns:

Type Description
DataFrame

DataFrame containing counterfactual instances aligned to factuals.

Source code in counterfactuals/cf_methods/local_methods/c_chvae/c_chvae.py
def get_counterfactuals_without_check(self, factuals: pd.DataFrame) -> pd.DataFrame:
    """Generates counterfactuals without running the post-hoc validity checks.

    This is similar to `get_counterfactuals` but skips `check_counterfactuals`,
    returning the raw counterfactual outputs projected back to the model's
    original feature order.

    Args:
      factuals: DataFrame of factual instances.

    Returns:
      DataFrame containing counterfactual instances aligned to `factuals`.
    """
    factuals = self._mlmodel.get_ordered_features(factuals)

    encoded_feature_names = self._mlmodel.data.categorical
    cat_features_indices = [
        factuals.columns.get_loc(feature) for feature in encoded_feature_names
    ]

    df_cfs = factuals.apply(
        lambda x: self._counterfactual_search(
            self._step, x.reshape((1, -1)), cat_features_indices
        ),
        raw=True,
        axis=1,
    )

    df_cfs = self._mlmodel.get_ordered_features(df_cfs)
    return df_cfs