function x = intlin(a, b)
% function x = intlin(a, b)
%
% PURPOSE: solving integer equation:
% 
% x(1)*a(1) + x(2)*a(2) + ... + x(n)*a(n) = b
% x >=0
%
% Input: a (1 x n) integer (could be negative)
% b (1 x 1) integer (could be negative)
% Output: x (n x 1) integer such that
% a*x = b, x>=0
%
% Author: Bruno Luong
% Last update: 24/July/2008

% Default value, empty solution
x = zeros(0,length(a));

if ~isscalar(b)
    fprintf('b must be scalar\n'); 
    return
end

% cast to double, and reshape as long thin array
a = double(a(:));
b = double(b);

if ~all(mod(a,1)==0) || mod(b,1)~=0
    fprintf('a and b must be integers\n');
    return
end

% d*a' = g
[g d] = gcdn(a);

if mod(b,g)==0
    an = a/g;
    bn = b/g;
    
    % (d*a' = b) OR (d*an' = bn)
    d = bn*d;
    
    %
    % General gcd final solution would be
    % x = d + k;
    % with k such that: k*an' = 0, and
    % k>=-d (<=> x>=0)
    kmin = ceil(-d);
    kmax = +inf(size(a));

    % Find all k such that k.an = 0
    k = allintlin0(an, kmin, kmax);
    
    x = bsxfun(@plus, d, k);
    
else % mode(b,g)~=0
    fprintf('WARNING: there is no solution\n'); 
end

end


function x = allintlin0(a, lower, upper)
%
% x, a, lower, upper are n-dimensional vector
% List all interger x such that
% x.a = 0
% lower <= x <= upper
%

a = a(:);
n = length(a);

lower = lower(:);
upper = upper(:);
% Adjust upper and lower bounds by LP
L = lower;
U = upper;
b = 0;
epsilon = 1e-6;
for k=1:n
    cost = basis(k,n);
    % Beware, his is Bruno's linprog, not MATLAB one in optimization
    % tool box, the result may be different.
    sol = linprog(cost', zeros(size(a')), 1, a', b, L, U);
    if ~all(isinf(sol))
        L(k) = max(L(k),sol(k)-epsilon);
    end
    cost = -basis(k,n);
    sol = linprog(cost', zeros(size(a')), 1, a', b, L, U);
    if ~all(isinf(sol))
        U(k) = min(U(k),sol(k)+epsilon);
    end
end
L = ceil(L);
U = floor(U);

if all(~isinf(L)) && all(~isinf(U))
    maxcount = Inf;
else % Limit the number of solutions that will be listed
    % NOTE: set maxcount to finite value doesn't work as
efficienly,
    % as expected because the recursive engine might spend
much of
    % CPU time to look for unvalid solutions

    % maxcount = 100;
    fprintf('There is infinity of solutions\n');
    x = NaN;
    return
end

% Call the engine
count = 0;
x = ilinengine(a, L, U, b, count, maxcount);
if maxcount<inf
    fprintf('WARNING: only %d solutions will be provided\n', maxcount);
    x = x(1:min(maxcount,end),:); % clipping
end

end

function v = basis(k,n) % generate a k-th basis vector of dimension n
v = zeros(1,n);
v(k) = 1;
end

% Solver engine for integer x
% a*x = b
% lower <= x <= upper
% RESTRICTION: 
% a must be primary array, i.e., they greatest common divisor is one
function x = ilinengine(a, lower, upper, b, count, maxcount)

% default value, empty result
x = zeros(0,length(a));

if count>maxcount
    return
end

% Trivial solution with one variable
if length(a)==1
    if mod(b,a(1))==0
        xtmp = b / a(1);
        if (xtmp >= lower) && (xtmp <=upper)
            x = xtmp;
        end
    end
    return
end

% preliminary check where as the sum b is possible
% This check is to speed up (?), and deos not affect the result
as = bsxfun(@times,a,[lower upper]);
as = sum(sort(as,2),1);
if b<as(1) || b>as(2)
    return
end

% Greatest common divisor for the tail
g = gcdn(a(2:end));
% find r1, the inverse of a(1) in g-modulo group
[uno r1] = gcd(a(1),g);
clear uno; % should be 1

r=r1*mod(b,g);
s = (b - a(1)*r)/g;
a(2:end) = a(2:end)/g; % the tail is primary among them

kmin = ceil((lower(1)-r)/g);
kmax = floor((upper(1)-r)/g);

    % perform a basic step
    function newcount = kstep(k)
        % Recursive call
        xk = ilinengine(a(2:end), lower(2:end), upper(2:end), ...
                        s-k*a(1), count, maxcount);
        nxk = size(xk,1);
                  x = [x; ...
            (r+g*k)+zeros(nxk,1) xk]; % append the new solutions
        newcount = count + nxk; % adjust the counter
            
    end

if isinf(kmin) && isinf(kmax) % k is unbounded
    for absk=0:inf
        for k=unique([-absk absk])
            if kstep(k)>maxcount
                break
            end
        end
        if count>maxcount
            break
        end
    end
elseif isinf(kmin) % k has upper bound, but no lower bound
    for k=kmax:-1:-inf
        if kstep(k)>maxcount
            break
        end
    end
else % k has lower bound 
    for k=kmin:kmax
        if kstep(k)>maxcount
            break
        end
    end
end

end

function [g varargout] = gcdn(varargin)
% function g = gcdn(a1, a2, ..., an);
%
% Return g, Greatest common divisor of a1, ... an
%
% [g c1 c2 ... cn]=gcdn(a1, a2, ..., an)
% return also c1, ..., cn
% So that a1*c1 + ... an*cn = g
%
% Compact calling form:
% [g c]=gcdn(a1, a2, ..., an) or [g c]=gcdn(a)
% assumes a and c are array
%

if nargin<2
    a = varargin{1};
    if length(a)<2
        if a
            g = abs(a);
            if nargout>=2
                varargout{1}=sign(a);
            end
            return
        else
           error('gcdn cannot compute for a = 0');
        end
    end
    a = reshape(a,1,[]);
else
    % Put all numbers in array
    a=cell2mat(varargin);
end

g=a(1);
c=zeros(size(a));
c(1)=1;
for k=2:length(a)
    [g cg c(k)]= gcd(g, a(k));
    c(1:k-1) = c(1:k-1)*cg;
end

if nargout>=2
    switch (nargout-1)
        case 1,
            varargout{1}=c;
        case length(c)
            varargout=num2cell(c);
        otherwise
             error('number of output is incompatible with input');
    end
end

end
