Skip to content

API Reference

Complete API documentation for the Counterfactuals library.

The API reference documentation is auto-generated from source code docstrings.

Core Modules

Counterfactual Methods

The main counterfactual explanation methods.

counterfactual_base

ExplanationResult dataclass

ExplanationResult(x_cfs, y_cf_targets, x_origs, y_origs, logs=None, cf_group_ids=None)

Data structure for storing the result of a counterfactual explanation.

This dataclass encapsulates all the important outputs from a counterfactual explanation process, including the generated counterfactuals, their targets, the original instances, and any additional logging information.

Attributes:

Name Type Description
x_cfs ndarray

Generated counterfactual examples.

y_cf_targets ndarray

Target labels/values for the counterfactuals.

x_origs ndarray

Original input instances.

y_origs ndarray

Original labels/values for the input instances.

logs Optional[Dict[str, Any]]

Additional logging information such as loss curves, convergence metrics, or method-specific data.

BaseCounterfactualMethod

BaseCounterfactualMethod(gen_model=None, disc_model=None, disc_model_criterion=None, device=None, **kwargs)

Bases: ABC

Abstract base class for all counterfactual explanation methods.

This class defines the interface that all counterfactual methods must implement. It provides a consistent API for fitting, explaining, and generating counterfactuals across different methodological approaches.

The class supports both individual explanations and batch processing through DataLoader objects, making it suitable for various use cases from single instance explanations to large-scale evaluations.

Attributes:

Name Type Description
gen_model

Generative model used for counterfactual generation (if applicable).

disc_model PytorchBase

Discriminative/classification model to be explained.

disc_model_criterion

Loss function for the discriminative model.

device str

Computing device ('cpu' or 'cuda') for PyTorch operations.

Parameters:

Name Type Description Default
gen_model Optional[Any]

Generative model for CF generation. Can be None for methods that don't use generative models.

None
disc_model Optional[PytorchBase]

The model to be explained. Should be a PyTorch-based model wrapped in our PytorchBase interface.

None
disc_model_criterion Optional[Any]

Loss function for the discriminative model. Required by optimization-based methods.

None
device Optional[str]

Device for computation. Defaults to 'cpu'.

None
**kwargs

Additional method-specific parameters.

{}
Source code in counterfactuals/cf_methods/counterfactual_base.py
def __init__(
    self,
    gen_model: Optional[Any] = None,
    disc_model: Optional[PytorchBase] = None,
    disc_model_criterion: Optional[Any] = None,
    device: Optional[str] = None,
    **kwargs,
) -> None:
    """
    Initialize the counterfactual method.

    Args:
        gen_model (Optional[Any]): Generative model for CF generation. Can be None
            for methods that don't use generative models.
        disc_model (Optional[PytorchBase]): The model to be explained. Should be
            a PyTorch-based model wrapped in our PytorchBase interface.
        disc_model_criterion (Optional[Any]): Loss function for the discriminative
            model. Required by optimization-based methods.
        device (Optional[str]): Device for computation. Defaults to 'cpu'.
        **kwargs: Additional method-specific parameters.
    """
    self.gen_model = gen_model
    self.disc_model = disc_model
    self.disc_model_criterion = disc_model_criterion
    self.device = device or "cpu"

    # Move models to device if they exist and have a .to() method
    if self.gen_model is not None and hasattr(self.gen_model, "to"):
        self.gen_model.to(self.device)
    if self.disc_model is not None and hasattr(self.disc_model, "to"):
        self.disc_model.to(self.device)

fit

fit(X_train, y_train, **kwargs)

Fit the counterfactual method on training data.

This method allows the counterfactual explanation method to learn from training data. This might involve training auxiliary models, learning data distributions, or other preparatory steps.

Parameters:

Name Type Description Default
X_train ndarray

Training features with shape (n_samples, n_features).

required
y_train ndarray

Training labels with shape (n_samples,).

required
**kwargs

Additional method-specific parameters.

{}
Source code in counterfactuals/cf_methods/counterfactual_base.py
def fit(self, X_train: np.ndarray, y_train: np.ndarray, **kwargs) -> None:
    """
    Fit the counterfactual method on training data.

    This method allows the counterfactual explanation method to learn
    from training data. This might involve training auxiliary models,
    learning data distributions, or other preparatory steps.

    Args:
        X_train (np.ndarray): Training features with shape (n_samples, n_features).
        y_train (np.ndarray): Training labels with shape (n_samples,).
        **kwargs: Additional method-specific parameters.
    """
    pass

explain abstractmethod

explain(X, y_origin, y_target, X_train=None, y_train=None, **kwargs)

Generate counterfactual explanations for given instances.

