import numpy as np def m_step(X, ES, ESS): """ mu, sigma, pie = MStep(X,ES,ESS) Inputs: ----------------- X: shape (N, D) data matrix ES: shape (N, K) E_q[s] ESS: shape (K, K) sum over data points of E_q[ss'] (N, K, K) if E_q[ss'] is provided, the sum over N is done for you. Outputs: -------- mu: shape (D, K) matrix of means in p(y|{s_i},mu,sigma) sigma: shape (,) standard deviation in same pie: shape (1, K) vector of parameters specifying generative distribution for s """ N, D = X.shape if ES.shape[0] != N: raise TypeError('ES must have the same number of rows as X') K = ES.shape[1] if ESS.shape == (N, K, K): ESS = np.sum(ESS, axis=0) if ESS.shape != (K, K): raise TypeError('ESS must be square and have the same number of columns as ES') mu = np.dot(np.dot(np.linalg.inv(ESS), ES.T), X).T sigma = np.sqrt((np.trace(np.dot(X.T, X)) + np.trace(np.dot(np.dot(mu.T, mu), ESS)) - 2 * np.trace(np.dot(np.dot(ES.T, X), mu))) / (N * D)) pie = np.mean(ES, axis=0, keepdims=True) return mu, sigma, pie