clear all;close all;clc;
model_dir='marmousi2_151_401_10m';

% load the initial velocity model
load(['./model/',model_dir,'/v1.mat']);
% load the acquisition and marked zone 
load(['./model/',model_dir,'/acquisition.mat']);
load(['./model/',model_dir,'/vel_mark.mat']);

% Define the iteration number
nit=10;

v1=vel0;misfit=zeros(nit+1,1);
pur=0.05;vmin=1500;vmax=3550;

parallel_init;
gk=zeros(nz,nx);dk=zeros(nz,nx);

%Iteration Starts
for it=1:nit
    tic;
    display(['Iteration ',num2str(it)]);
    gk1=zeros(nz,nx,ns);illum=zeros(nz,nx,ns);
    res0=zeros(ns,1);
    % Calculate the new misfit functional and the gradient
    parfor is=1:ns
        [seis,bc_top,bc_bottom,bc_left,bc_right,bc_p_nt,bc_p_nt_1]...
            =a2d_mod_abc24(v1,nbc,dx,nt,dt,s,sx(is),sz(is),gx,gz,isFS);
        csg_true=read_bin(['./data/',model_dir,'/csg_refl_',num2str(is),'.bin'],nt,ng);
        csg_dir=read_bin(['./data/',model_dir,'/csg_dir_',num2str(is),'.bin'],nt,ng);
        data_res=seis-csg_dir-csg_true;
        res0(is)=sum(data_res(:).*data_res(:));
        [gk1(:,:,is),illum(:,:,is)]=a2d_rtm_abc24(data_res,v1,nbc,dx,nt,dt,s,sx,sz,gx,gz,...
            bc_top,bc_bottom,bc_left,bc_right,bc_p_nt,bc_p_nt_1);
    end
    res0=sum(res0);misfit(it)=res0;
    display(['misfit = ',num2str(res0)]);
    gk1=sum(gk1,3);
    illum=sum(illum,3);
    gk1=gk1./illum;
    gk1=sm_mark0(gk1,mark);
    write_bin(['./results/',model_dir,'/gradient_',num2str(it),'.bin'],gk1);
    % Calculate the conjugate direction
    dk1=conjugate_direction(it,gk1,gk,dk);
    
    % Setup the trial step length
    ss=1.0./v1;
    s_mean=(sum(ss(:).*ss(:)))^0.5;
    g_mean=(sum(dk1(:).*dk1(:)))^0.5;
    alpha=s_mean/g_mean*pur;
    display(['s_mean=',num2str(s_mean),' g_mean=',num2str(g_mean),' alpha=',num2str(alpha)]);
    
    % Back tracking to find the numerical step length
    f1=0.5;
    ss1=ss+alpha*f1*dk1;
    v1=1./ss1;
    v1(v1<vmin)=vmin;v1(v1>vmax)=vmax;
    res1=zeros(ns,1);
    parfor is=1:ns
        seis=a2d_mod_abc24(v1,nbc,dx,nt,dt,s,sx(is),sz(is),gx,gz,isFS);
        csg_true=read_bin(['./data/',model_dir,'/csg_refl_',num2str(is),'.bin'],nt,ng);
        csg_dir=read_bin(['./data/',model_dir,'/csg_dir_',num2str(is),'.bin'],nt,ng);
        data_res=seis-csg_dir-csg_true;
        res1(is)=sum(data_res(:).*data_res(:));
    end
    res1=sum(res1);
    display(['f1= ',num2str(f1),' res1= ',num2str(res1)]);
    if res1>res0
        while res1>res0 && f1>0.0001
            f2=f1; res2=res1;
            f1=f1*0.5;
            ss1=ss+alpha*f1*dk1;
            v1=1./ss1;
            v1(v1<vmin)=vmin;v1(v1>vmax)=vmax;
            res1=zeros(ns,1);
            parfor is=1:ns
                seis=a2d_mod_abc24(v1,nbc,dx,nt,dt,s,sx(is),sz(is),gx,gz,isFS);
                csg_true=read_bin(['./data/',model_dir,'/csg_refl_',num2str(is),'.bin'],nt,ng);
                csg_dir=read_bin(['./data/',model_dir,'/csg_dir_',num2str(is),'.bin'],nt,ng);
                data_res=seis-csg_dir-csg_true;
                res1(is)=sum(data_res(:).*data_res(:));
            end
            res1=sum(res1);
            display(['f1= ',num2str(f1),' res1= ',num2str(res1)]);
        end
    else
        f2=f1*2;
        ss1=ss+alpha*f2*dk1;
        v1=1./ss1;
        v1(v1<vmin)=vmin;v1(v1>vmax)=vmax;
        res2=zeros(ns,1);
        parfor is=1:ns
            seis=a2d_mod_abc24(v1,nbc,dx,nt,dt,s,sx(is),sz(is),gx,gz,isFS);
            csg_true=read_bin(['./data/',model_dir,'/csg_refl_',num2str(is),'.bin'],nt,ng);
            csg_dir=read_bin(['./data/',model_dir,'/csg_dir_',num2str(is),'.bin'],nt,ng);
            data_res=seis-csg_dir-csg_true;
            res2(is)=sum(data_res(:).*data_res(:));
        end
        res2=sum(res2);
        display(['f2= ',num2str(f2),' res2= ',num2str(res2)]);
    end
    gama=(f1^2*(res0-res2)+f2^2*(res1-res0))/(2*res0*(f1-f2)+2*res1*f2-2*res2*f1);
    display(['gama= ',num2str(gama),' numerical step_length= ',num2str(gama*alpha)]);
    ss1=ss+alpha*gama*dk1;
    v1=1./ss1;
    v1(v1<vmin)=vmin;v1(v1>vmax)=vmax;
    res3=zeros(ns,1);
    parfor is=1:ns
        seis=a2d_mod_abc24(v1,nbc,dx,nt,dt,s,sx(is),sz(is),gx,gz,isFS);
        csg_true=read_bin(['./data/',model_dir,'/csg_refl_',num2str(is),'.bin'],nt,ng);
        csg_dir=read_bin(['./data/',model_dir,'/csg_dir_',num2str(is),'.bin'],nt,ng);
        data_res=seis-csg_dir-csg_true;
        res3(is)=sum(data_res(:).*data_res(:));
    end
    res3=sum(res3);
    display(['res3= ',num2str(res3)]);
    if (res3>res1 || res3>res2)
        if res1>res2
            res0=res2;
            lamta=f2;
        else
            res0=res1;
            lamta=f1;
        end
        ss1=ss+alpha*gama*dk1;
        v1=1./ss1;
        v1(v1<vmin)=vmin;v1(v1>vmax)=vmax;
    else
        res0=res3;
    end
    
    % Refresh
    gk=gk1;
    dk=dk1;
    write_bin(['./results/',model_dir,'/vel_',num2str(it),'.bin'],v1);
    toc;
end
misfit(nit+1)=res0;

parallel_stop;