This is the core method that generates counterfactual explanations for input instances. It should return counterfactuals that, when passed through the model, produce the desired target outcomes.

Parameters:

Name Type Description Default
X ndarray

Input instances to explain with shape (n_instances, n_features).

required
y_origin ndarray

Original predictions/labels for X with shape (n_instances,).

required
y_target ndarray

Desired target predictions/labels with shape (n_instances,).

required
X_train Optional[ndarray]

Training data, if needed by the method.

None
y_train Optional[ndarray]

Training labels, if needed by the method.

None
**kwargs

Additional method-specific parameters.

{}

Returns:

Name Type Description
ExplanationResult ExplanationResult

Object containing counterfactuals, targets, originals, and any additional logging information.

Raises:

Type Description
NotImplementedError

If the method is not implemented by the subclass.

Source code in counterfactuals/cf_methods/counterfactual_base.py
@abstractmethod
def explain(
    self,
    X: np.ndarray,
    y_origin: np.ndarray,
    y_target: np.ndarray,
    X_train: Optional[np.ndarray] = None,
    y_train: Optional[np.ndarray] = None,
    **kwargs,
) -> ExplanationResult:
    """
    Generate counterfactual explanations for given instances.

    This is the core method that generates counterfactual explanations
    for input instances. It should return counterfactuals that, when
    passed through the model, produce the desired target outcomes.

    Args:
        X (np.ndarray): Input instances to explain with shape (n_instances, n_features).
        y_origin (np.ndarray): Original predictions/labels for X with shape (n_instances,).
        y_target (np.ndarray): Desired target predictions/labels with shape (n_instances,).
        X_train (Optional[np.ndarray]): Training data, if needed by the method.
        y_train (Optional[np.ndarray]): Training labels, if needed by the method.
        **kwargs: Additional method-specific parameters.

    Returns:
        ExplanationResult: Object containing counterfactuals, targets, originals,
            and any additional logging information.

    Raises:
        NotImplementedError: If the method is not implemented by the subclass.
    """
    raise NotImplementedError("Subclasses must implement the explain method")

explain_dataloader abstractmethod

explain_dataloader(dataloader, epochs, lr, patience_eps=1e-05, **search_step_kwargs)

Generate counterfactual explanations for data provided via DataLoader.

This method is designed for batch processing of counterfactual generation, particularly useful for optimization-based methods that require iterative search procedures. It processes data in batches and typically involves gradient-based optimization.

Parameters:

Name Type Description Default
dataloader DataLoader

PyTorch DataLoader containing (X, y) pairs where X are instances to explain and y are their labels.

required
epochs int

Maximum number of optimization epochs per instance.

required
lr float

Learning rate for optimization procedures.

required
patience_eps Union[float, int]

Convergence threshold. When loss drops below this value, optimization can terminate early.

1e-05
**search_step_kwargs

Additional parameters passed to the search step function, such as regularization weights, constraints, etc.

{}

Returns:

Name Type Description
ExplanationResult ExplanationResult

Object containing all generated counterfactuals, their targets, original instances, and detailed logging information including loss curves and convergence metrics.

Raises:

Type Description
NotImplementedError

If the method is not implemented by the subclass.

Source code in counterfactuals/cf_methods/counterfactual_base.py
@abstractmethod
def explain_dataloader(
    self,
    dataloader: DataLoader,
    epochs: int,
    lr: float,
    patience_eps: Union[float, int] = 1e-5,
    **search_step_kwargs,
) -> ExplanationResult:
    """
    Generate counterfactual explanations for data provided via DataLoader.

    This method is designed for batch processing of counterfactual generation,
    particularly useful for optimization-based methods that require iterative
    search procedures. It processes data in batches and typically involves
    gradient-based optimization.

    Args:
        dataloader (DataLoader): PyTorch DataLoader containing (X, y) pairs
            where X are instances to explain and y are their labels.
        epochs (int): Maximum number of optimization epochs per instance.
        lr (float): Learning rate for optimization procedures.
        patience_eps (Union[float, int]): Convergence threshold. When loss
            drops below this value, optimization can terminate early.
        **search_step_kwargs: Additional parameters passed to the search
            step function, such as regularization weights, constraints, etc.

    Returns:
        ExplanationResult: Object containing all generated counterfactuals,
            their targets, original instances, and detailed logging information
            including loss curves and convergence metrics.

    Raises:
        NotImplementedError: If the method is not implemented by the subclass.
    """
    raise NotImplementedError("Subclasses must implement the explain_dataloader method")

Datasets

Dataset loading and configuration utilities.

file_dataset

FileDataset

FileDataset(config_path, samples_keep=None)

