Note
Click here to download the full example code
Plot confusion matrix
This example illustrates how to use the
imbalanced_ensemble.visualizer
module to plot confusion
matrix for imbalanced_ensemble.ensemble
classifier(s).
This example uses:
# Authors: Zhining Liu <zhining.liu@outlook.com>
# License: MIT
print(__doc__)
from time import time
# Import imbalanced_ensemble
import imbalanced_ensemble as imbens
# Import utilities from sklearn
import sklearn
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
RANDOM_STATE = 42
# sphinx_gallery_thumbnail_number = 4
Prepare data
Make a toy 3-class imbalanced classification task.
# make dataset
X, y = make_classification(n_classes=3, class_sep=2,
weights=[0.1, 0.3, 0.6], n_informative=3, n_redundant=1, flip_y=0,
n_features=20, n_clusters_per_class=2, n_samples=2000, random_state=0)
# train valid split
X_train, X_valid, y_train, y_valid = train_test_split(
X, y, test_size=0.5, stratify=y, random_state=RANDOM_STATE)
Train ensemble classifiers
4 different ensemble classifiers are used.
init_kwargs = {'n_estimators': 50, 'random_state': RANDOM_STATE}
fit_kwargs = {'X': X_train, 'y': y_train}
# imbalanced_ensemble.ensemble classifiers
ensemble_dict = {
'SPE': imbens.ensemble.SelfPacedEnsembleClassifier(**init_kwargs),
'EasyEns': imbens.ensemble.EasyEnsembleClassifier(**init_kwargs),
'BalanceForest': imbens.ensemble.BalancedRandomForestClassifier(**init_kwargs),
'SMOTEBagging': imbens.ensemble.SMOTEBaggingClassifier(**init_kwargs),
}
# Train all ensemble classifiers, store the results in fitted_ensembles
fitted_ensembles = {}
for clf_name, clf in ensemble_dict.items():
start_time = time()
clf.fit(**fit_kwargs)
fit_time = time() - start_time
fitted_ensembles[clf_name] = clf
print ('Training {:^30s} | Time used: {:.3f}s'.format(clf.__name__, fit_time))
Out:
Training SelfPacedEnsembleClassifier | Time used: 0.219s
Training EasyEnsembleClassifier | Time used: 0.759s
Training BalancedRandomForestClassifier | Time used: 0.104s
Training SMOTEBaggingClassifier | Time used: 0.898s
Fit an ImbalancedEnsembleVisualizer
The visualizer fits on a dictionary
like {…, ensemble_name: ensemble_classifier, …}
The keys should be strings corresponding to ensemble names.
The values should be fitted imbalance_ensemble.ensemble
or sklearn.ensemble
estimator objects.
# Initialize visualizer
visualizer = imbens.visualizer.ImbalancedEnsembleVisualizer(
eval_datasets = {
'training' : (X_train, y_train),
'validation' : (X_valid, y_valid),
},
eval_metrics = {
'acc': (sklearn.metrics.accuracy_score, {}),
'balanced_acc': (sklearn.metrics.balanced_accuracy_score, {}),
'weighted_f1': (sklearn.metrics.f1_score, {'average':'weighted'}),
},
)
# Fit visualizer
visualizer.fit(fitted_ensembles)
Out:
0%| | 0/50 [00:00<?, ?it/s]
Visualizer evaluating model SPE on dataset training :: 0%| | 0/50 [00:00<?, ?it/s]
Visualizer evaluating model SPE on dataset training :: 100%|#############| 50/50 [00:00<00:00, 1355.91it/s]
0%| | 0/50 [00:00<?, ?it/s]
Visualizer evaluating model SPE on dataset validation :: 0%| | 0/50 [00:00<?, ?it/s]
Visualizer evaluating model SPE on dataset validation :: 100%|#############| 50/50 [00:00<00:00, 1273.97it/s]
0%| | 0/50 [00:00<?, ?it/s]
Visualizer evaluating model EasyEns on dataset training :: 0%| | 0/50 [00:00<?, ?it/s]
Visualizer evaluating model EasyEns on dataset training :: 60%|########4 | 30/50 [00:00<00:00, 195.99it/s]
Visualizer evaluating model EasyEns on dataset training :: 100%|##############| 50/50 [00:00<00:00, 146.43it/s]
Visualizer evaluating model EasyEns on dataset training :: 100%|##############| 50/50 [00:00<00:00, 134.68it/s]
0%| | 0/50 [00:00<?, ?it/s]
Visualizer evaluating model EasyEns on dataset validation :: 0%| | 0/50 [00:00<?, ?it/s]
Visualizer evaluating model EasyEns on dataset validation :: 60%|########4 | 30/50 [00:00<00:00, 193.86it/s]
Visualizer evaluating model EasyEns on dataset validation :: 100%|##############| 50/50 [00:00<00:00, 144.63it/s]
Visualizer evaluating model EasyEns on dataset validation :: 100%|##############| 50/50 [00:00<00:00, 133.35it/s]
0%| | 0/50 [00:00<?, ?it/s]
Visualizer evaluating model BalanceForest on dataset training :: 0%| | 0/50 [00:00<?, ?it/s]
Visualizer evaluating model BalanceForest on dataset training :: 100%|#############| 50/50 [00:00<00:00, 1565.04it/s]
0%| | 0/50 [00:00<?, ?it/s]
Visualizer evaluating model BalanceForest on dataset validation :: 0%| | 0/50 [00:00<?, ?it/s]
Visualizer evaluating model BalanceForest on dataset validation :: 100%|#############| 50/50 [00:00<00:00, 1566.59it/s]
0%| | 0/50 [00:00<?, ?it/s]
Visualizer evaluating model SMOTEBagging on dataset training :: 0%| | 0/50 [00:00<?, ?it/s]
Visualizer evaluating model SMOTEBagging on dataset training :: 100%|#############| 50/50 [00:00<00:00, 1222.78it/s]
0%| | 0/50 [00:00<?, ?it/s]
Visualizer evaluating model SMOTEBagging on dataset validation :: 0%| | 0/50 [00:00<?, ?it/s]
Visualizer evaluating model SMOTEBagging on dataset validation :: 100%|#############| 50/50 [00:00<00:00, 1258.26it/s]
Visualizer computing confusion matrices........ Finished!
<imbalanced_ensemble.visualizer.visualizer.ImbalancedEnsembleVisualizer object at 0x000001E199A1D9D0>
Plot confusion matrices
fig, axes = visualizer.confusion_matrix_heatmap()

False predictions only
(parameter false_pred_only
: bool)
fig, axes = visualizer.confusion_matrix_heatmap(
false_pred_only=True,
)

Select results for visualization
(parameter on_ensembles
: list of ensemble name, on_datasets
: list of dataset name)
Select: method (‘SPE’, ‘BalanceForest’), data (‘validation’)
fig, axes = visualizer.confusion_matrix_heatmap(
on_ensembles=['SPE', 'BalanceForest'],
on_datasets=['validation'],
)

Customize visual appearance
(parameter sub_figsize
: tuple, sup_title
: bool or string, kwargs of seaborn.heatmap()
)
fig, axes = visualizer.confusion_matrix_heatmap(
on_ensembles=['SPE', 'BalanceForest'],
on_datasets=['training', 'validation'],
# Customize visual appearance
sub_figsize=(4, 3.3),
sup_title='My Suptitle',
# arguments pass down to seaborn.heatmap()
cmap='YlOrRd',
cbar=True,
linewidths=10,
vmax=20,
)

Total running time of the script: ( 0 minutes 49.354 seconds)
Estimated memory usage: 34 MB