EM Algorithm for Gaussian Mixture Models


Author: Anil Gurbuz

Last Updated: 5 MAY 2019


Gaussian Mixture Models

Gaussian mixture models are used to model one or multi dimensional data regardless of the labels. Main idea is estimating the underlying distribution that data --hopefully-- coming from. These models assumes multivariate Gaussian distribution for each underlying data generative process and finds $\mu$ and coveriance matrix $\Sigma$ for those different processes i.e clusters. Process of estimating multivariate Gaussian parameters are based on Maximum Likelihood approach -- Basically parameters maximising the below Likelihood function is wanted.

$$ \mathcal{L( \mu_{0,1...k},\Sigma_{0,1...k} | X_1, X_2 ... X_n )} = \prod_{i=1}^{n} \sum_{k=1}^{K} \gamma_{k} N(x_i;\mu_{k},\Sigma_{k})$$

Problem with the above expression is the latent variable $\gamma_{k}$ which is the mixing coefficients(percentage of the data in the cluster k). Because we don't know the assigned clusters of the points, this latent variable is unkown so we are not able to find a closed form expression for the Maximum Likelihood Estimates (MLE) of the model parameters. Instead, we will be maximising the likelihood function using Expectation Maximisation (EM) Algorithm.

Expectation Maximisation Algorithm

Expectation Maximisation consists of 2 main steps and iterates between those steps until convergence.

In the Expectation step;

  • We use the current values of the model parameters to estimate the posterior value of $\gamma$.

In the Maximisation step;

  • We use the posterior value of latent variable $\gamma$ to update the model parameters $\mu$ and $\Sigma$.
In [1]:
# Load libraries
library(mvtnorm) # generates multivariate Gaussian sampels and calculate the densities
library(ggplot2) # plotting
library(reshape2) # data wrangling
library(clusterGeneration) # generates the covariance matrices that we need for producing synthetic data.
library(tm)
library(SnowballC)
Loading required package: MASS
Loading required package: NLP

Attaching package: ‘NLP’

The following object is masked from ‘package:ggplot2’:

    annotate

In [2]:
## read the file (each line of the text file is one document)
text=readLines('Corpus.txt')

## the terms before '\t' are the lables (the newsgroup names) and all the remaining text after '\t' are the actual documents
docs <- strsplit(text, '\t')
rm(text) # just free some memory!

# store the labels for evaluation
labels <-  unlist(lapply(docs, function(x) x[1]))
docs <- data.frame(unlist(lapply(docs, function(x) x[2])))
# store the unlabeled texts   
docs[,2]=1:nrow(docs)
colnames(docs)=c('text','doc_id')
docs=docs[,c(2,1)]
                                                                  
# create a corpus
docs <- DataframeSource(docs)
docs <- Corpus(docs)

# Preprocessing:
docs <- tm_map(docs, removeWords, stopwords("english")) # remove stop words (the most common word in a language that can be find in any document)
docs <- tm_map(docs, removePunctuation) # remove pnctuation
docs <- tm_map(docs, stemDocument) # perform stemming (reducing inflected and derived words to their root form)
docs <- tm_map(docs, removeNumbers) # remove all numbers
docs <- tm_map(docs, stripWhitespace) # remove redundant spaces 

# Create a matrix which its rows are the documents and colomns are the words. 
## Each number in Document Term Matrix shows the frequency of a word (colomn header) in a particular document (row title)
dtm <- DocumentTermMatrix(docs)

## reduce the sparcity of out dtm
dtm <- removeSparseTerms(dtm, 0.90)

## convert dtm to a matrix
m = as.matrix(dtm)
rownames(m) <- 1:nrow(m)
# Rename the matrix
counts=m

Implementation of Hard-EM and Soft-EM