Bases: DatasetBase

File dataset loader compatible with DatasetBase.

config_path: Path to the dataset configuration file.
dataset_name: Optional name for the dataset (used for model paths).
Source code in counterfactuals/datasets/file_dataset.py
def __init__(
    self,
    config_path: Path,
    samples_keep: Optional[int] = None,
):
    """Initializes the File dataset with OmegaConf config.
    Args:
        config_path: Path to the dataset configuration file.
        dataset_name: Optional name for the dataset (used for model paths).
    """
    super().__init__(config_path=config_path)
    self.samples_keep = samples_keep if samples_keep is not None else self.config.samples_keep
    self.initial_transform_pipeline: Optional[InitialTransformPipeline] = (
        build_initial_transform_pipeline(self.config.initial_transforms)
    )
    self.one_hot_feature_groups: dict[str, list[str]] = {}

    raw_data = self._load_csv(self.config.raw_data_path)
    context = self._apply_initial_transforms(raw_data)

    if self.samples_keep > 0 and len(context.data) > self.samples_keep:
        context.data = context.data.sample(self.samples_keep, random_state=42).reset_index(
            drop=True
        )

    self.raw_data = context.data
    self._update_metadata_from_context(context)
    self.X, self.y = self.preprocess(self.raw_data)

preprocess

preprocess(raw_data)

Preprocesses raw data into feature and target arrays.

Parameters:

Name Type Description Default
raw_data DataFrame

Raw dataset as a pandas DataFrame.

required

Returns:

Type Description
tuple[ndarray, ndarray]

Tuple (X, y) as numpy arrays.

Source code in counterfactuals/datasets/file_dataset.py
def preprocess(self, raw_data: pd.DataFrame) -> tuple[np.ndarray, np.ndarray]:
    """Preprocesses raw data into feature and target arrays.

    Args:
        raw_data: Raw dataset as a pandas DataFrame.

    Returns:
        Tuple (X, y) as numpy arrays.
    """
    data = raw_data.copy()
    if self.config.target_mapping:
        data[self.config.target] = data[self.config.target].replace(self.config.target_mapping)

    X = data[self.features].to_numpy()
    y = data[self.config.target].to_numpy()
    self.X, self.y = X, y
    return X, y

Models

Classifiers

logistic_regression

multilayer_perceptron

Generative Models

kde

Implementations of various mixture models.

GenerativeModel

Bases: ABC, Module

Base class inherited by all generative models in pytorch-generative.

Provides
  • An abstract sample() method which is implemented by subclasses that support generating samples.
  • Variables self._c, self._h, self._w which store the shape of the (first) image Tensor the model was trained with. Note that forward() must have been called at least once and the input must be an image for these variables to be available.
  • A device property which returns the device of the model's parameters.

__call__

__call__(x, *args, **kwargs)

Saves input tensor attributes so they can be accessed during sampling.

Source code in counterfactuals/models/generative/kde.py
def __call__(self, x, *args, **kwargs):
    """Saves input tensor attributes so they can be accessed during sampling."""
    if getattr(self, "_c", None) is None and x.dim() == 4:
        _, c, h, w = x.shape
        self._create_shape_buffers(c, h, w)
    return super().__call__(x, *args, **kwargs)

load_state_dict

load_state_dict(state_dict, strict=True)

Registers dynamic buffers before loading the model state.

Source code in counterfactuals/models/generative/kde.py
def load_state_dict(self, state_dict, strict=True):
    """Registers dynamic buffers before loading the model state."""
    if "_c" in state_dict and not getattr(self, "_c", None):
        c, h, w = state_dict["_c"], state_dict["_h"], state_dict["_w"]
        self._create_shape_buffers(c, h, w)
    super().load_state_dict(state_dict, strict)

Kernel

Kernel(bandwidth=1.0)

Bases: ABC, Module

Base class which defines the interface for all kernels.

Parameters:

Name Type Description Default
bandwidth

The kernel's (band)width.

1.0
Source code in counterfactuals/models/generative/kde.py
def __init__(self, bandwidth=1.0):
    """Initializes a new Kernel.

    Args:
        bandwidth: The kernel's (band)width.
    """
    super().__init__()
    self.bandwidth = bandwidth

forward abstractmethod

forward(test_Xs, train_Xs)

Computes log p(x) for each x in test_Xs given train_Xs.

Source code in counterfactuals/models/generative/kde.py
@abc.abstractmethod
def forward(self, test_Xs, train_Xs):
    """Computes log p(x) for each x in test_Xs given train_Xs."""

sample abstractmethod

sample(train_Xs)

Generates samples from the kernel distribution.

