DICE¶
Diverse Counterfactual Explanations
DICE is a popular method for generating diverse counterfactual explanations, integrated via the dice-ml library.
Overview¶
DICE generates multiple diverse counterfactuals by optimizing for both validity and diversity simultaneously.
Usage¶
from counterfactuals.cf_methods.local_methods import DICE
method = DICE(
gen_model=gen_model,
disc_model=classifier,
disc_model_criterion=criterion,
device="cuda"
)
result = method.explain(
X=instance,
y_origin=0,
y_target=1,
X_train=X_train,
y_train=y_train
)
Parameters¶
| Parameter | Type | Default | Description |
|---|---|---|---|
disc_model |
BaseClassifier | required | Trained classifier |
n_counterfactuals |
int | 5 | Number of CFs to generate |
References¶
- Mothilal et al., "Explaining Machine Learning Classifiers through Diverse Counterfactual Explanations"
API Reference¶
DICE
¶
Bases: BaseCounterfactualMethod, LocalCounterfactualMixin
An interface class to different DiCE implementations.
Source code in counterfactuals/cf_methods/local_methods/dice/dice.py
explain
¶
explain(Xs, ys, total_CFs=1, desired_class='opposite', desired_range=None, permitted_range=None, features_to_vary='all', stopping_threshold=0.5, posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm='linear', verbose=False, **kwargs)
General method for generating counterfactuals.
:param query_instances: Input point(s) for which counterfactuals are to be generated. This can be a dataframe with one or more rows. :param total_CFs: Total number of counterfactuals required. :param desired_class: Desired counterfactual class - can take 0 or 1. Default value is "opposite" to the outcome class of query_instance for binary classification. :param desired_range: For regression problems. Contains the outcome range to generate counterfactuals in. This should be a list of two numbers in ascending order. :param permitted_range: Dictionary with feature names as keys and permitted range in list as values. Defaults to the range inferred from training data. If None, uses the parameters initialized in data_interface. :param features_to_vary: Either a string "all" or a list of feature names to vary. :param stopping_threshold: Minimum threshold for counterfactuals target class probability. :param proximity_weight: A positive float. Larger this weight, more close the counterfactuals are to the query_instance. Used by ['genetic', 'gradientdescent'], ignored by ['random', 'kdtree'] methods. :param sparsity_weight: A positive float. Larger this weight, less features are changed from the query_instance. Used by ['genetic', 'kdtree'], ignored by ['random', 'gradientdescent'] methods. :param diversity_weight: A positive float. Larger this weight, more diverse the counterfactuals are. Used by ['genetic', 'gradientdescent'], ignored by ['random', 'kdtree'] methods. :param categorical_penalty: A positive float. A weight to ensure that all levels of a categorical variable sums to 1. Used by ['genetic', 'gradientdescent'], ignored by ['random', 'kdtree'] methods. :param posthoc_sparsity_param: Parameter for the post-hoc operation on continuous features to enhance sparsity. :param posthoc_sparsity_algorithm: Perform either linear or binary search. Takes "linear" or "binary". Prefer binary search when a feature range is large (for instance, income varying from 10k to 1000k) and only if the features share a monotonic relationship with predicted outcome in the model. :param verbose: Whether to output detailed messages. :param sample_size: Sampling size :param random_seed: Random seed for reproducibility :param kwargs: Other parameters accepted by specific explanation method
:returns: A CounterfactualExplanations object that contains the list of counterfactual examples per query_instance as one of its attributes.
Source code in counterfactuals/cf_methods/local_methods/dice/dice.py
explain_dataloader
¶
explain_dataloader(dataloader, total_CFs=1, desired_class='opposite', desired_range=None, permitted_range=None, features_to_vary='all', stopping_threshold=0.5, posthoc_sparsity_param=0.1, posthoc_sparsity_algorithm='linear', verbose=False, **kwargs)
General method for generating counterfactuals.
:param query_instances: Input point(s) for which counterfactuals are to be generated. This can be a dataframe with one or more rows. :param total_CFs: Total number of counterfactuals required. :param desired_class: Desired counterfactual class - can take 0 or 1. Default value is "opposite" to the outcome class of query_instance for binary classification. :param desired_range: For regression problems. Contains the outcome range to generate counterfactuals in. This should be a list of two numbers in ascending order. :param permitted_range: Dictionary with feature names as keys and permitted range in list as values. Defaults to the range inferred from training data. If None, uses the parameters initialized in data_interface. :param features_to_vary: Either a string "all" or a list of feature names to vary. :param stopping_threshold: Minimum threshold for counterfactuals target class probability. :param proximity_weight: A positive float. Larger this weight, more close the counterfactuals are to the query_instance. Used by ['genetic', 'gradientdescent'], ignored by ['random', 'kdtree'] methods. :param sparsity_weight: A positive float. Larger this weight, less features are changed from the query_instance. Used by ['genetic', 'kdtree'], ignored by ['random', 'gradientdescent'] methods. :param diversity_weight: A positive float. Larger this weight, more diverse the counterfactuals are. Used by ['genetic', 'gradientdescent'], ignored by ['random', 'kdtree'] methods. :param categorical_penalty: A positive float. A weight to ensure that all levels of a categorical variable sums to 1. Used by ['genetic', 'gradientdescent'], ignored by ['random', 'kdtree'] methods. :param posthoc_sparsity_param: Parameter for the post-hoc operation on continuous features to enhance sparsity. :param posthoc_sparsity_algorithm: Perform either linear or binary search. Takes "linear" or "binary". Prefer binary search when a feature range is large (for instance, income varying from 10k to 1000k) and only if the features share a monotonic relationship with predicted outcome in the model. :param verbose: Whether to output detailed messages. :param sample_size: Sampling size :param random_seed: Random seed for reproducibility :param kwargs: Other parameters accepted by specific explanation method
:returns: A CounterfactualExplanations object that contains the list of counterfactual examples per query_instance as one of its attributes.