function dy_dlogX = convertsymgrad(dy_dX,X) % function dy_dlogX = convertsymgrad(dy_dX,X) % % Convert a gradient of a scalar with respect to unconstrained square matrix X % into a gradient wrt to log(X) and add constraint that X and log(X) are % symmetric. log(X) is the matrix logarithm (logm(X) in Matlab). We do not % assume dy_dX and X commute (in which case answer would just be (dy_dX)'*X --- % unfortunately we use eig, so this is probably O(D^3). % % For the symmetric stuff and more on matrix exponentials in general, see the % section on the symmetry operator in: % Najfeld I, Havel TF % Derivatives of the Matrix Exponential and Their Computation % Advances in Applied Mathematics 16, 321-375 (1995). % % Iain Murray 27 May 2004. Many thanks to Peter Latham. % This was put together in test.m, see that for more of where this came from % See also convertgrad, which doesn't do the symmetrising on the last line. D=length(X); [E,L]=eig(X); L=log(diag(L)); [a,b]=meshgrid(L,L); ix=find(~eye(D)); T1=zeros(D); T1(ix)= a(ix)./(b(ix).*(b(ix)-a(ix))).*(exp(b(ix))-1-b(ix)-0.5*b(ix).^2)... +b(ix)./(a(ix).*(a(ix)-b(ix))).*(exp(a(ix))-1-a(ix)-0.5*a(ix).^2); ix=find(eye(D)); T1(ix)=(exp(L)-1-L)-2*(exp(L)-1-L-L.^2/2)./L; G1=E*((E'*(dy_dX'*E)).*T1)*E'; T2=(exp(L)-1-L)./L; % used for both G2 and G3 G2=E*((E'*dy_dX').*repmat(T2,1,D)); G3=(dy_dX'*E)*(E.*repmat(T2',D,1))'; dy_dlogX=G1'+G2'+G3'+dy_dX'; dy_dlogX=dy_dlogX+dy_dlogX'-diag(diag(dy_dlogX)); % Some numerical checks %---------------------- % % This also serves to explain what I'm doing. % % In all of this lX=logm(X) % % The gradient is naively an infinite sum, the first 200 terms are: % dy_dlogX=0; % for i=1:200 % for j=0:i % dy_dlogX=dy_dlogX+lX^j*A'*lX^(i-j)/factorial(i+1); % end % end % dy_dlogX=dy_dlogX'+A' % The various quantities I compute are bits of this sum (split into parts which % are amenable to being in the same form after transforming lX into an % eigen-basis. % G1=0; % for i=3:50 % for j=1:(i-2) % G1=G1+lX^j*A'*lX^(i-1-j)/factorial(i); % end % end % G1 % G2=0; % for i=2:50 % G2=G2+lX^(i-1)*A'/factorial(i); % end % G2 % G3=0; % for i=2:100 % G3=G3+A'*lX^(i-1)/factorial(i); % end % G3