Source code in counterfactuals/models/generative/kde.py
@abc.abstractmethod
def sample(self, train_Xs):
    """Generates samples from the kernel distribution."""

ParzenWindowKernel

ParzenWindowKernel(bandwidth=1.0)

Bases: Kernel

Implementation of the Parzen window kernel.

Source code in counterfactuals/models/generative/kde.py
def __init__(self, bandwidth=1.0):
    """Initializes a new Kernel.

    Args:
        bandwidth: The kernel's (band)width.
    """
    super().__init__()
    self.bandwidth = bandwidth

GaussianKernel

GaussianKernel(bandwidth=1.0)

Bases: Kernel

Implementation of the Gaussian kernel.

Source code in counterfactuals/models/generative/kde.py
def __init__(self, bandwidth=1.0):
    """Initializes a new Kernel.

    Args:
        bandwidth: The kernel's (band)width.
    """
    super().__init__()
    self.bandwidth = bandwidth

KernelDensityEstimator

KernelDensityEstimator(train_Xs, kernel=None)

Bases: GenerativeModel

The KernelDensityEstimator model.

Parameters:

Name Type Description Default
train_Xs

The "training" data to use when estimating probabilities.

required
kernel

The kernel to place on each of the train_Xs.

None
Source code in counterfactuals/models/generative/kde.py
def __init__(self, train_Xs, kernel=None):
    """Initializes a new KernelDensityEstimator.

    Args:
        train_Xs: The "training" data to use when estimating probabilities.
        kernel: The kernel to place on each of the train_Xs.
    """
    super().__init__()
    self.kernel = kernel or GaussianKernel()
    self.train_Xs = nn.Parameter(train_Xs, requires_grad=False)
    assert len(self.train_Xs.shape) == 2, "Input cannot have more than two axes."

KDE

KDE(bandwidth=0.1, **kwargs)

Bases: PytorchBase, GenerativePytorchMixin

Source code in counterfactuals/models/generative/kde.py
def __init__(self, bandwidth: float = 0.1, **kwargs):
    super(KDE, self).__init__(None, None)
    self.bandwidth = bandwidth
    self.models = nn.ModuleDict()

forward

forward(x, context=None)

Forward pass with optional context for compatibility.

Source code in counterfactuals/models/generative/kde.py
def forward(self, x: torch.Tensor, context: torch.Tensor = None):
    """Forward pass with optional context for compatibility."""
    if context is None:
        # If no context provided, try to use all available models
        # This is a fallback for PytorchBase compatibility
        if len(self.models) == 1:
            model = next(iter(self.models.values()))
            return model(x).view(-1)
        else:
            raise ValueError("Context must be provided when multiple models exist")

    preds = torch.zeros_like(context, dtype=torch.float32)
    for i in range(x.shape[0]):
        model = self._get_model_for_context(context[i].item())
        preds[i] = model(x[i].unsqueeze(0))
    return preds.view(-1)

predict_log_proba

predict_log_proba(X_test, context=None)

Predict log probabilities for input data.

Source code in counterfactuals/models/generative/kde.py
def predict_log_proba(
    self, X_test: np.ndarray, context: Optional[np.ndarray] = None
) -> np.ndarray:
    """Predict log probabilities for input data."""
    # Convert to torch tensor if needed
    assert context is not None, "Context must be provided for KDE"
    X_test = torch.from_numpy(X_test).float()
    context = torch.from_numpy(context).float()
    preds = torch.zeros_like(context, dtype=torch.float32)
    for i in range(X_test.shape[0]):
        model = self._get_model_for_context(context[i].item())
        preds[i] = model(X_test[i].unsqueeze(0))
    return preds.cpu().numpy()

sample_and_log_proba

sample_and_log_proba(n_samples, context=None)

Sample from KDE and return log probabilities.

Source code in counterfactuals/models/generative/kde.py
def sample_and_log_proba(self, n_samples: int, context: Optional[np.ndarray] = None) -> tuple:
    """Sample from KDE and return log probabilities."""
    raise NotImplementedError("Sampling from KDE is not implemented")

auto_reshape

auto_reshape(fn)

Decorator which flattens image inputs and reshapes them before returning.

This is used to enable non-convolutional models to transparently work on images.

Source code in counterfactuals/models/generative/kde.py
def auto_reshape(fn):
    """Decorator which flattens image inputs and reshapes them before returning.

    This is used to enable non-convolutional models to transparently work on images.
    """

    def wrapped_fn(self, x, *args, **kwargs):
        original_shape = x.shape
        x = x.view(original_shape[0], -1)
        y = fn(self, x, *args, **kwargs)
        return y.view(original_shape)

    return wrapped_fn

For method-specific API documentation, see: