Generating Counterfactuals¶
The core workflow for generating counterfactual explanations.
Multiple Methods Available
CEL provides 17+ counterfactual methods. The example below demonstrates PPCEF,
but the same explain() interface works for all local methods. See Local Methods
for a complete list and comparison.
Basic Usage¶
This example uses PPCEF (Probabilistically Plausible Counterfactual Explanations with Flows):
from counterfactuals.cf_methods.local_methods import PPCEF
# Initialize method
method = PPCEF(
gen_model=flow,
disc_model=classifier,
disc_model_criterion=torch.nn.CrossEntropyLoss(),
device="cuda"
)
# Generate counterfactual
result = method.explain(
X=instance, # Instance to explain
y_origin=0, # Current prediction
y_target=1, # Desired prediction
X_train=X_train, # Training data (for some methods)
y_train=y_train
)
Using Different Methods¶
All local methods share the same interface. Simply change the import:
# Option 1: PPCEF (used in example above)
from counterfactuals.cf_methods.local_methods import PPCEF
method = PPCEF(gen_model=flow, disc_model=classifier, ...)
# Option 2: DICE (diverse counterfactuals)
from counterfactuals.cf_methods.local_methods import DICE
method = DICE(model=classifier, ...)
# Option 3: WACH (gradient-based)
from counterfactuals.cf_methods.local_methods import WACH
method = WACH(disc_model=classifier, ...)
# Option 4: CEM (contrastive)
from counterfactuals.cf_methods.local_methods import CEM
method = CEM(model=classifier, ...)
# All methods use the same explain() interface:
result = method.explain(X=instance, y_origin=0, y_target=1, ...)
Understanding ExplanationResult¶
from counterfactuals.cf_methods import ExplanationResult
# Result structure
result.x_cfs # Generated counterfactuals
result.y_cf_targets # Target labels
result.x_origs # Original instances
result.y_origs # Original labels
result.logs # Training logs (optional)
result.cf_group_ids # Group assignments (for group methods)
Batch Processing¶
# Create dataloader
from counterfactuals.datasets import TorchDataLoader
loader = TorchDataLoader(X_test, y_test, batch_size=32)
# Generate for multiple instances
result = method.explain_dataloader(
dataloader=loader,
epochs=100,
lr=0.01
)
Common Parameters¶
| Parameter | Description |
|---|---|
epochs |
Number of optimization iterations |
lr |
Learning rate |
alpha |
Validity loss weight |
beta |
Proximity loss weight |
K |
Number of counterfactuals per instance |
Next Steps¶
- Evaluating Results - Assess counterfactual quality