clc; clearvars;
format long;

%%%%%%%%%%%%%%%%%%%%%%
%% constant factors %%
%%%%%%%%%%%%%%%%%%%%%%
dt = 1e-3;          % sampling time
DEG2RAD = pi/180;   % degree 2 rad
RAD2DEG = 180/pi;   % rad 2 degree

%%%%%%%%%%%%%%%%%%%%%%%%
%% EKF initialization %%
%%%%%%%%%%%%%%%%%%%%%%%%

% model x_k = f(x_k-1, u, w)
% xprio: state prediction, before correction
% xpost: state correction, after prediction
% x(1 - 4): Quaternion current attitude, q.rot, q.x, q.y, q.z
% u(1 - 4): Quaternion delta rotation, q.rot, q.x, q.y, q.z
%
% x(1 - 4) * u(1 - 4), quaternion multiplication : 
%       http://de.mathworks.com/help/aeroblks/quaternionmultiplication.html
%
%             q.rot = (q1.rot*q2.rot - q1.x*q2.x - q1.y*q2.y - q1.z*q2.z);
%             q.x   = (q1.rot*q2.x + q1.x*q2.rot + q1.y*q2.z - q1.z*q2.y);
%             q.y   = (q1.rot*q2.y - q1.x*q2.z + q1.y*q2.rot + q1.z*q2.x);
%             q.z   = (q1.rot*q2.z + q1.x*q2.y - q1.y*q2.x + q1.z*q2.rot);

% u: delta rotation as Quaternion

f = @(x,u) [ x(1)*u(1) - x(2)*u(2) - x(3)*u(3) - x(4)*u(4) ; % qrot
             x(1)*u(2) + x(2)*u(1) + x(3)*u(4) - x(4)*u(3) ; % qx
             x(1)*u(3) - x(2)*u(4) + x(3)*u(1) + x(4)*u(2) ; % qy
             x(1)*u(4) + x(2)*u(3) - x(3)*u(2) + x(4)*u(1) ; % qz
             x(5)     % gx
             x(6)     % gy
             x(7) ];  % gz

% measurement model
% h(x)
h = @(x) [ atan2(2*x(2)*x(1)-2*x(3)*x(4) , 1 - 2*x(2)*x(2) - 2*x(4)*x(4) )    % delta gx
           asin( 2*(x(2)*x(3) + x(4)*x(1) ) );                                % delta gy
           atan2(2*x(3)*x(1)-2*x(2)*x(4) , 1 - 2*x(3)*x(3) - 2*x(4)*x(4) ) ]; % delta gz 

states = 7;         % amount of states used in the state-vector 
                    % x or rather f in EKF
                   
Q = eye(states)*.00001; % std of system noise
R = eye(3)*.001;       % std of measurement noise
                       % dimension of R: amount of measured states 

% prio: prediction, before correction                     
% post: correction, after prediction

Pprio = eye(states); % P covariances 
Ppost = eye(states); 

I = eye(states);     % identity matrix

% initial states
xprio = [ 1
          0
          0
          0
          0 
          0 
          0 ];

xpost = [ 1
          0
          0
          0
          0 
          0 
          0 ];

% Quaternions      
qCurrent = Quaternion(1, 0, 0, 0); % Quaternion represent current attitude
qDelta   = Quaternion(1, 0, 0, 0); % Quaternion represent attitude change 
                                   % per sample

%%%%%%%%%%%%%%%%
%% simulation %%
%%%%%%%%%%%%%%%%

t = 0:dt:1;         % time vector for simulation values;
N = length(t);      % amount of values for simulation

% gx, simulation data
gx(1:N) = 45*DEG2RAD; %[rad/s]     % roll
gy(1:N) =          0; %[DEG/s]     % pitch
gz(1:N) =          0; %[DEG/s]     % yaw

% noise, simulation data
w = Q*randn(states,N);  % system noise (gaussian)
v = R*randn(3,N);       % measurement noise (gaussian)


%% Preallocating
% measurements + noise
z      = zeros(3,N);
% angles from Quaternions
angles = zeros(3,N);
% gx post Filter: gx_
gx_    = zeros(1,N);

%%%%%%%%%%%%%%%%%
%% go! EKF go! %%
%%%%%%%%%%%%%%%%%

