Skip to content

CET

Counterfactual Explanation Trees

CET uses tree structures to generate interpretable counterfactual explanations.

Overview

CET builds decision trees that guide the counterfactual generation process.

Usage

from counterfactuals.cf_methods.local_methods.cet import CounterfactualExplanationTree

method = CounterfactualExplanationTree(
    gen_model=gen_model,
    disc_model=classifier,
    disc_model_criterion=criterion,
    device="cuda"
)

result = method.explain(
    X=instance,
    y_origin=0,
    y_target=1,
    X_train=X_train,
    y_train=y_train
)

API Reference

CounterfactualExplanationTree

CounterfactualExplanationTree(mdl, X, Y=[], max_iteration=1000, max_depth=3, min_samples_leaf=1, remain_redundant_leaf=False, max_candidates=50, tol=1e-06, use_mined_rules=False, minsup=0.5, discretization_bins=5, lime_approximation=False, n_samples=10000, alpha=1.0, feature_names=[], feature_types=[], feature_categories=[], feature_constraints=[], target_name='Output', target_labels=['Good', 'Bad'], device=None, **kwargs)

Bases: BaseCounterfactualMethod, LocalCounterfactualMixin

Source code in counterfactuals/cf_methods/local_methods/cet/cet.py
def __init__(
    self,
    mdl,
    X,
    Y=[],
    max_iteration=1000,
    max_depth=3,
    min_samples_leaf=1,
    remain_redundant_leaf=False,
    max_candidates=50,
    tol=1e-6,
    use_mined_rules=False,
    minsup=0.5,
    discretization_bins=5,
    lime_approximation=False,
    n_samples=10000,
    alpha=1.0,
    feature_names=[],
    feature_types=[],
    feature_categories=[],
    feature_constraints=[],
    target_name="Output",
    target_labels=["Good", "Bad"],
    device: str | None = None,
    **kwargs,
):
    # Initialize base/mixin behavior
    super().__init__(disc_model=mdl, device=device)
    self.mdl_ = mdl
    self.extractor_ = ActionExtractor(
        mdl,
        X,
        Y=Y,
        feature_names=feature_names,
        feature_types=feature_types,
        feature_categories=feature_categories,
        feature_constraints=feature_constraints,
        max_candidates=max_candidates,
        tol=tol,
        target_name=target_name,
        target_labels=target_labels,
        lime_approximation=lime_approximation,
        n_samples=n_samples,
        alpha=alpha,
    )
    self.cost_ = Cost(
        X,
        Y,
        feature_types=feature_types,
        feature_categories=feature_categories,
        feature_constraints=feature_constraints,
        max_candidates=max_candidates,
        tol=tol,
    )

    self.lime_approximation_ = lime_approximation
    if lime_approximation:
        self.lime_ = LimeEstimator(
            mdl,
            X,
            n_samples=n_samples,
            feature_types=feature_types,
            feature_categories=feature_categories,
            alpha=alpha,
        )

    self.max_iteration_ = max_iteration
    self.max_depth_ = max_depth
    self.min_samples_leaf_ = min_samples_leaf
    self.remain_redundant_leaf_ = remain_redundant_leaf
    self.feature_names_ = (
        feature_names
        if len(feature_names) == X.shape[1]
        else ["x_{}".format(d) for d in range(X.shape[1])]
    )
    self.feature_types_ = (
        feature_types if len(feature_types) == X.shape[1] else ["C" for d in range(X.shape[1])]
    )
    self.feature_categories_ = feature_categories
    self.feature_categories_flatten_ = flatten(feature_categories)
    self.feature_constraints_ = (
        feature_constraints
        if len(feature_constraints) == X.shape[1]
        else ["" for d in range(X.shape[1])]
    )
    self.target_name_ = target_name
    self.target_labels_ = target_labels
    self.tol_ = tol
    self.infeasible_ = False
    self.feature_categories_inv_ = []
    for d in range(X.shape[1]):
        g = -1
        if self.feature_types_[d] == "B":
            for i, cat in enumerate(self.feature_categories_):
                if d in cat:
                    g = i
                    break
        self.feature_categories_inv_.append(g)

    if use_mined_rules:
        self.discretizer_ = FrequentRuleMiner(minsup=minsup, discretization=True)
        self.discretizer_ = self.discretizer_.fit(
            X,
            feature_names=feature_names,
            feature_types=feature_types,
            discretization_bins=discretization_bins,
        )
        self.rule_names_ = self.discretizer_.rule_names_
        self.R_ = len(self.rule_names_)
        self.rule_length_ = self.discretizer_.L_
    else:
        self.discretizer_ = FeatureDiscretizer(bins=discretization_bins, onehot=False)
        self.discretizer_ = self.discretizer_.fit(
            X, feature_names=feature_names, feature_types=feature_types
        )
        self.rule_names_ = self.discretizer_.feature_names
        self.R_ = len(self.rule_names_)
        self.rule_length_ = np.ones(self.R_)
    self.rule_probability_ = (1 / self.rule_length_) / (1 / self.rule_length_).sum()

