ReViCE
Regional Variant of PPCEF (Group PPCEF)
ReViCE generates counterfactuals for groups of similar instances.
Overview
ReViCE extends PPCEF to handle groups, finding counterfactual transformations that work for clusters of similar instances.
Usage
from counterfactuals.cf_methods.group_methods.group_ppcef import RPPCEF
method = RPPCEF(
gen_model=gen_model,
disc_model=classifier,
disc_model_criterion=criterion,
device="cuda"
)
result = method.explain(
X=X_test,
y_origin=y_test,
y_target=target_class,
X_train=X_train,
y_train=y_train,
n_groups=5
)
API Reference
RPPCEF
RPPCEF(cf_method_type, gen_model, disc_model, disc_model_criterion, init_cf_method_from_kmeans=False, K=None, X=None, device=None, actionable_features=None, **kwargs)
Bases: BaseCounterfactualMethod, GroupCounterfactualMixin
Source code in counterfactuals/cf_methods/group_methods/group_ppcef/rppcef.py
| def __init__(
self,
cf_method_type: str,
gen_model: GenerativePytorchMixin,
disc_model: PytorchBase,
disc_model_criterion: torch.nn.modules.loss._Loss,
init_cf_method_from_kmeans: bool = False,
K: int = None,
X: np.ndarray = None,
device: str = None,
# TODO: poprawa nazewnictwa
actionable_features: list = None,
**kwargs,
):
# Initialize mixins / base to set up models and device
super().__init__(
gen_model=gen_model,
disc_model=disc_model,
disc_model_criterion=disc_model_criterion,
device=device,
**kwargs,
)
self.actionable_features = actionable_features
# initialize delta after we know device/other attributes
self.delta = self._init_cf_method(cf_method_type, K, init_cf_method_from_kmeans, X)
self.loss_components_logs = {}
|
explain
explain(X, y_origin, y_target, X_train, y_train)
Explains the model's prediction for a given input.
Source code in counterfactuals/cf_methods/group_methods/group_ppcef/rppcef.py
| def explain(
self,
X: np.ndarray,
y_origin: np.ndarray,
y_target: np.ndarray,
X_train: np.ndarray,
y_train: np.ndarray,
):
"""
Explains the model's prediction for a given input.
"""
raise NotImplementedError("This method is not implemented for this class.")
|
explain_dataloader
explain_dataloader(dataloader, alpha, alpha_s, alpha_k, log_prob_threshold, epochs=1000, lr=0.0005, patience=100, patience_eps=0.001)
Trains the model for a specified number of epochs.
Source code in counterfactuals/cf_methods/group_methods/group_ppcef/rppcef.py
| def explain_dataloader(
self,
dataloader: DataLoader,
alpha: int,
alpha_s: int,
alpha_k: int,
log_prob_threshold: float,
epochs: int = 1000,
lr: float = 0.0005,
patience: int = 100,
patience_eps: int = 1e-3,
):
"""
Trains the model for a specified number of epochs.
"""
self.loss_components_logs = {}
self.gen_model.eval()
for param in self.gen_model.parameters():
param.requires_grad = False
if self.disc_model:
self.disc_model.eval()
for param in self.disc_model.parameters():
param.requires_grad = False
target_class = []
original = []
original_class = []
for xs_origin, contexts_origin in dataloader:
xs_origin = xs_origin.to(self.device)
contexts_origin = contexts_origin.to(self.device)
contexts_origin = contexts_origin.reshape(-1, 1)
contexts_target = torch.abs(1 - contexts_origin)
xs_origin = torch.as_tensor(xs_origin)
xs_origin.requires_grad = False
optimizer = torch.optim.Adam(self.delta.parameters(), lr=lr)
min_loss = float("inf")
dist_flag = False
for epoch in (epoch_pbar := tqdm(range(epochs), dynamic_ncols=True)):
optimizer.zero_grad()
loss_components = self._search_step(
self.delta,
xs_origin,
contexts_origin,
contexts_target,
alpha=alpha,
alpha_s=alpha_s,
alpha_k=alpha_k,
log_prob_threshold=log_prob_threshold,
)
mean_loss = loss_components["loss"].mean()
mean_loss.backward()
optimizer.step()
self._log_loss_components(loss_components)
loss = loss_components["loss"].detach().cpu().mean().item()
# Progress bar description
epoch_pbar.set_description(
", ".join(
[
f"{k}: {v.detach().cpu().mean().item():.4f}"
for k, v in loss_components.items()
]
)
)
# Early stopping handling
if (loss < (min_loss - patience_eps)) or (epoch < 1000):
min_loss = loss
patience_counter = 0
else:
patience_counter += 1
if patience_counter > patience:
if not dist_flag:
patience_counter = 0
dist_flag = True
else:
break
original.append(xs_origin.detach().cpu().numpy())
original_class.append(contexts_origin.detach().cpu().numpy())
target_class.append(contexts_target.detach().cpu().numpy())
x_origs = np.concatenate(original, axis=0)
y_origs = np.concatenate(original_class, axis=0)
y_target = np.concatenate(target_class, axis=0)
# x_cfs = x_origs + self.delta().detach().numpy()
return self.delta, x_origs, y_origs, y_target
|