Linear Discriminant Analysis With Python

Author: Jason Brownlee

Linear Discriminant Analysis is a linear classification machine learning algorithm.

The algorithm involves developing a probabilistic model per class based on the specific distribution of observations for each input variable. A new example is then classified by calculating the conditional probability of it belonging to each class and selecting the class with the highest probability.

As such, it is a relatively simple probabilistic classification model that makes strong assumptions about the distribution of each input variable, although it can make effective predictions even when these expectations are violated (e.g. it fails gracefully).

In this tutorial, you will discover the Linear Discriminant Analysis classification machine learning algorithm in Python.

After completing this tutorial, you will know:

  • The Linear Discriminant Analysis is a simple linear machine learning algorithm for classification.
  • How to fit, evaluate, and make predictions with the Linear Discriminant Analysis model with Scikit-Learn.
  • How to tune the hyperparameters of the Linear Discriminant Analysis algorithm on a given dataset.

Let’s get started.

Linear Discriminant Analysis With Python

Linear Discriminant Analysis With Python
Photo by Mihai Lucîț, some rights reserved.

Tutorial Overview

This tutorial is divided into three parts; they are:

  1. Linear Discriminant Analysis
  2. Linear Discriminant Analysis With scikit-learn
  3. Tune LDA Hyperparameters

Linear Discriminant Analysis

Linear Discriminant Analysis, or LDA for short, is a classification machine learning algorithm.

It works by calculating summary statistics for the input features by class label, such as the mean and standard deviation. These statistics represent the model learned from the training data. In practice, linear algebra operations are used to calculate the required quantities efficiently via matrix decomposition.

Predictions are made by estimating the probability that a new example belongs to each class label based on the values of each input feature. The class that results in the largest probability is then assigned to the example. As such, LDA may be considered a simple application of Bayes Theorem for classification.

LDA assumes that the input variables are numeric and normally distributed and that they have the same variance (spread). If this is not the case, it may be desirable to transform the data to have a Gaussian distribution and standardize or normalize the data prior to modeling.

… the LDA classifier results from assuming that the observations within each class come from a normal distribution with a class-specific mean vector and a common variance

— Page 142, An Introduction to Statistical Learning with Applications in R, 2014.

It also assumes that the input variables are not correlated; if they are, a PCA transform may be helpful to remove the linear dependence.

… practitioners should be particularly rigorous in pre-processing data before using LDA. We recommend that predictors be centered and scaled and that near-zero variance predictors be removed.

— Page 293, Applied Predictive Modeling, 2013.

Nevertheless, the model can perform well, even when violating these expectations.

The LDA model is naturally multi-class. This means that it supports two-class classification problems and extends to more than two classes (multi-class classification) without modification or augmentation.

It is a linear classification algorithm, like logistic regression. This means that classes are separated in the feature space by lines or hyperplanes. Extensions of the method can be used that allow other shapes, like Quadratic Discriminant Analysis (QDA), which allows curved shapes in the decision boundary.

… unlike LDA, QDA assumes that each class has its own covariance matrix.

— Page 149, An Introduction to Statistical Learning with Applications in R, 2014.

Now that we are familiar with LDA, let’s look at how to fit and evaluate models using the scikit-learn library.

Linear Discriminant Analysis With scikit-learn

The Linear Discriminant Analysis is available in the scikit-learn Python machine learning library via the LinearDiscriminantAnalysis class.

The method can be used directly without configuration, although the implementation does offer arguments for customization, such as the choice of solver and the use of a penalty.

...
# create the lda model
model = LinearDiscriminantAnalysis()

We can demonstrate the Linear Discriminant Analysis method with a worked example.

First, let’s define a synthetic classification dataset.

We will use the make_classification() function to create a dataset with 1,000 examples, each with 10 input variables.

The example creates and summarizes the dataset.

# test classification dataset
from sklearn.datasets import make_classification
# define dataset
X, y = make_classification(n_samples=1000, n_features=10, n_informative=10, n_redundant=0, random_state=1)
# summarize the dataset
print(X.shape, y.shape)

Running the example creates the dataset and confirms the number of rows and columns of the dataset.

(1000, 10) (1000,)

We can fit and evaluate a Linear Discriminant Analysis model using repeated stratified k-fold cross-validation via the RepeatedStratifiedKFold class. We will use 10 folds and three repeats in the test harness.

The complete example of evaluating the Linear Discriminant Analysis model for the synthetic binary classification task is listed below.

# evaluate a lda model on the dataset
from numpy import mean
from numpy import std
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
# define dataset
X, y = make_classification(n_samples=1000, n_features=10, n_informative=10, n_redundant=0, random_state=1)
# define model
model = LinearDiscriminantAnalysis()
# define model evaluation method
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)
# evaluate model
scores = cross_val_score(model, X, y, scoring='accuracy', cv=cv, n_jobs=-1)
# summarize result
print('Mean Accuracy: %.3f (%.3f)' % (mean(scores), std(scores)))

Running the example evaluates the Linear Discriminant Analysis algorithm on the synthetic dataset and reports the average accuracy across the three repeats of 10-fold cross-validation.