explain

explain(X, y_origin=None, y_target=None, X_train=None, **kwargs)

Wrapper to produce ExplanationResult for compatibility.

This uses the existing get_counterfactuals method and returns an ExplanationResult dataclass.

Source code in counterfactuals/cf_methods/local_methods/cet/cet.py
def explain(
    self,
    X: np.ndarray,
    y_origin: np.ndarray | None = None,
    y_target: np.ndarray | None = None,
    X_train: np.ndarray | None = None,
    **kwargs,
) -> ExplanationResult:
    """Wrapper to produce ExplanationResult for compatibility.

    This uses the existing get_counterfactuals method and returns an
    ExplanationResult dataclass.
    """
    # If training data provided, fit (keeps backward compatibility)
    if X_train is not None:
        self.fit(X_train)

    x_cfs = self.get_counterfactuals(X)
    y_target_arr = (
        np.array(y_target) if y_target is not None else np.zeros((X.shape[0],), dtype=int)
    )
    y_origin_arr = (
        np.array(y_origin) if y_origin is not None else np.zeros((X.shape[0],), dtype=int)
    )

    return ExplanationResult(
        x_cfs=np.array(x_cfs),
        y_cf_targets=y_target_arr,
        x_origs=np.array(X),
        y_origs=y_origin_arr,
        logs=None,
    )

get_counterfactuals

get_counterfactuals(X, return_costs=False)

Get counterfactual examples for input instances.

Parameters

X : array-like of shape (n_samples, n_features) Input instances to generate counterfactuals for return_costs : bool, default=False Whether to return the costs associated with each counterfactual

Returns

counterfactuals : array-like of shape (n_samples, n_features) Counterfactual examples costs : array-like of shape (n_samples,), optional Costs associated with each counterfactual if return_costs=True

Source code in counterfactuals/cf_methods/local_methods/cet/cet.py
def get_counterfactuals(self, X, return_costs=False):
    """
    Get counterfactual examples for input instances.

    Parameters
    ----------
    X : array-like of shape (n_samples, n_features)
        Input instances to generate counterfactuals for
    return_costs : bool, default=False
        Whether to return the costs associated with each counterfactual

    Returns
    -------
    counterfactuals : array-like of shape (n_samples, n_features)
        Counterfactual examples
    costs : array-like of shape (n_samples,), optional
        Costs associated with each counterfactual if return_costs=True
    """
    actions = self.predict(X)
    counterfactuals = X + actions

    if return_costs:
        # Compute costs for each counterfactual
        costs = np.array(
            [self.cost_.compute(x, a, cost_type=self.cost_type_) for x, a in zip(X, actions)]
        )
        return counterfactuals, costs

    return counterfactuals