function [D,Z] = my_ksvd(X,initdict,epsilon,smax,maxiter)
%%%%%%%%%%%%%%%%%%%%% K-SVD dictionary learning %%%%%%%%%%%%%%%%%%%%%%%%%%%
% Runs the K-SVD dictionary learningalgorithm on the specified set of
% signals X, returning the trained dictionary D and the sparse coding Z
% such that X ~ D*Z.
%
% More precisely, we intend to solve:
%
% min |X-D*XZ|_F^2 s.t. for all i : |Z_i|_0 <= smax
% D,X or : |X_i - D*Z_i|_2 <= epsilon
%
% where X is the set of training signals, Z_i is the i-th column of
% Z, smax is the target sparsity and thresh is the target error.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%% INPUT %%%%%%
% - X (N*T) : Set of T N-dimensional training signals
% - initdict (N*K) : Initial N-dimensional dictionary with K atoms
% - epsilon (double) : Target reconstruction error per signal
% - smax (int) : Maximum number of nonzero element in columns of Z
% - maxiter (int) : Maximum number of iterations
%
%%%%%% OUTPUT %%%%%
% - D (N*K) : Output N-dimensional dictionary with K atoms
% - Z (K*T) : Output sparse coding
%
%%%%%% TP Telecom Strasbourg 2021
%%%%%% Antoine Deleforge (antoine.deleforge@inria.fr)
%%%%%% Code inspired from the implementation of Ron Rubinstein (2009)
%%%%% References:
% [1] M. Aharon, M. Elad, and A.M. Bruckstein, "The K-SVD: An Algorithm
% for Designing of Overcomplete Dictionaries for Sparse
% Representation", the IEEE Trans. On Signal Processing, Vol. 54, no.
% 11, pp. 4311-4322, November 2006.
% [2] M. Elad, R. Rubinstein, and M. Zibulevsky, "Efficient Implementation
% of the K-SVD Algorithm using Batch Orthogonal Matching Pursuit",
% Technical Report - CS, Technion, April 2008.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%% Initialize variables %%%%%%%%%%%%%%%%%
% Initialize D with initdict and normalize its columns to unit norm
% Tip: use normcols
D = initdict;
D = normcols(D); % normalize the dictionary
% Initialize X by a zero matrix of size K x T
K = size(D,2);
T = size(X,2);
Z = zeros(K,T);
err = zeros(1,maxiter+1);
err(1)=1e11;
converged = false;
iter=1;
%%%%%%%%%%%%%%%%% main loop %%%%%%%%%%%%%%%%%
while(~converged)
%%%%% sparse coding %%%%%
% For t = 1...T, sparsely code the signal Y(:,t) using dictionary D and
% place the code in X(:,t)
for t=1:T
Z(:,t) = sparse_coding(D, X(:,t),epsilon, smax);
end
%%%%% dictionary update %%%%%
% For k = 1...K, update the atom d_k = D(:,k)
used_sig = false(1,size(X,2)); % Track signals which have been used
% for replacement
for k = 1:K
d_k = D(:,k);
% Compute idx, the set of signal indices in {1,...,T} that use
% atom d_k
idx = find(Z(k,:));
% Case of unused Atom (idx is empty)
% - Easy option : Ignore it
% => Advanced option : Replace d_k by the most poorly
% reconstructed signals
if isempty(idx)
Xunused = X(:,~used_sig);
E = sum(abs(Xunused- D*Z(:,~used_sig)).^2);
[~,i] = max(E);
d_k = Xunused(:,i);
d_k = d_k./norm(d_k);
z_k = 0;
used_sig(i)=true;
fprintf(1,'\n Replacing Atom %d by data point %d\n',k,i);
else
% Compute Xbis and Ybis the submatrix of X and Y corresponding
% to signals of indices idx
Zbis = Z(:,idx);
Xbis = X(:,idx);
% Extract x_k, the row vector containing the nonzero weights
% that these signals give to atom d_k
z_k = Z(k,idx);
% Compute R, the current residual when reconstructing Ybis
% *without* using d_k and x_k.
R = Xbis - D*Zbis + d_k*z_k;
% update x_k and d_k such that d_k*x_k is the best rank-1
% approximation of R. Tip: use "svds"
[d_k,s,z_k] = svds(R, 1);
z_k = s*z_k;
end
Z(k,idx) = z_k;
D(:,k) = d_k;
end
%%%%% Compute ||Y - DX||^2_2 / ||Y||_2^2, the relative mean squared
%%%%% error (RMSE) on this iteration
err(iter+1) = sum(sum(abs(X - D*Z).^2))./sum(sum(abs(X).^2));
fprintf(1,'Iteration %d / %d complete, RMSE = %.4g\n',...
iter,maxiter,err(iter+1));
%%%%% Early stopping when the RMSE is less than 0.1 %
TOLERANCE = 1e-03;
converged = (iter>=maxiter) | ...
(abs(err(iter+1)-err(iter))/err(iter)