% (c) 2013-2015 Miika Aittala, Jaakko Lehtinen, Tim Weyrich, Aalto 
% University, University College London. This code is released under the 
% Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 
% license (http://creativecommons.org/licenses/by-nc-sa/4.0/).

% A few words on the code. It is (very) slightly cleaned up research code, which
% shows traces of convoluted historical development and is not particularly 
% readable, reusable, flexible or efficient. A lot of notation is different
% from the one used in the paper, but it does ultimately perform the same
% computations. The code contains quite a few hardcoded assumptions about 
% our setup parameters, and things that make sense only after one has 
% carefully worked through all the little details of the method.
%
% It's a good idea to enable parallelization before running, e.g.
% 'matlabpool 4'.

function sols = optimizer(Data, path)
    if nargin == 0
        disp('Error: no input data');
        return
    end

    if nargin == 1
        path = '';
    end

    clf;

    % Final resolution
    WW = 1024;
    % Intermediate resolutions
    seq = [32 64 128 512];
    % Number of iterations for intermediate
    iterseq = [12 3 3 3 3 3];

    % Compute a solution at lowest resolution
    sols = opt_refl_global(Data,0,seq(1),iterseq(1),[]);
    
    subplot(2,4,6);
    % Upsample to second resolution
	x = upsample(Data,sols{end},seq(2));

    % Optimize at intermediate resolutions
    for s = 2:(numel(seq)-1)
        sols = [sols opt_refl_global(Data,0,seq(s),iterseq(s),x)];
        subplot(2,4,6);
    	x = upsample(Data,sols{end},seq(s+1));        
    end

    % Final optimized solution
	sols = [sols opt_refl_global(Data,0,seq(end),5,x)];
    
    % Final upsampling to final resolution
    subplot(2,4,6);
	x = upsample(Data,sols{end},WW);

    % Fresnel correction
    fimg = fresnel_correction(Data, x, 0.06);
    x(:,:,[2 5 6]) =  x(:,:,[2 5 6]) ./ repmat(fimg, [1 1 3]);
    
    % The solution is in sols{end}
    sols = [sols x];
    
    save(strcat(path,'sols.mat'),'sols');
    results_output(Data, sols{end}, path);
end

function [sols] = opt_refl_global(Data, optmode, w, iters, x, singlep)
        
    % Frequencies used (indices)
    freqchoice = 1:8;

    % Read in the data, scale to desired low-res size
    D = Data;
    D.Z = imresize(D.Z(:,:,:,freqchoice,:), [w w], 'bilinear');
    %D.Z = imresize_sq(D.Z(:,:,:,freqchoice,:), w);
    newsize = size(Data.Z(:,:,:,freqchoice,:));
    newsize(1:2) = [w w];

    D.Z = double(reshape(D.Z, newsize));
    D.DC = double(imresize(D.DC, [w w], 'bilinear'));
    %D.DC = double(imresize_sq(D.DC, w));
    D.imgsize = [w w];

	freqs = D.freqs(freqchoice);
	D.freqs = D.freqs(freqchoice);
    nfreqs = numel(freqs);

    % The frequency used for some heuristic inital guess reasoning
    hif = 5;
    
    % List of frequencies sampled, in the order they appear in flattended 
    omega = [[0 freqs freqs zeros(1,nfreqs) -freqs];
                [0 zeros(1,nfreqs) freqs freqs freqs]];

	% Read in some geometric calibration
    geo = struct();
	geo.R = D.R;
    geo.s = D.s;
    geo.t = D.t;
    geo.E = D.E;
    geo.Rt = D.Rt;

    mon = Data.mon;
    
	par = struct();
    par.n_p = w^2;  % number of spatial points
    par.n_var = 10; % number of variables in optimization
    
    % weight of data fidelity term; notice that these are a bit different
    % from the paper, but actually turn out to give the exact same numbers
    % in the end...
    par.lh_mul = w*5; 

    % Get the window used in the dataset
    nw = numel(D.mog_win);
    mog_win = D.mog_win;
    
    % ... but we want it in another format (for historical reasons..)
    win = struct;
    win.C_w = zeros(2,2,nw);
    win.mu_w = zeros(2,nw);
    win.z_w = zeros(nw,1);
    for i = 1:nw
        win.C_w(:,:,i) = mog_win{i}.sigma;
        win.mu_w(:,i) = mog_win{i}.mu;
    	win.z_w(i) = mog_win{i}.mult;
    end
    
    par.w = nw;
      
    % Store geometric terms for all points
    con_all = cell(D.imgsize(1), D.imgsize(2));
    
    %% This is here just for the single-point visualization!
    % Skip this section.
    if optmode == 1
        % Assume a solution is stored in the data
        img = Data.sol;
    
        nfreqs_hi = 512;
    	freqs_hi = 1:nfreqs_hi;
    
        omega_hi = [[0 freqs_hi freqs_hi zeros(1,nfreqs_hi) -freqs_hi];
                [0 zeros(1,nfreqs_hi) freqs_hi freqs_hi freqs_hi]];

        clf;
        subplot(2,3,1);
        imagec(D.DC);
        hold on;
        
        i = singlep(1);
        j = singlep(2);
        plot(j,i,'yo');
        
        Z = D.Z(i,j,2,:,:);
        Z = double([D.DC(i,j,2); Z(:)]);

        ix = ([i;j]-0.5)./(D.imgsize'-1);
        q = D.T_img_to_floor * [ix(2);ix(1);1];
        p = [q(1:2)/q(3);0];
        
        geo.p = p;   % sample point position
        
        con = opt_geometric(p, geo, par,mon,mog_win);
        con.p = p;
        %con = fm_geo_consts(p,geo, par);
        
        Z = D.Z(i,j,1,:,:);
        Z = double([D.DC(i,j,1); Z(:)]);

        vars = struct;
        
        X = squeeze(img(i,j,:));
        vars.sigma_k = (X(7));
        vars.n = [X(8);X(9)]*1;
        vars.k = X(10);

        fk = eval_model(vars,omega,geo,con,par,win);
        Df = fk(:,1) * X(1);
        Sf = fk(:,2) * X(2);
        
        fk_hi = eval_model(vars,omega_hi,geo,con,par,win);
        Df_hi = fk_hi(:,1) * X(1);
        Sf_hi = fk_hi(:,2) * X(2);
        
        asd = 1;
        phi = 2*asd-1;
        hold on;
        
        if 1
            np = 256;
            ax = linspace(-pi,pi,np);
            subplot(2,3,1+asd);
            
            Zp = part_ifft2(D.freqs, ( [Z(1);Z(2+(phi-1)*nfreqs:1+phi*nfreqs)] ),np);
            plot(ax,Zp,'linewidth',3,'color','blue')
            
            hold on;
            Dp = part_ifft2(D.freqs, ( [Df(1);Df(2+(phi-1)*nfreqs:1+phi*nfreqs)] ),np);
            Sp = part_ifft2(D.freqs, ( [Sf(1);Sf(2+(phi-1)*nfreqs:1+phi*nfreqs)] ),np);

            Dp_hi = part_ifft2(freqs_hi, ( [Df_hi(1);Df_hi(2+(phi-1)*nfreqs_hi:1+phi*nfreqs_hi)] ),np);
            Sp_hi = part_ifft2(freqs_hi, ( [Sf_hi(1);Sf_hi(2+(phi-1)*nfreqs_hi:1+phi*nfreqs_hi)] ),np);
            
            subplot(2,3,2);
            plot(ax,Dp_hi+Sp_hi,'linewidth',3,'color','red');
            hold on;
            plot(ax,Dp_hi,'r','linewidth',2);
            plot(ax,Sp_hi,'m','linewidth',2);
            axis([-pi pi -0.4 3.1]);
            
            subplot(2,3,4);
            
            semilogy(Data.freqs, abs(squeeze(Data.Z(i,j,1,:,phi))),'-x');
            hold on;
            
            Df_hi_clip = Df_hi(2+(phi-1)*nfreqs_hi:1+phi*nfreqs_hi);
            Sf_hi_clip = Sf_hi(2+(phi-1)*nfreqs_hi:1+phi*nfreqs_hi);
            
            semilogy(freqs_hi,abs(Df_hi_clip + Sf_hi_clip),'r');
            semilogy(freqs_hi,abs(Df_hi_clip),':r');
            semilogy(freqs_hi,abs(Sf_hi_clip),':r');
            
            axis([0 70 1e-3 10^(-0.5)]);
            
            subplot(2,3,5);
            plot(freqs_hi,angle(Df_hi_clip + Sf_hi_clip),'r');
            hold on;
            plot(Data.freqs, angle(squeeze(Data.Z(i,j,1,:,phi))),'o');
            axis([0 70 -pi pi]);
            
        end      
        return
    end

    %% This evaluates the solution to give the "reprojection" of the data.
    % Used for debugging. Skip.
    
    if optmode == 2
    	global sols;
        global img_eval;

        soltemp = sols{end};
        
        soltemp(:,:,7) = log(soltemp(:,:,7));
        
        img_eval = zeros(w,w,size(omega,2),2);
        for i = 1:D.imgsize(1)    
            for j = 1:D.imgsize(2)            
                ix = ([i;j]-0.5)./(D.imgsize'-1);
                q = D.T_img_to_floor * [ix(2);ix(1);1];
                p = [q(1:2)/q(3);0];

                con_all{i,j} = opt_geometric(p, geo, par,mon,mog_win);
                con_all{i,j}.p = p;

                vars = extract_vars(squeeze(soltemp(i,j,:)));

                img_eval(i,j,:,:) = eval_model(vars,omega,geo,con_all{i,j},par,win);

            end
            if mod(i,10) == 5
                imagesc(angle(img_eval(:,:,9,2)))
                drawnow;
            end
        end
        return
    end

    
    %% Back to the main algorithm
    
    % Some terminology here (these are a mess for historical reasons.)
    % The optimization variables are referred to by names and numbers, in
    % parallel; this is because the optimizer just deals with indices into
    % a vector of anonymous variables. So:
    %
    % Diffuse albedo: called z0 or diff, indices 1,3,4 (for R,G,B)
    % Specular albedo: called z_k or spec, indices 2,5,6 (for R,G,B)
    % Glossiness: called sigma, index 7 (note conventions regarding log)
    % Normals: called n, indices 8,9 (be careful about axis orientations)
    % Kurtosis: called k, index 10 (again, careful with log and +1)
    
    % Initial guesses and prior helpers
    n_init = zeros(w,w,2);  % normal initial
    init_zk = zeros(w,w,3); % specular initial
    init_z0 = zeros(w,w,3); % diffuse initial
    ddc_diff = zeros(w,w);  % diffuse smoothness scale helper
    ddc_spec = zeros(w,w);  % specular -"-

    % For all pixels...
    for i = 1:D.imgsize(1)    
        for j = 1:D.imgsize(2)
            
            % Get the floor coordinates
            ix = ([i;j]-0.5)./(D.imgsize'-1);
            q = D.T_img_to_floor * [ix(2);ix(1);1];
            p = [q(1:2)/q(3);0];

            % Compute the geometric terms
            con_all{i,j} = opt_geometric(p, geo, par,mon,mog_win);
            con_all{i,j}.p = p;

            % Zero normal initial guess for now
            n_init(i,j,:) = [0;0];

            % A sharpish default helper BRDF for computation of initial 
            % guesses
            vars = struct();
            vars.sigma_k = 0.03;
            vars.n = [0;0];
            vars.k = 0.03;
            
            % Compute the FT of the above BRDF at freqs (0,0), (5,5) and
            % (1,1); the resulting matrix is 3x2, with first column
            % corresponding to diffuse, second to specular
            fk = eval_model(vars,[0 hif 1;0 hif 1],geo,con_all{i,j},par,win);
            
            % Diffuse and specular initial guess.
            init_zk(i,j,:) = abs(D.Z(i,j,:,hif,2)) / abs(fk(2,2));
            init_z0(i,j,:) = (D.DC(i,j,:)-abs(D.Z(i,j,:,hif,2))) / fk(1,1);

            % Record the effects of the diffuse and specular component to
            % the DC component; these will be used to scale the DC
            % derivatives respectively when forming smoothness priors
            ddc_diff(i,j) = fk(1,1);
            ddc_spec(i,j) = fk(1,2);

        end
        INITIAL = i
    end
    
    
    %% Construct priors
    PRIORS_START = 1
        
    priors = cell(0);
    
    function prior = makeprior(var, A, b, W)
        prior = struct();
        prior.var = var;
        prior.A = A;
        prior.b = b;
        prior.W = W;
    end

    zpr = sparse(par.n_p,1); % for convenience
    eyepr = speye(par.n_p);

    spmul = w/32;
    
    % MAGNITUDE PRIORS
    % Albedos
    priors{end+1} = makeprior(1, {eyepr}, zeros(par.n_p,1), spmul*eyepr/1);
    priors{end+1} = makeprior(3, {eyepr}, zeros(par.n_p,1), spmul*eyepr/1);
    priors{end+1} = makeprior(4, {eyepr}, zeros(par.n_p,1), spmul*eyepr/1);

    priors{end+1} = makeprior(2, {eyepr}, zeros(par.n_p,1), spmul*eyepr/0.5);
    priors{end+1} = makeprior(5, {eyepr}, zeros(par.n_p,1), spmul*eyepr/0.5);
    priors{end+1} = makeprior(6, {eyepr}, zeros(par.n_p,1), spmul*eyepr/0.5);

    % Glossiness
    priors{end+1} = makeprior(7, {eyepr}, -3*ones(par.n_p,1), spmul*eyepr/1);
    
    % Kurtosis
	priors{end+1} = makeprior(10, {eyepr}, -1*ones(par.n_p,1), spmul*eyepr/0.5);        
    
    % Normals
    priors{end+1} = makeprior(8, {eyepr}, zpr, spmul*eyepr/0.1);
    priors{end+1} = makeprior(9, {eyepr}, zpr, spmul*eyepr/0.1);

    % If at low resolution, add spatial priors
    if w < 200

        % Construct derivative operators (along "directions 1 and 2", a.k.a. x
        % and y, though not necessarily so in the image (see conventions)
        [D1,D2] = D_mats(w);
        flat = @(x) reshape(x,[],1); % Helper: equivalent to (:)

        % x- and y-derivatives of the DC component, scaled by the weights
        % computed above
        edge_diff1r = (D1 * (flat(D.DC(:,:,1)./ddc_diff)));
        edge_diff1g = (D1 * (flat(D.DC(:,:,2)./ddc_diff)));
        edge_diff1b = (D1 * (flat(D.DC(:,:,3)./ddc_diff)));
        edge_diff1 = max(max(edge_diff1r,edge_diff1g),edge_diff1b);

        edge_diff2r = (D2 * (flat(D.DC(:,:,1)./ddc_diff)));
        edge_diff2g = (D2 * (flat(D.DC(:,:,2)./ddc_diff)));
        edge_diff2b = (D2 * (flat(D.DC(:,:,3)./ddc_diff)));
        edge_diff2 = max(max(edge_diff2r,edge_diff2g),edge_diff2b);

        edge_spec1r = (D1 * (flat(D.DC(:,:,1)./ddc_spec)));
        edge_spec1g = (D1 * (flat(D.DC(:,:,2)./ddc_spec)));
        edge_spec1b = (D1 * (flat(D.DC(:,:,3)./ddc_spec)));
        edge_spec1 = max(max(edge_spec1r,edge_spec1g),edge_spec1b);

        edge_spec2r = (D2 * (flat(D.DC(:,:,1)./ddc_spec)));
        edge_spec2g = (D2 * (flat(D.DC(:,:,2)./ddc_spec)));
        edge_spec2b = (D2 * (flat(D.DC(:,:,3)./ddc_spec)));
        edge_spec2 = max(max(edge_spec2r,edge_spec2g),edge_spec2b);


        % Spatially varying albedo smoothness
        edge_diff1_mr = spdiags(1./abs(edge_diff1r),0,w^2,w^2)/2;
        edge_diff1_mg = spdiags(1./abs(edge_diff1g),0,w^2,w^2)/2;
        edge_diff1_mb = spdiags(1./abs(edge_diff1b),0,w^2,w^2)/2;
        edge_diff1_m = spdiags(1./abs(edge_diff1),0,w^2,w^2)/2;

        edge_diff2_mr = spdiags(1./abs(edge_diff2r),0,w^2,w^2)/2;
        edge_diff2_mg = spdiags(1./abs(edge_diff2g),0,w^2,w^2)/2; 
        edge_diff2_mb = spdiags(1./abs(edge_diff2b),0,w^2,w^2)/2; 
        edge_diff2_m = spdiags(1./abs(edge_diff2),0,w^2,w^2)/2;

        edge_spec1_mr = spdiags(1./abs(edge_spec1r),0,w^2,w^2)/2;
        edge_spec1_mg = spdiags(1./abs(edge_spec1g),0,w^2,w^2)/2;
        edge_spec1_mb = spdiags(1./abs(edge_spec1b),0,w^2,w^2)/2;
        edge_spec1_m = spdiags(1./abs(edge_spec1),0,w^2,w^2)/2;

        edge_spec2_mr = spdiags(1./abs(edge_spec2r),0,w^2,w^2)/2;
        edge_spec2_mg = spdiags(1./abs(edge_spec2g),0,w^2,w^2)/2;
        edge_spec2_mb = spdiags(1./abs(edge_spec2b),0,w^2,w^2)/2;
        edge_spec2_m = spdiags(1./abs(edge_spec2),0,w^2,w^2)/2;

        priors{end+1} = makeprior(1, {D1}, zpr, edge_diff1_m);
        priors{end+1} = makeprior(1, {D2}, zpr, edge_diff2_m);
        priors{end+1} = makeprior(3, {D1}, zpr, edge_diff1_m);
        priors{end+1} = makeprior(3, {D2}, zpr, edge_diff2_m);
        priors{end+1} = makeprior(4, {D1}, zpr, edge_diff1_m);
        priors{end+1} = makeprior(4, {D2}, zpr, edge_diff2_m);

        priors{end+1} = makeprior(2, {D1}, zpr, edge_spec1_m);
        priors{end+1} = makeprior(2, {D2}, zpr, edge_spec2_m);
        priors{end+1} = makeprior(5, {D1}, zpr, edge_spec1_m);
        priors{end+1} = makeprior(5, {D2}, zpr, edge_spec2_m);
        priors{end+1} = makeprior(6, {D1}, zpr, edge_spec1_m);
        priors{end+1} = makeprior(6, {D2}, zpr, edge_spec2_m);

        % Integrability
        priors{end+1} = makeprior([8 9], {-D2,D1}, zpr, eyepr/0.01);
        
        % Smoothness
        priors{end+1} = makeprior(7, {D1}, zpr, eyepr/1.0);
        priors{end+1} = makeprior(7, {D2}, zpr, eyepr/1.0);
        priors{end+1} = makeprior(10, {D1}, zpr, eyepr/1.0);
        priors{end+1} = makeprior(10, {D2}, zpr, eyepr/1.0);
    end

    % Rearrange to a big sparse A and vector b for easy application later
    prior = struct();
    prior.A = [];
    prior.b = [];
    
    priorcell_A = cell(numel(priors), par.n_var);
    priorcell_b = cell(numel(priors),1);
    for i = 1:numel(priors)
        spz = sparse(size(priors{i}.A{1},1), par.n_p);
        v = 1;
        for j = 1:par.n_var
            if j == priors{i}.var(min(numel(priors{i}.var),v))
                if numel(priors{i}.var) > 1
                    priorcell_A{i,j} = priors{i}.W * priors{i}.A{v};
                else
                    priorcell_A{i,j} = priors{i}.W * priors{i}.A{1};
                end
                v = v + 1;
            else
                priorcell_A{i,j} = spz;
            end            
        end

        priorcell_b{i} = priors{i}.W * priors{i}.b;
    end
    prior.A = cell2mat(priorcell_A);
    prior.b = cell2mat(priorcell_b);
    
    clear priors;
    
    % Permute the A-matrix from [a1 a2 a3 ... b1 b2 b3 ... c1 c2 c3 ...]
    % into [a1 b1 c1 a2 b2 c2 a3 b3 c3 ... ] (where alphabets are variables
    % and numbers are pixel indices)
    perm = repmat((0:(par.n_var-1)),[1 par.n_p]) * par.n_p + kron(1:par.n_p, ones(1,par.n_var));
    prior.A = prior.A(:,perm);

    PRIORS_DONE = 1

    
    %% Initial guess
        
    if numel(x) == 0
        x = repmat([0.1;0.1;0.1;0.1;0.1;0.1;log(0.03);0;0;-2], [1 D.imgsize(1), D.imgsize(2)]);

        n_init = permute(n_init, [3 1 2]);
        x(8:9,:,:) = n_init; 

        x([1 3 4],:,:) = permute(init_z0, [3 1 2]);
        x([2 5 6],:,:) = permute(init_zk, [3 1 2]);

    else
        x(:,:,7) = log(max(0.00001,x(:,:,7)));
        x(:,:,10) = log(max(0.00001,x(:,:,10)));
        x = permute(x,[3 1 2]);
    end
    
    drawsol(x);
    drawnow;

    %% Levenberg-Marquardt
    % Not a particularly nice implementation, but the Matlab built in
    % wasn't quite flexible enough, and also this was much faster in some
    % cases for whatever reason. It works, but could probably be improved.
    
    v = 2;

    [JTJ, JTr, r] = comp_jtj(x,D,con_all,omega,geo,par,prior,win);
    
    mu = 1;

    e_1 = 0.000001;
    e_2 = 0.000001;
    e_3 = 0.000001;

    solves = 0;
    adjusts = 0;

    % Put initial guess into the solution iteration list
    sols{1} = permute(x,[2 3 1]);
    sols{1}(:,:,7) = exp(sols{1}(:,:,7));
    sols{1}(:,:,10) = exp(sols{1}(:,:,10));

    iter = 1;

    while iter <= iters
        solves = solves + 1;
        SOLVE_DX = 1
        tic
        d_x = -(JTJ+mu*diag(diag(JTJ))+mu*0.0001*speye(size(JTJ,1))) \ JTr;
        toc
        
        if norm(d_x) <= e_2 * norm(x(:))
            break;
        else
            x_new = x + reshape(d_x, [par.n_var D.imgsize(1), D.imgsize(2)]);
            rnew = comp_r(x_new,D,con_all,omega,geo,par,prior,win);
            [iter solves mu norm(r) norm(rnew)]
            if norm(rnew) < norm(r)
                x = x_new;
                if iter == iters 
                    break
                end
                [JTJ, JTr, r] = comp_jtj(x,D,con_all,omega,geo,par,prior,win);
                if max(abs(JTr)) < e_1 || norm(r)^2 <= e_3
                    break
                end
                mu = mu / 5;
                v = 2;
    
                iter = iter + 1;
                
                % store the current solution
                sols{iter} = permute(x,[2 3 1]);
                sols{iter}(:,:,7) = exp(sols{iter}(:,:,7));
                sols{iter}(:,:,10) = exp(sols{iter}(:,:,10));
                save('sols_run.mat','sols');
                
            else
                mu = mu*v;
                v = 2*v;
                adjusts = adjusts + 1;
                if mu > 10^6
                    break;
                end
                
            end
        end

        % just to be sure, shouldn't happen under reasonable 
        % circumstances...
        x(1:6,:,:) = max(0.000001,x(1:6,:,:));
        
        drawsol(x);
        drawnow;
        
    end
    
    OPT_DONE = [solves adjusts]

end



%% Geometric precomputation routines

function [con, par] = opt_geometric(p, geo, par, mon, mog_win)
    % Geometric per-point constants
    con.n_s = -cross(geo.Rt(:,1), geo.Rt(:,2)); % screen normal
    con.d = con.n_s'*(p-geo.t);                 % orthogonal distance from p to screen
    con.p = p;

    mog_mon{1}.mult = con.d^2 * mon.z_m;
    mog_mon{1}.mu = (con.d * [0; mon.y_m] + geo.Rt'/geo.s * (p-geo.t));
    mog_mon{1}.sigma = con.d^2 * diag(mon.s_m.^2);
    
    % Parameters of the geometric Gaussian transformation (see notes)
    g_a = geo.s/con.d;
    g_b = 1/con.d*geo.Rt.'*(geo.t-p);
    g_ct = 1/con.d^4;

    g_mult = con.d*g_ct/g_a^2;
    mog_geo{1}.mult = g_mult*1.0391;
    mog_geo{2}.mult = g_mult*0.4583;
    mog_geo{3}.mult = g_mult*1.5027;
    g_mu = -g_b/g_a;
    mog_geo{1}.mu = g_mu;
    mog_geo{2}.mu = g_mu;
    mog_geo{3}.mu = g_mu;
    g_sigma = 1/g_a^2 * eye(2);
    mog_geo{1}.sigma = g_sigma * 1.5617^2;
    mog_geo{2}.sigma = g_sigma * 0.3975^2;
    mog_geo{3}.sigma = g_sigma * 0.7138^2;

    mog_geowin = mog_mult(mog_mon,mog_geo);

    mog_winmon = mog_mult(mog_mon,mog_win);
    
    par.L = numel(mog_geowin);
    

    % Geometric constants in alternative form
    con.C_l = zeros(2,2,par.L);
    con.mu_l = zeros(2,par.L);
    con.z_l = zeros(par.L,1);    
    for l = 1:par.L
        con.C_l(:,:,l) = mog_geowin{l}.sigma;
        con.mu_l(:,l) = mog_geowin{l}.mu;
        con.z_l(l) = mog_geowin{l}.mult;
    end
    
    nv = numel(mog_winmon);
    con.C_v = zeros(2,2,nv);
    con.mu_v = zeros(2,nv);
    con.z_v = zeros(nv,1);    
    for v = 1:nv
        con.C_v(:,:,v) = mog_winmon{v}.sigma;
        con.mu_v(:,v) = mog_winmon{v}.mu;
        con.z_v(v) = mog_winmon{v}.mult;
    end

end



function gc = g_mult(ga, gb)
    csigma_i = inv(ga.sigma) + inv(gb.sigma);
    gc.sigma = inv(csigma_i);
    gc.mu = gc.sigma * (ga.sigma \ ga.mu + gb.sigma \ gb.mu);
    gc.mult = ga.mult * gb.mult * ...
        1/sqrt(det(2*pi*(ga.sigma + gb.sigma))) * ...
        exp(-0.5*(ga.mu-gb.mu)'*((ga.sigma + gb.sigma)\(ga.mu-gb.mu)));
end

function r = mog_mult(a,b)
    r = {};
    k = 1;
    for i = 1:numel(a)
        for j = 1:numel(b)
            r{k} = g_mult(a{i}, b{j});
            k = k + 1;
        end
    end
end



%% Derivative computation, model

% Compute the J^\top J -matrix and the J^\top r -vector, and residual.
% This is kind of ugly, in particular due to the way one needs to do
% parallelization in Matlab (you can't just assign to a big 2D array from
% the individual threads, because it will not be able to figure that the
% assignments are safe...)
function [JTJ, JTr, R] = comp_jtj(X,D,con_all,omega,geo,par,prior,win)
    EVAL_JTJ=1

    tic
	w = D.imgsize(1);
    
    JTJ = zeros(par.n_var, par.n_var, D.imgsize(1),D.imgsize(2));
    JTr = zeros(par.n_var, D.imgsize(1),D.imgsize(2));
    R = [];
    
    % Hack...
    i=1;j=1;r = eval_e_lsqnonlin(X(:,i,j),double([squeeze(D.DC(1,1,:));reshape(squeeze(D.Z(1,1,:,:,:)), [], 1)]),omega,geo,con_all{i,j},par,win);

    nr = numel(r);
    nv = par.n_var;
	R = zeros(nr,D.imgsize(1),D.imgsize(2));
    
    parfor i = 1:w
        concon = con_all(i,:);

        XX = X(:,i,:);
        ZZ = D.Z(i,:,:,:,:);
        DCDC = D.DC(i,:,:);
        JTr_ = zeros(nv,w);
        JTJ_ = zeros(nv,nv,w);
        R_ = zeros(nr,w);
        for j = 1:w
            con = concon{j};
            x = XX(:,:,j);
            Z = ZZ(1,j,:,:,:);
            Z = double([squeeze(DCDC(1,j,:)); Z(:)]);

            [r,J] = eval_e_lsqnonlin_diff2(x,Z,omega,geo,con,par,win);
            
            JTr_(:,j) = J'*r;
            JTJ_(:,:,j) = J'*J;
            R_(:,j) = r;            
        end
        
        JTr(:,i,:) = JTr_;
        JTJ(:,:,i,:) = JTJ_;
        R(:,i,:) = R_;
    end
    toc

    % Now apply the priors using the A and b constructed out of them
    APPLY_PRIORS = 1
    tic

    prior_R = prior.A*X(:)-prior.b;
    R = [R(:); prior_R];
    JTr = JTr(:);
    JTr = JTr + prior.A' * prior_R;
    
    JTJ = num2cell(JTJ,[1 2]);
    JTJ{1,1} = sparse(JTJ{1,1}); % make the block mat sparse
    JTJ = blkdiag(JTJ{:});
    
    JTJ = JTJ + prior.A'*prior.A;
    
    toc
end

% Just evaluate the residual without bothering about the matrices
function R = comp_r(X,D,con_all,omega,geo,par,prior,win)
    R = [];
    
    COMP_R = 1

    % Hack...
    i=1;j=1;r = eval_e_lsqnonlin(X(:,i,j),double([squeeze(D.DC(1,1,:));reshape(squeeze(D.Z(1,1,:,:,:)), [], 1)]),omega,geo,con_all{i,j},par,win);
    R = zeros(numel(r),D.imgsize(1),D.imgsize(2));
    
    nr = numel(r);
    w = D.imgsize(1);
    

    parfor i = 1:w
        concon = con_all(i,:);
        XX = X(:,i,:);
        ZZ = D.Z(i,:,:,:,:);
        DCDC = D.DC(i,:,:);
        R_ = zeros(nr,w);
        for j = 1:w
            con = concon{j};
            x = XX(:,:,j);
            Z = ZZ(1,j,:,:,:);
            Z = double([squeeze(DCDC(1,j,:)); Z(:)]);

            r = eval_e_lsqnonlin(x,Z,omega,geo,con,par,win);
            
            R_(:,j) = r;            
        end
        
        R(:,i,:) = R_;
    end
    
    R = [R(:); prior.A*X(:)-prior.b];
end

% A function for extracting the variables from the optimization vector
function vars = extract_vars(v)
    vars = struct();
    % The channel ordering is a dumb historical artifact from monochromatic
    % times.
    vars.z0 = (v(1));   % diffuse albedo R channel
    vars.z_k = (v(2));  % specular albedo R channel
    vars.cd = (v(3:4)); % diffuse albedo GB
    vars.cs = (v(5:6)); % specular albedo GB
    vars.sigma_k = exp(v(7));   % glossiness
    vars.n = v(8:9);        % normals
    vars.k = exp(v(10));    % kurtosis
end


function [fx, J] = eval_e_lsqnonlin_diff2(v, Z, omega, geo, con, par, win)
    ep = 0.000001;  % Finite difference epsilon
    [fx, fk0] = eval_e_lsqnonlin(v,Z,omega,geo,con,par,win);
    ress = numel(fx);
    
    J = zeros(ress, par.n_var);
    
    % The albedo variables (1-6) are trivial to differentiate against, as
    % they are the outermost multipliers in the formula.
    for i = 1:6
        t = v;
        t(i) = t(i) + ep;
        
        vars = extract_vars(t);
       
        fk = fk0;

        % This is a dumb way to evaluate the derivatives (doing "FD's")
        % because we could just construct the Jacobian by placing the
        % entries of fk into proper slots. However, this works just as
        % well, and is not a bottleneck by any means.
        fk = fk(:,1) * [vars.z0 vars.cd'] + ...
             fk(:,2) * [vars.z_k vars.cs'];
        fk = reshape(fk.',[],1);
        fc = Z - fk;

        J(:,i) = (par.lh_mul*([real(fc); imag(fc(4:end));]) - fx) / ep;        
    end
    
    % For glossiness, normals and kurtosis, evaluate actual finite
    % differences.
    for i = 7:10
        t = v;
        t(i) = t(i) + ep;
        
        vars = extract_vars(t);

        fk = eval_model(vars,omega,geo,con,par,win);
        % Apply albedos outside       
        fk = fk(:,1) * [vars.z0 vars.cd'] + ...
             fk(:,2) * [vars.z_k vars.cs'];
        fk = reshape(fk.',[],1);
        fc = Z - fk;

        J(:,i) = (par.lh_mul*([real(fc); imag(fc(4:end));]) - fx) / ep;
    end
end


function [f,fk0] = eval_e_lsqnonlin(v, Z, omega, geo, con, par,win)
    vars = extract_vars(v);

    fk0 = eval_model(vars,omega,geo,con,par,win);
    % Apply albedos outside.
    % fk0 contains two columns: diffuse spectrum and the specular spectrum,
    % neither with albedo or color multiplied in.
    fk = fk0(:,1) * [vars.z0 vars.cd'] + ...
         fk0(:,2) * [vars.z_k vars.cs'];
    fk = reshape(fk.',[],1);
    
    fc = Z - fk;

    f = par.lh_mul*([real(fc); imag(fc(4:end));]);
end


% Evaluate the model itself at desired frequencies, with given parameters.
% The notation here is really outdated and in general the structure is very
% messy, but it boils down to what is said in the paper.
%
% We currently also include a variable s that describes a uniform scale 
% between the screen and world coordinates (so X = s*Tx). This was omitted
% in the paper as it is unnecessary (it can be made to be 1 by adjusting
% other transformations accordingly), but removing it from the code would 
% take some careful effort.
function f = eval_model(vars, omega, geo, con, par, win)
    IM = sqrt(-1);

    n_win = numel(win.z_w);
    n_geo = numel(con.z_l);
    L = n_win*n_geo;

    % Multiply the window by the geometric terms
    % (could be moved outside, there used to be a reason for this)
    C_wl = zeros(2,2,n_win);
    mu_wl = zeros(2,n_win);
    z_wl = zeros(n_win,1);
    for w = 1:n_win
        for l = 1:n_geo
            wl = (w-1)*n_geo + l;
            
            % Here's a Gaussian multiplication spelled out in full; this is
            % actually really simple, but I doubt that Matlab would be very
            % efficient if we started encapsulating all this stuff into
            % some nice programmatical structure.
            C_wl(:,:,wl) = inv(inv(win.C_w(:,:,w)) + inv(con.C_l(:,:,l)));
            mu_wl(:,wl) = C_wl(:,:,wl) * ((win.C_w(:,:,w) \ win.mu_w(:,w)) ...
                            + (con.C_l(:,:,l) \ con.mu_l(:,l)));
            z_wl(wl) = win.z_w(w) * con.z_l(l) * ...
                  (1/sqrt(det(2*pi*(win.C_w(:,:,w)+con.C_l(:,:,l))))) * ...
                exp(-0.5*(win.mu_w(:,w)-con.mu_l(:,l))'* ...
                ((win.C_w(:,:,w)+con.C_l(:,:,l)) ...
                \(win.mu_w(:,w)-con.mu_l(:,l))));
            
        end
    end

    
    nn = [vars.n;1] / norm([vars.n;1]);   % normal on unit sphere

    % The affine term multipliers
    a_a = (nn.' * geo.R).';
    a_b = nn.' * (geo.t-con.p);
    
    e = geo.E-con.p;    % point-to-eye unit vector
    e = e/norm(e);

    
    % Reflection vector
    R3 = 2*nn*nn'*e - e;
    R1 = cross(R3,nn);
    R1 = R1 / norm(R1);
    R2 = -cross(R1,R3); % minus for left handed...
    Rt = [R1 R2];

    % Center and Jacobian of the ray hit on the screen
    c = -R3'*con.n_s;
    x0 = 1/geo.s * geo.Rt'*(con.d/c * R3 + con.p - geo.t);    
    J = con.d/(geo.s*c) * geo.Rt'*(eye(3)-(R3*con.n_s')/(R3'*con.n_s))*Rt;

    % Kurtosis handling
    mu_k = [x0 x0];
    C_k = zeros(2,2,2);

    vars.k = vars.k + 1;    

    z_k = [0.5 0.5];
    corr2 = (2/(vars.k^(-2)+vars.k^2));
 	C_k(:,:,1) = corr2*(1/vars.k*vars.sigma_k)^2*J*diag([2*e'*nn;2].^2)*J';
 	C_k(:,:,2) = corr2*(vars.k*vars.sigma_k)^2*J*diag([2*e'*nn;2].^2)*J';

    % Compute the covariances and the means of all pairs of lobe and
    % geometric Gaussian multiplications
    nv = numel(con.z_v);
    C_kv = zeros(2,2,2,nv);
    mu_kv = zeros(2,2,nv);
    z_kv = zeros(2,nv);
    for k = 1:2
        for v = 1:nv
            C_kv(:,:,k,v) = inv(inv(C_k(:,:,k)) + inv(con.C_v(:,:,v)));
            mu_kv(:,k,v) = C_kv(:,:,k,v) * ((C_k(:,:,k) \ mu_k(:,k)) + ...
                (con.C_v(:,:,v) \ con.mu_v(:,v)));
            z_kv(k,v) = z_k(k) * con.z_v(v);
        end
    end

    % Evaluate the Gaussians and affine terms at desired frequencies
    omega_n = size(omega,2);
    A_l = zeros(omega_n,L); 
    G_l = zeros(omega_n,L);
    for l = 1:L
        G_l(:,l) = z_wl(l) * exp(-IM*mu_wl(:,l).' * omega - ...
                    0.5*qfbulk(C_wl(:,:,l),omega));
        A_l(:,l) = (-IM*C_wl(:,:,l)*a_a).' * omega + ...
                    (mu_wl(:,l).' * a_a + a_b);
    end
    fd = sum(G_l.*A_l, 2) / pi; 
    
    % Evaluate speculars
    G_kv = zeros(omega_n,nv,2);
    for v = 1:nv
       for k = 1:2
            G_kv(:,k,v) = z_kv(k,v) * ...
                (1/sqrt(det(2*pi*(C_k(:,:,k)+con.C_v(:,:,v))))) * ...
                exp(-0.5*(mu_k(:,k)-con.mu_v(:,v))'* ...
                ((C_k(:,:,k)+con.C_v(:,:,v))\(mu_k(:,k)-con.mu_v(:,v)))) * ...
                exp(-IM*mu_kv(:,k,v).' * omega - 0.5*qfbulk(C_kv(:,:,k,v),omega));

       end

    end
    
    % Sum the specular Gaussian evaluations
    fs = sum(sum(G_kv,2),3);

    % Output m x 2 matrix, first column contains the diffuse evaluated at
    % the desired frequencies, the other has speculars. The final value is
    % obtained as a_d * fd + a_s * fs, where a's are albedos. They are kept
    % separate for now, because computing derivatives is much faster if we
    % know each component.
    f = [fd fs];

end


function [img, aimg] = fresnel_correction(Data, sol, z)

    w = size(sol,1);
    
    img = zeros(w,w);
    for i = 1:w
        for j = 1:w
            
            
            % Get the floor coordinates
            ix = ([i;j]-0.5)./(size(sol,1)-1);
            q = Data.T_img_to_floor * [ix(2);ix(1);1];
            p = [q(1:2)/q(3);0];

            E = -Data.cam.R' * Data.cam.T;
            e = E - p;
            e = e/norm(e);
            
            nn = [sol(i,j,8);sol(i,j,9);1];
            n = nn / norm(nn);
 
            img(i,j) = n'*e;

        end

    end
    aimg = img;
    img = z + (1-z) * (1-img).^5;
end

%% UTILITIES

% Evaluates a quadratic form at a bunch of points
function q = qfbulk(A,x)
    q = sum(x.*(A*x),1);
end

% Derivative matrices
function [Dx,Dy] = D_mats(n)
    Dxc = zeros(n^2,3);
    Dyc = zeros(n^2,3);
    
    row = 1;
    for i = 2:n
        for j = 1:n
            Dxc(row,:) = [sub2ind([n n],i,j), sub2ind([n n],i-1,j), -1];
            row = row + 1;
            ent = [sub2ind([n n],i,j), sub2ind([n n],i,j), 1];
            Dxc(row,:) = ent;
            row = row + 1;
        end
    end
    
    row = 1;
    for i = 1:n
        for j = 2:n
            Dyc(row,:) = [sub2ind([n n],i,j), sub2ind([n n],i,j-1), -1];
            row = row + 1;
            ent = [sub2ind([n n],i,j), sub2ind([n n],i,j), 1];
            Dyc(row,:) = ent;
            row = row + 1;
        end
    end
    
    Dx = spconvert(Dxc);
    Dy = spconvert(Dyc);
end


%% VISUALIZTION

function F = part_ifft2(freqs, S, n)
    F = zeros(n,1);
    X = linspace(-pi,pi,n)';
    F = F + S(1);
    for i = 1:numel(freqs)
        w = freqs(i);
        F = F + exp(sqrt(-1)*w*X) * S(i+1);
        F = F + exp(-sqrt(-1)*w*X) * conj(S(i+1));
    end
end

function drawsol(img)
    subplot(2,4,1);
    imagec(0.8*(cat(3, squeeze(((img(1,:,:)))),squeeze((img(3,:,:))),squeeze((img(4,:,:))))))
    subplot(2,4,2);
    imagec(0.8*(cat(3, squeeze(((img(2,:,:)))),squeeze((img(5,:,:))),squeeze((img(6,:,:))))))
    subplot(2,4,3);
    image(sqrt(exp(squeeze((img(10,:,:)))))*32)
    subplot(2,4,4);
    subplot(2,4,7);
    imagepn((squeeze((img(8,:,:)))))
    subplot(2,4,8);
    imagepn((squeeze((img(9,:,:)))))
    subplot(2,4,5);
    image(sqrt(exp(squeeze((img(7,:,:)))))*100)
    subplot(2,4,4);

    image_nmap(real(min(10,max(-10,squeeze((img(8,:,:)))))),real(min(10,max(-10,squeeze((img(9,:,:)))))),2);

    colormap gray
    drawnow
end

