### We propose a model for a factored covariance matrix # the model is as follows: # - we have observations X_1, \ldots, X_n which following a zero mean Gaussian with covariance \Sigma # - we model \Sigma = W G W^T + I for orthonormal W of rank K and dense G (k by k) # - we propose to learn W and G via score matching with some constraints on W # - we introduce non-negativity and orthonormality constraints on W in order to ensure the model is identifiable. # we note that directly optimizing over the set of non-neg orthonormal matricies is too difficult (depends very highly # on the initial choice of W!). As a result, we employ augmented Lagragian multipliers and enforce orthonormality only # in the limit (non-negativity is enforced at each iteration) # # the function nonNegativeCovFactor_LagrangeMult implements the decomposition described above, returns non-neg orthonormal W and a latent # connectivities matrix, G, for each subject. # # # # nonNegativeCovFactor_LagrangeMult = function( Shat, k=2, diagG=FALSE, lagParam=1, tol=.01, alphaArimijo=0.5, maxIter=1000){ # decompose covariance into non-negative W and diagonal G (see above) # # here we enforce orthonormality of W by using augmented Lagrangians as detailed here: # https://link.springer.com/content/pdf/10.1007%2Fs10915-013-9740-x.pdf # The parameter lagParam denotes the r in equation (6) on page 432. # # # INPUT: # - Shat: list of empirical covariance matrices across subjects, one entry per subject. # IMPORTANT: even in the case of 1 class, Shat must be a list! # - k: rank of approximation (by default this is 2) # - diagG: should we estimate a diagonal or full G matrix? Default is False (i.e. estimate dense G) # - lagParam: coefficient for augmented Lagrangian # - alphaArimijo: backtracking parameter for Armijo rule # - alpha: stepsize for W updates # - maxIter: max number of iterations # # # OUTPUT: # - W: non-negative, orthonormal loading matrix # - G: list of latent connectivities for each class/subject # - iter: number of iterations # - logLik: log likelihood over *training data* # # # # define some parameters p = ncol( Shat[[1]] ) nSub = length( Shat ) ShatMean = Reduce('+', Shat ) / nSub LagMult = matrix( 0, ncol=k, nrow=k ) # Lagrange multiplier matrix, will enforce orthonormality on W # initialize parameters: W is set to eigenvalues of ShatMean, eigenS = eigen( ShatMean ) # initialize to non-negative eigenvectors W = eigenS$vector[, 1:k] # check if we should flip the signs on W: W = apply(W, 2, FUN=function(x){ if (sum(x)<0){ return(x*-1) } else { return(x) } }) Wold = W # to check convergence W = ProjectNonNegative(W) # projecting max 1 is too harsh - and we get it for free with nonNeg + orthogonality! if (diagG){ A = lapply( Shat, FUN=function(x){ AupdateDiag( W, x ) }) } else { A = lapply( Shat, FUN=function(x){ AupdateNonDiag( W, x ) }) } cArmijo = 0.01 for (iter in 1:maxIter){ # update W AtildeAll = lapply( A, FUN=function(x){ 0.5* x %*% x - x }) Atilde = Reduce('+', AtildeAll) / nSub Wgrad = matrix(0, ncol=k, nrow=p) for (s in 1:nSub){ Wgrad = Wgrad + (1 * Shat[[s]] %*% W %*% AtildeAll[[s]]) / nSub } Wgrad = Wgrad + lagParam * ( W %*% t(W) %*% W - W ) + W %*% LagMult # apply Armijo rule update: W = armijoUpdateW_MultiSubject_penalized( W=W, Wgrad=Wgrad, Gtilde=AtildeAll, Shat=Shat, alpha=alphaArimijo, nonNeg=nonNeg, c=cArmijo, useStiefel = FALSE ) W = ProjectNonNegative( W ) # to ensure non-negativity W = normalizeColumns( W ) # update A if (diagG){ A = lapply( Shat, FUN=function(x){ AupdateDiag( normalizeColumns(ProjectMax1(W)), x ) }) } else { A = lapply( Shat, FUN=function(x){ AupdateNonDiag( normalizeColumns(ProjectMax1(W)) , x ) }) } # update Lagrange multiplier: LagMult = LagMult + lagParam * (t(W) %*% W - diag(k)) if (sum(abs(W-Wold)) < tol){ break } else { #cat('Error: ', sum(abs(W-Wold)), '\n') Wold = W } } Wmax1 = normalizeColumns(ProjectMax1(W)) # apply max 1 constraint here # and recover G: if (diagG){ G = lapply(Shat, FUN=function(x){ diag(diag(t(Wmax1) %*% x %*% Wmax1)) - diag(k) }) } else{ G = lapply(Shat, FUN=function(x){ t(Wmax1) %*% x %*% Wmax1 - diag(k) }) } # finally we compute the data likleihood: PresHat = vector("list", nSub) logLik = rep(0, nSub) for (i in 1:nSub){ PresHat[[i]] = diag(p) - Wmax1 %*% G[[i]] %*% solve( G[[i]] + diag(k) ) %*% t(Wmax1) logLik[i] = 0.5 * log ( det ( PresHat[[i]] ) ) - 0.5 * sum(diag( Shat[[i]] %*% PresHat[[i]])) } return(list('G'=G, 'W'=W, 'W'=Wmax1, 'iter'=iter, 'logLik'=logLik)) } ## helper functions ProjectNonNegative = function( W ){ # project a matrix into non-negative numbers # this is done elementwise # Wnew = apply(W, c(1,2), FUN=function(x){max(0,x)}) return(Wnew) } ProjectMax1 = function( W ){ # project onto set where at most 1 entry is non-zero per row # same as constraint used in MCF # Wnew = apply(W, 1, FUN=function(x){ ii = which.max(x) x[-ii] = 0 x[ii] = max(0, x[ii]) return(x) }) return(t(Wnew)) } normalizeColumns = function( W ){ # normalize columns to unit norm Wnew= apply(W, 2, FUN=function(x){ if (sum(abs(x))==0){ return( x + .001) } else{ return( (x)/sqrt(.0001 + sum(x**2)) ) } }) return(Wnew) } AupdateDiag = function( W, Shat ){ # udpate diagonal A matrix # targets = diag( t(W) %*% Shat %*% W ) Anew = diag( 1 - 1/targets ) return(Anew) } AupdateNonDiag = function( W, Shat ){ # update A without assuming diagonal - ie for the full G/A matrix case! k = ncol( W ) invMat = try(solve( t( W ) %*% Shat %*% W ), silent=TRUE) if (class(invMat)=="try-error"){ invMat = ginv( t( W ) %*% Shat %*% W ) } Anew = diag( k ) - invMat evalues = eigen( Anew )$values if ( min( evalues ) <= 0 ){ Anew = Anew + diag(k) * ( abs( min( evalues ) ) + .001) } return( Anew ) } armijoUpdateW_MultiSubject_penalized = function(W, Wgrad, Gtilde, Shat, alpha=.5, c=0.001, nonNeg=TRUE, useStiefel=TRUE, maxIter=1000){ # update W according to armijo backtracking # # Gtilde is the matrix: 0.5 A*A - A where A= G/(vI+ G^{-1}) # we just project onto non-negative here, not onto the set of orthogonal, non-negatives as this is too restrictive! # here both Gtilde and Shat are lists! # # nSub = length(Shat) stopBackTrack = FALSE Wgrad = apply(Wgrad, 2, FUN=function(x){x/(sqrt(sum(x**2))+.001)}) iterCount = 0 while(stopBackTrack==FALSE){ if (useStiefel){ # we proceed along the Stiefel manifold: Wnew = W - alpha * ( Wgrad - W %*% t(Wgrad) %*% W ) } else{ # we just take a GD step: Wnew = W - alpha * ( Wgrad ) } #Wnew = ProjectMax1( Wnew, nonNeg = nonNeg ) Wnew = ProjectNonNegative( Wnew ) Wnew = normalizeColumns( Wnew ) currObj = 0 newObj = 0 for (i in 1:nSub){ currObj = currObj + sum( diag( t(W) %*% Shat[[i]] %*% W %*% Gtilde[[i]] ) ) # this is f(Wnew) newObj = newObj + sum( diag( t(Wnew) %*% Shat[[i]] %*% Wnew %*% Gtilde[[i]] ) ) # this is f(W) } if ( newObj <= currObj + c * alpha * sum(diag( t(Wgrad) %*% (Wnew-W))) + 0.001 ){ stopBackTrack=TRUE } else { alpha = alpha / 2 iterCount = iterCount + 1 if (iterCount > maxIter){ stopBackTrack = TRUE } } } return( Wnew ) }