Training Models¶
Learn how to train discriminative and generative models for counterfactual generation.
Discriminative Models (Classifiers)¶
MLP Classifier¶
from counterfactuals.models.classifiers import MLPClassifier
classifier = MLPClassifier(
input_dim=n_features,
hidden_dims=[128, 64],
output_dim=n_classes
)
classifier.fit(
train_loader=train_loader,
test_loader=test_loader,
epochs=100,
lr=0.001
)
Logistic Regression¶
from counterfactuals.models.classifiers import LogisticRegression
classifier = LogisticRegression(input_dim=n_features, output_dim=n_classes)
classifier.fit(train_loader, test_loader, epochs=50)
Generative Models (Flows)¶
Masked Autoregressive Flow (MAF)¶
from counterfactuals.models.generators import MaskedAutoregressiveFlow
flow = MaskedAutoregressiveFlow(
input_dim=n_features,
hidden_dims=[128, 128],
n_layers=5
)
flow.fit(
train_loader=train_loader,
test_loader=test_loader,
epochs=200,
lr=0.0001
)
Other Flows¶
- RealNVP: Affine coupling layers
- NICE: Non-volume preserving
- CNF: Continuous normalizing flows (for regression)
Saving and Loading Models¶
# Save
classifier.save("models/classifier.pt")
flow.save("models/flow.pt")
# Load
classifier.load("models/classifier.pt")
flow.load("models/flow.pt")