.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_pruning.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code or to run this example in your browser via JupyterLite. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_pruning.py: ============ Data pruning ============ .. currentmodule:: fastcan This example shows how to prune dataset with :func:`minibatch` based on :class:`FastCan`. The method is compared to random data pruning. .. GENERATED FROM PYTHON SOURCE LINES 12-16 .. code-block:: Python # Authors: The fastcan developers # SPDX-License-Identifier: MIT .. GENERATED FROM PYTHON SOURCE LINES 17-26 Load data and prepare baseline ------------------------------ We use ``iris`` dataset and logistic regression model to demonstrate data pruning. The baseline model is a logistic regression model trained on the entire dataset. Here, 110 samples are used as the training data, which is intentionally made imbalanced, to test data pruning methods. The coefficients of the model trained on the pruned dataset will be compared to the baseline model with R-squared score. The higher R-squared score, the better the pruning. .. GENERATED FROM PYTHON SOURCE LINES 26-35 .. code-block:: Python from sklearn.datasets import load_iris from sklearn.linear_model import LogisticRegression iris = load_iris() baseline_lr = LogisticRegression(max_iter=1000).fit(iris["data"], iris["target"]) X_train = iris["data"][10:120] y_train = iris["target"][10:120] .. GENERATED FROM PYTHON SOURCE LINES 36-41 Random data pruning ------------------- There are 110 samples in the training dataset. The random pruning method selected samples assuming a uniform distribution over all data. .. GENERATED FROM PYTHON SOURCE LINES 41-52 .. code-block:: Python import numpy as np def _random_pruning(X, y, n_samples_to_select: int, random_state: int): rng = np.random.default_rng(random_state) ids_random = rng.choice(y.size, n_samples_to_select, replace=False) pruned_lr = LogisticRegression(max_iter=1000).fit(X[ids_random], y[ids_random]) return pruned_lr.coef_, pruned_lr.intercept_ .. GENERATED FROM PYTHON SOURCE LINES 53-60 FastCan data pruning -------------------- To use :class:`FastCan` to prune the data, there are two steps: #. Learn the atoms by Dictionary Learning (here we use ``KMeans``) #. Select the samples by :func:`minibatch` according to the multiple correlation between each atom and the batch of samples. .. GENERATED FROM PYTHON SOURCE LINES 60-86 .. code-block:: Python from sklearn.cluster import KMeans from fastcan import minibatch def _fastcan_pruning( X, y, n_samples_to_select: int, random_state: int, n_atoms: int, batch_size: int, ): kmeans = KMeans( n_clusters=n_atoms, random_state=random_state, ).fit(X) atoms = kmeans.cluster_centers_ ids_fastcan = minibatch( X.T, atoms.T, n_samples_to_select, batch_size=batch_size, verbose=0 ) pruned_lr = LogisticRegression(max_iter=1000).fit(X[ids_fastcan], y[ids_fastcan]) return pruned_lr.coef_, pruned_lr.intercept_ .. GENERATED FROM PYTHON SOURCE LINES 87-99 Visualize selected samples -------------------------------------------------- Use principal component analysis (PCA) to visualize the distribution of the samples, and to compare the difference between the selection of ``Random`` pruning and ``FastCan`` pruning. For clearer viewing of the selection, only 10 samples are selected from the training data by the pruning methods. The results show that ``FastCan`` selects 3 setosa, 4 versicolor, and 3 virginica, while ``Random`` select 6, 2, and 2, respectively. The imbalanced selection of ``Random`` is caused by the imbalanced training data, while ``FastCan``, benefited from the dictionary learning (k-means), can overcome the imbalance issue. .. GENERATED FROM PYTHON SOURCE LINES 99-138 .. code-block:: Python import matplotlib.pyplot as plt from sklearn.decomposition import PCA def plot_pca(X, y, target_names, n_samples_to_select, random_state): pca = PCA(2).fit(X) pcs_all = pca.transform(X) kmeans = KMeans( n_clusters=10, random_state=random_state, ).fit(X) atoms = kmeans.cluster_centers_ pcs_atoms = pca.transform(atoms) ids_fastcan = minibatch(X.T, atoms.T, n_samples_to_select, batch_size=1, verbose=0) pcs_fastcan = pca.transform(X[ids_fastcan]) rng = np.random.default_rng(random_state) ids_random = rng.choice(X.shape[0], n_samples_to_select, replace=False) pcs_random = pca.transform(X[ids_random]) plt.scatter(pcs_fastcan[:, 0], pcs_fastcan[:, 1], s=50, marker="o", label="FastCan") plt.scatter(pcs_random[:, 0], pcs_random[:, 1], s=50, marker="*", label="Random") plt.scatter(pcs_atoms[:, 0], pcs_atoms[:, 1], s=100, marker="+", label="Atoms") cmap = plt.get_cmap("Dark2") for i, label in enumerate(target_names): mask = y == i plt.scatter( pcs_all[mask, 0], pcs_all[mask, 1], s=5, label=label, color=cmap(i + 2) ) plt.xlabel("The First Principle Component") plt.ylabel("The Second Principle Component") plt.legend(ncol=2) plot_pca(X_train, y_train, iris.target_names, 10, 123) .. image-sg:: /auto_examples/images/sphx_glr_plot_pruning_001.png :alt: plot pruning :srcset: /auto_examples/images/sphx_glr_plot_pruning_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 139-144 Compare pruning methods ----------------------- 80 samples are selected from 110 training data with ``Random`` pruning and ``FastCan`` pruning. The results show that ``FastCan`` pruning gives a higher median value of R-squared and a lower standard deviation. .. GENERATED FROM PYTHON SOURCE LINES 144-171 .. code-block:: Python from sklearn.metrics import r2_score def plot_box(X, y, baseline, n_samples_to_select: int, n_random: int): r2_fastcan = np.zeros(n_random) r2_random = np.zeros(n_random) for i in range(n_random): coef, intercept = _fastcan_pruning( X, y, n_samples_to_select, i, n_atoms=40, batch_size=2 ) r2_fastcan[i] = r2_score( np.c_[coef, intercept], np.c_[baseline.coef_, baseline.intercept_] ) coef, intercept = _random_pruning(X, y, n_samples_to_select, i) r2_random[i] = r2_score( np.c_[coef, intercept], np.c_[baseline.coef_, baseline.intercept_] ) plt.boxplot(np.c_[r2_fastcan, r2_random]) plt.ylabel("R2") plt.xticks(ticks=[1, 2], labels=["FastCan", "Random"]) plt.show() plot_box(X_train, y_train, baseline_lr, n_samples_to_select=80, n_random=100) .. image-sg:: /auto_examples/images/sphx_glr_plot_pruning_002.png :alt: plot pruning :srcset: /auto_examples/images/sphx_glr_plot_pruning_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 2.392 seconds) .. _sphx_glr_download_auto_examples_plot_pruning.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: lite-badge .. image:: images/jupyterlite_badge_logo.svg :target: ../lite/lab/index.html?path=auto_examples/plot_pruning.ipynb :alt: Launch JupyterLite :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_pruning.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_pruning.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_pruning.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_