DiCoFlex
Diverse Counterfactuals with Flexible Constraints
DiCoFlex generates diverse counterfactual explanations while respecting flexible feature constraints.
Overview
DiCoFlex extends standard counterfactual generation by:
- Producing multiple diverse counterfactuals
- Supporting flexible actionability constraints
- Balancing diversity with validity
Usage
from counterfactuals.cf_methods.local_methods.DiCoFlex import DiCoFlex
method = DiCoFlex(
gen_model=gen_model,
disc_model=classifier,
disc_model_criterion=criterion,
device="cuda"
)
result = method.explain(
X=instance,
y_origin=0,
y_target=1,
X_train=X_train,
y_train=y_train
)
Parameters
| Parameter |
Type |
Default |
Description |
gen_model |
BaseGenerator |
required |
Trained generative model |
disc_model |
BaseClassifier |
required |
Trained classifier |
n_counterfactuals |
int |
5 |
Number of CFs to generate |
API Reference
DiCoFlex
DiCoFlex(gen_model, disc_model, class_to_index, mask_vectors, params, device=None)
Bases: BaseCounterfactualMethod, LocalCounterfactualMixin
Sample-based counterfactual generator backed by a conditional flow.
Source code in counterfactuals/cf_methods/local_methods/DiCoFlex/method.py
| def __init__(
self,
gen_model: GenerativePytorchMixin,
disc_model: PytorchBase,
class_to_index: Dict[int, int],
mask_vectors: list[np.ndarray],
params: DiCoFlexParams,
device: Optional[str] = None,
) -> None:
super().__init__(
gen_model=gen_model,
disc_model=disc_model,
device=device,
)
if not mask_vectors:
raise ValueError("At least one mask vector must be supplied for DiCoFlex.")
if params.mask_index >= len(mask_vectors):
raise ValueError("mask_index exceeds available mask vectors.")
if params.target_class not in class_to_index:
raise ValueError(
f"Target class {params.target_class} not observed in the training data."
)
self.class_to_index = class_to_index
self.mask_vectors = mask_vectors
self.params = params
self.device = device or "cpu"
self.gen_model.to(self.device)
self.disc_model.to(self.device)
|
explain
explain(X, y_origin, y_target, X_train=None, y_train=None, **kwargs)
Generate counterfactuals for the provided samples.
Source code in counterfactuals/cf_methods/local_methods/DiCoFlex/method.py
| def explain(
self,
X: np.ndarray,
y_origin: np.ndarray,
y_target: np.ndarray,
X_train: Optional[np.ndarray] = None,
y_train: Optional[np.ndarray] = None,
**kwargs,
) -> ExplanationResult:
"""Generate counterfactuals for the provided samples."""
x_np = np.asarray(X, dtype=np.float32)
y_origin_vec = np.asarray(y_origin).reshape(-1, 1)
y_target_vec = np.asarray(y_target).reshape(-1, 1)
(
cf_batch,
y_target_flat,
x_orig_flat,
y_origin_flat,
target_probs,
valid_mask,
log_probs,
group_ids,
) = self._sample_counterfactuals(x_np, y_origin_vec, y_target_vec)
logs = {
"sampling/mean_target_probability": float(target_probs.mean()),
"sampling/valid_ratio": float(valid_mask.mean()),
"sampling/log_prob_mean": float(log_probs.mean()),
"model_returned_mask": valid_mask.tolist(),
"cf_group_ids": group_ids.tolist(),
}
return ExplanationResult(
x_cfs=cf_batch,
y_cf_targets=y_target_flat,
x_origs=x_orig_flat,
y_origs=y_origin_flat,
logs=logs,
cf_group_ids=group_ids,
)
|
explain_dataloader
explain_dataloader(dataloader, epochs, lr, patience_eps=1e-05, **kwargs)
Adapter around explain() for DataLoader inputs.
Source code in counterfactuals/cf_methods/local_methods/DiCoFlex/method.py
| def explain_dataloader(
self,
dataloader,
epochs: int,
lr: float,
patience_eps: float = 1e-5,
**kwargs,
) -> ExplanationResult:
"""Adapter around explain() for DataLoader inputs."""
xs, ys = [], []
for batch_x, batch_y in dataloader:
xs.append(batch_x.numpy())
ys.append(batch_y.numpy())
X = np.vstack(xs)
y_origin = np.concatenate(ys)
y_target = np.full_like(y_origin, fill_value=self.params.target_class)
return self.explain(
X=X,
y_origin=y_origin,
y_target=y_target,
)
|