for k=2:N
    %%%%%%%%%%%%%%%%
    % measurements %
    %%%%%%%%%%%%%%%%
    z(:,k) = [ gx(k)
               gy(k)
               gz(k) ] + v(:,k);
    z(:,k) = z(:,k)*dt;
    
    %%% condition %%%
    qDelta = Quaternion.qFromEuler(xpost(5),xpost(6),xpost(7));
    qDelta = qDelta.normalize();
    u(1) = qDelta.rot;
    u(2) = qDelta.x;
    u(3) = qDelta.y;
    u(4) = qDelta.z;
    
    %%%%%%%%%%%%%%
    % prediction %
    %%%%%%%%%%%%%%
    xprio = f(xpost,u) + w(:,k); % u: Quaternion qDelta
    
    % normalize qCurrent
    qCurrent = Quaternion(xprio(1), xprio(2), xprio(3), xprio(4));
    qCurrent = qCurrent.normalize();
    xprio(1) = qCurrent.rot;
    xprio(2) = qCurrent.x;
    xprio(3) = qCurrent.y;
    xprio(4) = qCurrent.z;   
    % qCurrent is now the new attitude
    
    % Jacobian F
    JF =[ u(1), -u(2), -u(3), -u(4), 0, 0, 0
          u(2),  u(1),  u(4), -u(3), 0, 0, 0
          u(3), -u(4),  u(1),  u(2), 0, 0, 0
          u(4),  u(3),  u(2),  u(1), 0, 0, 0
             0,     0,     0,     0, 1, 0, 0
             0,     0,     0,     0, 0, 1, 0
             0,     0,     0,     0, 0, 0, 1 ];
          
%%    % Jacobian H
     
    u(1) = xprio(1);
    u(2) = xprio(2);
    u(3) = xprio(3);
    u(4) = xprio(4);
    
    JH11 = (2*u(2))/((u(3)*u(2)^2*u(4) - 2*u(1)*u(2) + 2*u(4)^2)^2 + 1) ;
    JH12 = (2*u(1) - 2*u(2)*u(3)*u(4))/((u(3)*u(2)^2*u(4) - 2*u(1)*u(2) + 2*u(4)^2)^2 + 1) ;
    JH13 = -(u(2)^2*u(4))/((u(3)*u(2)^2*u(4) - 2*u(1)*u(2) + 2*u(4)^2)^2 + 1) ;
    JH14 = -(u(3)*u(2)^2 + 4*u(4))/((u(3)*u(2)^2*u(4) - 2*u(1)*u(2) + 2*u(4)^2)^2 + 1);
    
    JH21 = (2*u(4))/(1 - (2*u(1)*u(4) + 2*u(2)*u(3))^2)^(1/2);
    JH22 = (2*u(3))/(1 - (2*u(1)*u(4) + 2*u(2)*u(3))^2)^(1/2);
    JH23 = (2*u(2))/(1 - (2*u(1)*u(4) + 2*u(2)*u(3))^2)^(1/2);
    JH24 = (2*u(1))/(1 - (2*u(1)*u(4) + 2*u(2)*u(3))^2)^(1/2);
    
    JH31 = (2*u(3))/((2*u(3)^2 - 2*u(1)*u(3) + 2*u(4)^2 + 2*u(2)*u(4))^2 + 1);
    JH32 = -(2*u(4))/((2*u(3)^2 - 2*u(1)*u(3) + 2*u(4)^2 + 2*u(2)*u(4))^2 + 1);
    JH33 = (2*u(1) - 4*u(3))/((2*u(3)^2 - 2*u(1)*u(3) + 2*u(4)^2 + 2*u(2)*u(4))^2 + 1);
    JH34 = -(2*u(2) + 4*u(4))/((2*u(3)^2 - 2*u(1)*u(3) + 2*u(4)^2 + 2*u(2)*u(4))^2 + 1);
    
    JH = [ JH11, JH12, JH13, JH14, 0, 0, 0
           JH21, JH22, JH23, JH24, 0, 0, 0
           JH31, JH32, JH33, JH34, 0, 0, 0 ];
       
    Pprio = JF*Ppost*JF' + Q;     
      
    %%%%%%%%%%%%%%
    % correction %
    %%%%%%%%%%%%%%
    
    S  = JH*Pprio*JH' + R;
    K  = Pprio*JH' / S;        % Kalman Gain
    y     = z(:,k) - h(xprio);
    xpost = xprio +  K*y;
    Ppost = (I-K*JH)*Pprio; 
    
    % normalize qCurrent
    qCurrent = Quaternion(xpost(1), xpost(2), xpost(3), xpost(4));
    qCurrent = qCurrent.normalize();
    xpost(1) = qCurrent.rot;
    xpost(2) = qCurrent.x;
    xpost(3) = qCurrent.y;
    xpost(4) = qCurrent.z;
    
    ang = Quaternion.eulerFromQ(qCurrent); %=[DEG]
    
    %%% data collecting %%%
    angles(1,k) = ang(1)+angles(1,k-1);
    angles(2,k) = ang(2)+angles(2,k-1);
    angles(3,k) = ang(3)+angles(3,k-1);
    gx_(k)      = xpost(5);
end


%% data output
figure(1)
plot(t,angles(1,:), t,angles(2,:),t,angles(3,:)); 
    legend('yaw', 'pitch', 'roll','location','northwest');
    grid on;

figure(2)
plot(t,z(1,:),t,gx_);
    legend('gx','gx_{EKF}','location','south');