% Fit exponential decaying sinusoid using some basic techniques
clear
clf
clc



%% generate input signal
% (example signal was not provided, so it must be created by hand)
fs = 1000; % sampling rate
freq = 6; % frequency of sinusoid
ampl = [3000 80]; % amplitudes of sinusoid and noise
alpha = -0.8; % exponenial coeff
phi = rand(1)*2*pi; % random phase 

x = linspace(0,2*pi,fs); % timing vector

sig = sin(freq*x +phi); % generate sinusoid
decay = exp(alpha*x); % generate exponential decay
sig = sig .* decay; % put signal and decay together
noise = randn(1,fs); % make some noise
sig = sig/max(abs(sig)) * ampl(1); % normalization and amplification
noise = noise/max(abs(noise)) *ampl(2); % normalization and amplification
sig = sig + noise; % add noise

in = [noise sig noise]; % put some noise before and after the signal

clear('freq','ampl', 'alpha', 'phi', 'sig','noise','x')
%% model the input signal with exponential decaying sinusoid
% estimate start and end of input signal using some basic binary decision
alpha = exp(-1/(0.125*fs)); % get alpha factors for level estimation
level = zeros(size(in)); % allocate memory

oldLevel = 0;
% compute signal level as a function of time
% (basic 1st order low-pass level estimation)
for ii = 1:length(level);
   curIn = in(ii);
   oldLevel = alpha*curIn + (1-alpha)*oldLevel;
   level(ii) = oldLevel;
end

level_dB = 20*log10(abs(level)); % decibel
maxLevel_dB = max(level_dB); % get max level
[notUsed idx] = find((maxLevel_dB-level_dB) < 30); % find start and end of signal  
sig_start = idx(1); % signal start index
sig_end = idx(end); % signal end index
sig_length = sig_end - sig_start+1; % length of signal

% estimate fundamental frequency using autocorrelation
% (basic implementation, can be greatly improved by interpolation)
freqsRange = [20 1]; % range of possible frequencies [max min] [Hz] 
freqsLags = round(fs./freqsRange); % Hz to Samples
rxx = xcorr(in, 'coeff'); % normalized autocorrelation
rxx = rxx(length(in):end); % only right hand correlation vector
[notUsed idx] = max(rxx(freqsLags(1):freqsLags(2))); % index of maximum in interesting area
freq_estimate = fs/(idx+freqsLags(1)); % estimate fundamental frequency from index

% synthesize dummy sinusoid
x = linspace(0,2*pi,fs); % timing vector
out = sin(freq_estimate*x); % sinusoid
out = out(1:sig_length); % trim to signal length 

% estimate phase shift 
% create peak filter from estimated fundamental frequency
% (to get rid of the noise for better estimation of phase)
[z,p,k] = butter(12, [freq_estimate-0.1  freq_estimate+0.1]/(0.5*fs), 'stop'); % butterworth design
[sos,g] = zp2sos(z,p,k); % second order sections
h = dfilt.df2sos(sos,g); % filter object
in_filtered = in - filter(h, in); % filtering
in_filtered = in_filtered(sig_start:sig_end);

rxy = xcorr(out,in_filtered, 'coeff'); % cross correlation
[notused idx] = max(rxy); % find indexshift of maximum 
% (if the synthesized signal was in phase, the maximum would be exactly in
% the middle (right hand index 0)
phase_shift = idx - sig_length; % get phase shift in samples
phase = (freq_estimate/fs*phase_shift)*2*pi; % phase in radians

% resynthesize signal with phase shift
% (so the sinusoids have the same phase and frequency)
x = linspace(0,2*pi,fs); % timing vector
out = sin(freq_estimate*x + phase);
out = out(1:sig_length);
out = out/max(abs(out)) * 10^(maxLevel_dB/20); % normalization and amplification

% create exponential decay
x = linspace(0,2*pi,fs); % timing vector
% get handle to function (minimize least squares error by )
errorfunc = @(lambda) sum((in(sig_start:sig_end)-out.*exp(lambda*x(1:sig_length))).^2); 

% minimize the least squares by optimizing lambda
options = optimset('TolX', 1.e-4, 'TolFun', 1.e-4 , 'MaxFunEvals', 1500, 'MaxIter', 500);
lambda_out = fminsearch(errorfunc,0, options);

% create final decay with the new found lambda
decay_out = exp(lambda_out *x(1:sig_length));
out = out .* decay_out;

% plot input signal and synthesized signal
hold all 
plot(in(sig_start:sig_end), 'linewidth',2) 
plot(out, 'linewidth',2)
ylim([-3500 3500])
xlim([-0 850])