Your specific results may vary given the stochastic nature of the learning algorithm. Consider running the example a few times.

In this case, we can see that the model achieved a mean accuracy of about 89.3 percent.

Mean Accuracy: 0.893 (0.033)

We may decide to use the Linear Discriminant Analysis as our final model and make predictions on new data.

This can be achieved by fitting the model on all available data and calling the predict() function passing in a new row of data.

We can demonstrate this with a complete example listed below.

# make a prediction with a lda model on the dataset
from sklearn.datasets import make_classification
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
# define dataset
X, y = make_classification(n_samples=1000, n_features=10, n_informative=10, n_redundant=0, random_state=1)
# define model
model = LinearDiscriminantAnalysis()
# fit model
model.fit(X, y)
# define new data
row = [0.12777556,-3.64400522,-2.23268854,-1.82114386,1.75466361,0.1243966,1.03397657,2.35822076,1.01001752,0.56768485]
# make a prediction
yhat = model.predict([row])
# summarize prediction
print('Predicted Class: %d' % yhat)

Running the example fits the model and makes a class label prediction for a new row of data.

Predicted Class: 1

Next, we can look at configuring the model hyperparameters.

Tune LDA Hyperparameters

The hyperparameters for the Linear Discriminant Analysis method must be configured for your specific dataset.

An important hyperparameter is the solver, which defaults to ‘svd‘ but can also be set to other values for solvers that support the shrinkage capability.

The example below demonstrates this using the GridSearchCV class with a grid of different solver values.

# grid search solver for lda
from sklearn.datasets import make_classification
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
# define dataset
X, y = make_classification(n_samples=1000, n_features=10, n_informative=10, n_redundant=0, random_state=1)
# define model
model = LinearDiscriminantAnalysis()
# define model evaluation method
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)
# define grid
grid = dict()
grid['solver'] = ['svd', 'lsqr', 'eigen']
# define search
search = GridSearchCV(model, grid, scoring='accuracy', cv=cv, n_jobs=-1)
# perform the search
results = search.fit(X, y)
# summarize
print('Mean Accuracy: %.3f' % results.best_score_)
print('Config: %s' % results.best_params_)

Running the example will evaluate each combination of configurations using repeated cross-validation.

Your specific results may vary given the stochastic nature of the learning algorithm. Try running the example a few times.

In this case, we can see that the default SVD solver performs the best compared to the other built-in solvers.

Mean Accuracy: 0.893
Config: {'solver': 'svd'}

Next, we can explore whether using shrinkage with the model improves performance.

Shrinkage adds a penalty to the model that acts as a type of regularizer, reducing the complexity of the model.

Regularization reduces the variance associated with the sample based estimate at the expense of potentially increased bias. This bias variance trade-off is generally regulated by one or more (degree-of-belief) parameters that control the strength of the biasing towards the “plausible” set of (population) parameter values.

Regularized Discriminant Analysis, 1989.

This can be set via the “shrinkage” argument and can be set to a value between 0 and 1. We will test values on a grid with a spacing of 0.01.

In order to use the penalty, a solver must be chosen that supports this capability, such as ‘eigen’ or ‘lsqr‘. We will use the latter in this case.

The complete example of tuning the shrinkage hyperparameter is listed below.

# grid search shrinkage for lda
from numpy import arange
from sklearn.datasets import make_classification
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
# define dataset
X, y = make_classification(n_samples=1000, n_features=10, n_informative=10, n_redundant=0, random_state=1)
# define model
model = LinearDiscriminantAnalysis(solver='lsqr')
# define model evaluation method
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)
# define grid
grid = dict()
grid['shrinkage'] = arange(0, 1, 0.01)
# define search
search = GridSearchCV(model, grid, scoring='accuracy', cv=cv, n_jobs=-1)
# perform the search
results = search.fit(X, y)
# summarize
print('Mean Accuracy: %.3f' % results.best_score_)
print('Config: %s' % results.best_params_)

Running the example will evaluate each combination of configurations using repeated cross-validation.

Your specific results may vary given the stochastic nature of the learning algorithm. Try running the example a few times.

In this case, we can see that using shrinkage offers a slight lift in performance from about 89.3 percent to about 89.4 percent, with a value of 0.02.

Mean Accuracy: 0.894
Config: {'shrinkage': 0.02}

Further Reading

This section provides more resources on the topic if you are looking to go deeper.

Tutorials

Papers

Books

APIs

Articles

Summary

In this tutorial, you discovered the Linear Discriminant Analysis classification machine learning algorithm in Python.

Specifically, you learned:

  • The Linear Discriminant Analysis is a simple linear machine learning algorithm for classification.
  • How to fit, evaluate, and make predictions with the Linear Discriminant Analysis model with Scikit-Learn.
  • How to tune the hyperparameters of the Linear Discriminant Analysis algorithm on a given dataset.

Do you have any questions?
Ask your questions in the comments below and I will do my best to answer.

The post Linear Discriminant Analysis With Python appeared first on Machine Learning Mastery.

Go to Source