Fisher’s criterion in LDA#

In this examples, we will demonstrate the canonical correlation coefficient between the features X and the one-hot encoded target y has equivalent relationship with Fisher’s criterion in LDA (Linear Discriminant Analysis).

# Authors: The fastcan developers
# SPDX-License-Identifier: MIT

Prepare data#

We use iris dataset and transform this multiclass data to multilabel data by one-hot encoding. Here, drop=”first” is necessary; otherwise, the transformed target is not full column rank.

from sklearn import datasets
from sklearn.preprocessing import OneHotEncoder

X, y = datasets.load_iris(return_X_y=True)
# drop="first" is necessary; otherwise, the transformed target is not full column rank
y_enc = OneHotEncoder(
    drop="first",
    sparse_output=False,
).fit_transform(y.reshape(-1, 1))

Compute Fisher’s criterion#

The intermediate product of LinearDiscriminantAnalysis in sklearn is Fisher’s criterion, when solver="eigen". However, it does not provide an interface to export it, so we reproduce it manually.

import numpy as np
from scipy import linalg
from sklearn.covariance import empirical_covariance
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

clf = LinearDiscriminantAnalysis(solver="eigen").fit(X, y)
Sw = clf.covariance_  # within scatter
St = empirical_covariance(X)  # total scatter
Sb = St - Sw  # between scatter
fishers_criterion, _ = linalg.eigh(Sb, Sw)

fishers_criterion = np.sort(fishers_criterion)[::-1]
n_nonzero = min(X.shape[1], clf.classes_.shape[0] - 1)
# remove the eigenvalues which are close to zero
fishers_criterion = fishers_criterion[:n_nonzero]
# get canonical correlation coefficients from convert Fisher's criteria
r2 = fishers_criterion / (1 + fishers_criterion)

Compute SSC#

Compute the sum of squared canonical correlation coefficients (SSC). It can be found that the result obtained by FastCan/CCA (Canonical Correlation Analysis) is the same as LDA.

from fastcan import FastCan

ssc = FastCan(4, verbose=0).fit(X, y_enc).scores_.sum()

print(f"SSC from LDA: {r2.sum():5f}")
print(f"SSC from CCA: {ssc:5f}")
SSC from LDA: 1.191899
SSC from CCA: 1.191899

Total running time of the script: (0 minutes 0.066 seconds)

Gallery generated by Sphinx-Gallery