In [3]:
# function to prevent numerical overflow/underflow
logSum <- function(v) {
   m = max(v)
   return ( m + log(sum(exp(v-m))))
}
In [4]:
soft_EM = function(counts, K, eta.max){
    
    # Initialzations and parameter setting
    N = nrow(counts)
    W = ncol(counts)
    terminate = FALSE
    eta = 1
    rho <- matrix(1/K,nrow = K, ncol=1)
    mu <- matrix(runif(K*W),nrow = K, ncol = W)
    mu <- prop.table(mu, margin = 1)
    colnames(mu) = colnames(counts)
    rownames(mu) = 1:K
    post = matrix(0, nrow = N, ncol = K)
    

    
    while (!terminate){
        # Expectation
        for (n in 1:N){
            for (k in 1:K){
                ## calculate the  numerator of posterior
                post[n,k]= log(rho[k])+sum(counts[n,]*log(mu[k,]))
            }
            # Add the effect of denominator as log
            logZ = logSum(post[n,])
            post[n,] = post[n,] - logZ
        }
        # Get rid of log
        post= exp(post)
        
        # Maximisation
        rho = colSums(post)/N
        for (k in 1:K){
            doc_weighted=0
            for (n in 1:N){
                doc_weighted = doc_weighted + post[n,k]*counts[n,]
            }
            mu[k,] = doc_weighted
        }
        #normalise every row
        mu=mu+1e-10
        mu=mu/rowSums(mu)
    
        # increase counter
        eta = eta +1
        # Check termination criteria 
        terminate <- (eta > eta.max )
    }

    return(list("post"=post, "rho"=rho, "mu"=mu))
    
}
In [5]:
hard_EM = function(counts, K, eta.max){
    
    # Initialzations and parameter setting
    N = nrow(counts)
    W = ncol(counts)
    terminate = FALSE
    eta = 1
    rho <- matrix(1/K,nrow = K, ncol=1)
    mu <- matrix(runif(K*W),nrow = K, ncol = W)
    mu <- prop.table(mu, margin = 1)
    colnames(mu) = colnames(counts)
    rownames(mu) = 1:K
    post = matrix(0, nrow = N, ncol = K)
    

    while (!terminate){
        # Expectation
        for (n in 1:N){
            for (k in 1:K){
                post[n,k]= log(rho[k])+sum(counts[n,]*log(mu[k,]))
            }
            logZ = logSum(post[n,])
            post[n,] = post[n,] - logZ
        }
        post= exp(post)
        
        # Hard EM assignments
        max.prob <- post==apply(post, 1, max)
        post[max.prob] = 1
        post[!max.prob] = 0
        
        # Maximisation
        rho = colSums(post)/N
        for (k in 1:K){
            doc_weighted=0
            for (n in 1:N){
                doc_weighted = doc_weighted + post[n,k]*counts[n,]
            }
            mu[k,] = doc_weighted
        }
        #normalise every row
        mu=mu+1e-10
        mu=mu/rowSums(mu)
    
        # increase counter
        eta = eta +1
        # Check termination criteria
        terminate <- (eta > eta.max )
    }

    return(list("post"=post, "rho"=rho, "mu"=mu))
    
}

Apply EM Algorithm - PCA - Visualise

In [6]:
# Run
soft = soft_EM(counts, 4, 100)
hard = hard_EM(counts, 4, 100)

# Take predictions
soft_pred=soft$post
hard_pred=hard$post

# Convert eventual probabilities to predictions of soft EM
max.prob <- soft_pred==apply(soft_pred, 1, max)
soft_pred[max.prob] = 1
soft_pred[!max.prob] = 0

# To store predictions
soft_predicitons=rep(0,nrow(soft_pred))
hard_predicitons=rep(0,nrow(hard_pred))

# Take the columns with 1 and assign group predictions
soft_predicitons[which(soft_pred[,1]==1)]=1
soft_predicitons[which(soft_pred[,2]==1)]=2
soft_predicitons[which(soft_pred[,3]==1)]=3
soft_predicitons[which(soft_pred[,4]==1)]=4

hard_predicitons[which(hard_pred[,1]==1)]=1
hard_predicitons[which(hard_pred[,2]==1)]=2
hard_predicitons[which(hard_pred[,3]==1)]=3
hard_predicitons[which(hard_pred[,4]==1)]=4

# Assign corresponding numbers to string type clusters
labels[which(labels=='sci.crypt')]=1
labels[which(labels=='sci.electronics')]=2
labels[which(labels=='sci.med')]=3
labels[which(labels=='sci.space')]=4

# Visualise
p.comp <- prcomp(counts) 
plot(p.comp$x, col=adjustcolor(soft_predicitons, alpha=0.5), pch=16,  main='Soft_EM Clusters')
plot(p.comp$x, col=adjustcolor(hard_predicitons, alpha=0.5), pch=16,  main='hard_EM Clusters')
plot(p.comp$x, col=adjustcolor(factor(labels), alpha=0.5), pch=16,  main='True Labels')