# this script analyses data from the single atom - 4 wave mixing experiment performed in 2015-2016
# input are the five transmission histogram:
# 
# non_reversed_with_alternating_09_26_onwards_combined_histo_tx_1ns
# 2015_11_11-11_17_non_reversed_combined_histo_tx_1ns
# 2015_10_30-11_03_non_reversed_combined_histo_tx_1ns
# 2015_11_03-11_06_non_reversed_combined_histo_tx_1ns
# 2015_11_06-11_11_non_reversed_combined_histo_tx_1ns
#

import numpy as np 
import matplotlib.pyplot as plt
from lmfit import minimize, Parameters, fit_report
from scipy.integrate import ode

lifetime = 26.2348 # in ns
gamma_0 = 1 / (lifetime*1e-9) # atom decay rate
overlap = 0.033# spatial overlap factor; 1 = 100% overlap with dipole pattern, 0 =  no overlap

timebin = 1 # 1ns. time bin size for transmission histogram
t_offset_for_fit = 2 # start decay time fit a few bins from max

tau = np.zeros(5)
tau_err = np.zeros(5)
tx = np.zeros(5)
tx_err = np.zeros(5)
eta = np.zeros(5)
eta_err = np.zeros(5)
pe_max =  np.zeros(5)
pe_max_err = np.zeros(5)
pe_max_time = np.zeros(5)

def delta_decay(t):
    return np.interp(t, time_p_e_range, delta_pe)
    
def dummy_delta_decay(t):
    return np.interp(t, time_p_e_range, dummy_delta_pe)

def my_int(f, time_v, dt, gamma_0, overlap):
    """
    Routine to integrate the atom interaction
    differential equation from the extintion data
    """
    def atom(t, y):
        dydt = f(t) - gamma_0 * (1-overlap) * y
        return dydt
    y = np.zeros(len(time_v))
    y0, t0 = 0, np.min(time_v)
    sol = ode(atom).set_integrator('dopri5')
    sol.set_initial_value(y0, t0)
    # y = []
    for k in range(len(time_v)):
        y[k] = sol.integrate(sol.t + dt)[0]
    return np.roll(y,1)

def pemax(tau_p,tau_0):
    return 4*overlap*(tau_0/tau_p)**((tau_0+tau_p)/(tau_p-tau_0))

def Theta(x): #heaviside function; 
    return 1 * (x >= 0)

def p_decay(t,tau_p):
    gamma_p = 1/(tau_p*1e-9)
    return Theta(t)*4*overlap*gamma_0*gamma_p / (gamma_p-gamma_0)**2 * ( np.exp(-gamma_0/2*t) - np.exp(-gamma_p/2*t) )**2 

### loop through the five files  				
i=0
for filename in ['non_reversed_with_alternating_09_26_onwards_combined_histo_tx_1ns',
                 '2015_11_11-11_17_non_reversed_combined_histo_tx_1ns',
                 '2015_10_30-11_03_non_reversed_combined_histo_tx_1ns',
                 '2015_11_03-11_06_non_reversed_combined_histo_tx_1ns',
                 '2015_11_06-11_11_non_reversed_combined_histo_tx_1ns']:
    data_fromfile = np.genfromtxt(filename)
    trigcounts = data_fromfile[0][2]  # total number of trigger, obtained from file header
    trigbgcounts = data_fromfile[0][6]  # total number of trigger in background phase , obtained from file header
    data_fromfile =  np.genfromtxt(filename,skip_header=2)
        
    #time bins:
    t_data = data_fromfile[:,1]
    # conicidences per trigger, atom in trap:
    with_atom = data_fromfile[:,4]
    # statiscal uncertainty (squareroot N counting errror):
    err_with_atom = data_fromfile[:,5]
    # accidental events, two time intervals to determine background, total length 300ns
    # the accidental interval were chosen to be far away from the photon but within the AOM open-interval
    acc_window_start_1 = 750
    acc_window_end_1 = 800
    acc_window_start_2 = 1024
    acc_window_end_2 = 1274
    acc_window_number_of_bins = 300
    with_atom_bg = np.mean( np.concatenate( (  with_atom[acc_window_start_1/timebin:acc_window_end_1/timebin], with_atom[acc_window_start_2/timebin:acc_window_end_2/timebin]) ) ) # 
        #-------
    #-------
    # conicidences per trigger, no atom in trap:
    without_atom = data_fromfile[:,8]
    # statiscal uncertainty:
    err_without_atom = data_fromfile[:,9]
    # accidental events
    without_atom_bg =  np.mean( np.concatenate( ( without_atom[acc_window_start_1/timebin:acc_window_end_1/timebin], without_atom[acc_window_start_2/timebin:acc_window_end_2/timebin]) ) ) #
    without_atom_bg_err = (np.concatenate( ( err_without_atom[acc_window_start_1/timebin:acc_window_end_1/timebin], err_without_atom[acc_window_start_2/timebin:acc_window_end_2/timebin]) ))**2 
    
    ## decaying pulse starts at ...
    t_start = without_atom.argmax()
        
    # fit bg data to get decay times
    # define fit parameter
    params=Parameters()
    params.add('amp',value=without_atom[without_atom.argmax()],vary=True, min=0.0) # amplitude
    params.add('y0',value=np.mean(without_atom),vary=True,min=0.0) # offset
    params.add('tau',value=9,vary=True,min=0.0) # time constant
    params.add('t0',value=t_start, vary=False) # end/start point of exponential
    
    # chop interesting part of trace
    fit_range = 100
    t0_index = without_atom.argmax() + t_offset_for_fit
    data_ROI = without_atom[t0_index:t0_index+fit_range]
    data_error_ROI = err_without_atom[t0_index:t0_index+fit_range]
    t_data_ROI = t_data[t0_index:t0_index+fit_range]
    #
    #fit function and error function
    def fit_function_decay_time(p,x):
        return params['amp'].value * np.exp( (params['t0'].value -x )/params['tau'].value ) + params['y0'].value
        
    def fit_residual_decay_time(p, x, y):
        return fit_function_decay_time(p,x)-y  

    # fit
    fitout = minimize(fit_residual_decay_time,params,args=(t_data_ROI, data_ROI ))
    
    # range of summing for eta_f and tx
    summing_range_short = 10 #*round(params['tau'].value)  
    summing_range_long = 100#8*round(params['tau'].value)  # sum from -1 tau to +8 tau for extinction
    sumrange_min = t_start - summing_range_short
    sumrange_max = t_start + summing_range_long
    sumrange_number_of_bins = sumrange_max-sumrange_min
    
    window_ratio = (summing_range_long+summing_range_short)/acc_window_number_of_bins
    #heralding efficiency eta_f: 
    eta_f = sum( without_atom[sumrange_min:sumrange_max] - without_atom_bg )
    eta_f_error = np.sqrt( sum( (err_without_atom[sumrange_min:sumrange_max])**2 ) + window_ratio**2*sum(without_atom_bg_err) )
    
    delta = ( (without_atom-without_atom_bg) - (with_atom-with_atom_bg) ) / ( eta_f * timebin*1e-9 )
    err_delta = np.sqrt( err_with_atom**2 + err_without_atom**2 ) / ( eta_f * timebin*1e-9 )
    extinction = sum(delta[sumrange_min:sumrange_max]) *timebin*1e-9
    
    ## error 
    counts_with_atom = sum( data_fromfile[sumrange_min:sumrange_max,2] )
    counts_without_atom = sum( data_fromfile[sumrange_min:sumrange_max,6] )
    acc_with_atom = sum( np.concatenate( (  data_fromfile[acc_window_start_1/timebin:acc_window_end_1/timebin,2], data_fromfile[acc_window_start_2/timebin:acc_window_end_2/timebin,2]) ) )
    acc_without_atom = sum( np.concatenate( (  data_fromfile[acc_window_start_1/timebin:acc_window_end_1/timebin,6], data_fromfile[acc_window_start_2/timebin:acc_window_end_2/timebin,6]) ) )
    
    error_on_tx = (1-extinction)*np.sqrt( (counts_with_atom+ (sumrange_number_of_bins/acc_window_number_of_bins)**2*acc_with_atom )/ ((counts_with_atom-(sumrange_number_of_bins/acc_window_number_of_bins)**2*acc_with_atom)**2 ) + (counts_without_atom+ (sumrange_number_of_bins/acc_window_number_of_bins)**2*acc_without_atom )/ ((counts_without_atom-(sumrange_number_of_bins/acc_window_number_of_bins)**2*acc_without_atom)**2 ) + 1/trigcounts + 1/trigbgcounts  ) 
    
    error_on_acc_atom = (np.sqrt(acc_with_atom)/acc_window_number_of_bins/trigcounts) 
    error_on_acc_without_atom = (np.sqrt(acc_without_atom)/acc_window_number_of_bins/trigbgcounts)
    

    print()
    print(" -------------------- " + filename + " ----------------")
    print()
    print("time of maximum (ns): ", t_data[t_start]) 
    print("heralding efficiency eta_f, acc corrected: {:.4} +/- {:.2}".format(eta_f,eta_f_error) )
    print("extinction, acc corrected:  {:.4} +/- {:.4}".format(extinction,error_on_tx))
    print("accidentals per trigger per time bin, with atom:  {:.3} +/- {:.2}".format(with_atom_bg,error_on_acc_atom) ) 
    print("accidentals per trigger per time bin, without atom:  {:.3} +/- {:.2}".format(without_atom_bg,error_on_acc_without_atom)) 
    print("Number of trigger with atom : {:.2} ".format(trigcounts))    
    print("Number of trigger without atom : {:.2} ".format(trigbgcounts))  
    print()
    print('fit report:')
    print(fit_report(fitout.params))
    print()

				
    # second part derives P_e from transmission histogram
    print('')
    print('P_e from transmission:')

    # chop the time intervals for ODE. we choose -20ns to +150ns
    time_p_e_range_min = t_start - 20
    time_p_e_range_max = t_start + 150
    time_p_e_range = t_data[time_p_e_range_min:time_p_e_range_max] - t_start
    # slice the deltas in the interval
    delta_pe = delta[time_p_e_range_min:time_p_e_range_max]*1e-9 # times 1e-9 because ODE is done in ns-timescale 
    err_delta_pe = err_delta[time_p_e_range_min:time_p_e_range_max]*1e-9

    """ Bootstrap evaluation of the error of P_e """
    print('Bootstrap evaluation of the error of P_e. hang on this takes about 10sec')
    reps = 300
    # 
    decays = np.zeros((reps, len(time_p_e_range)))
    dummy_delta_pe = np.zeros( len(time_p_e_range))
    for k in range(reps):
        # produce new data set
        for h in range(0,len(time_p_e_range)):
            dummy_delta_pe[h] = np.random.normal(delta_pe[h],err_delta_pe[h])
        decays[k, :] =  my_int(dummy_delta_decay, time_p_e_range, timebin, gamma_0*1e-9, overlap)
    # get averaged curves and errors    
    decay = np.mean(decays, 0)
    decay_err = np.std(decays, 0)

    print('decaying photon, P_e,max : ' + '%.4f ( %.4f )'  % (np.max(decay), decay_err[decay.argmax()] ) )
        
    """ Plotting and results """
    plt.figure()
    plt.errorbar(time_p_e_range, decay, yerr=decay_err, fmt= 'o',color='red')
    plt.xlabel('time since herald (ns)')
    plt.ylabel('P_e ')
    plt.title('excited state population P_e from transmission')
    plt.xlim([-20,150])
				
    t_theory = np.arange(-200,200,0.1)*1e-9 # time axis for theory curves
    plt.plot(t_theory*1e9-2, p_decay(t_theory,params['tau'].value ),color='blue')
    
    with open('ODE_'+str(int(params['tau'].value ))+'.dat', 'w') as f:
        f.write('#time\tPe\tPe_err\n')
        
        [f.write('{:.1f}\t{:.5e}\t{:.5e}\n'
                 ''.format(t, P, P_err))
        for t, P, P_err
        in zip(time_p_e_range, decay, decay_err)]
    	#store tx values and tau fit:
        tx[i] = extinction
        tx_err[i] = error_on_tx
        tau[i] = params['tau'].value 
        tau_err[i] = params['tau'].stderr 
        eta[i] = eta_f
        eta_err[i] = eta_f_error
        pe_max[i] = np.max(decay)
        pe_max_err[i] = decay_err[decay.argmax()]
        pe_max_time[i] = time_p_e_range[np.argmax(decay)]
        i=i+1

## end of file-loop

print()
print('summary:')
for i in range(0,5):
    print("tau: {:.3} +/- {:.2},   tx: {:.3} +/- {:.2},  eta_f:  {:.3} +/- {:.2}".format(tau[i],tau_err[i],tx[i],tx_err[i],eta[i],eta_err[i]))    
  
plt.figure()
plt.xlim([0,7])
plt.ylim([0,0.12])
plt.errorbar(lifetime/tau,tx,yerr=tx_err, fmt='ro', label='data')

tau_p=np.arange(1,300,0.2)
theory = overlap*(1-overlap) * 4 *tau_p/ (tau_p+lifetime)
#plt.plot(lifetime/tau_p,theory, label='Lambda = 0.033')

# define fit parameter
params_tx=Parameters()
params_tx.add('overlap_fit',value=0.033,vary=True, min=0.0) # amplitude

#fit function and error function
def fit_function_tx(p,x):
	return params_tx['overlap_fit'].value * (1-params_tx['overlap_fit'].value) * 4 *x/ (x+lifetime)

def fit_function_pemax(p,x):
	return 4* params_tx['overlap_fit'].value*(lifetime/x)**((lifetime+x)/(x-lifetime))
	
def fit_residual_tx(p, x, y):
	resid = 0.0*y[:]
	resid[0, :] = fit_function_tx(p,x)-y[0,:] 
	resid[1, :] = fit_function_pemax(p,x)-y[1,:] 
	return resid.flatten()  
# fit
data_for_fit = np.array([tx,pe_max])
fitout_tx = minimize(fit_residual_tx,params_tx,args=(tau, data_for_fit ))
# get fitted curve
fitarray_tx = params_tx['overlap_fit'].value * (1-params_tx['overlap_fit'].value) * 4 *tau_p/ (tau_p+lifetime)
plt.plot(lifetime/tau_p,fitarray_tx, label="fit, Lambda = {:.3}".format(params_tx['overlap_fit'].value ) )
plt.legend(shadow=True, fancybox=True)
plt.xlabel('bandwidth ratio: Gamma_p / Gamma_0')
plt.ylabel('transmission extinction')
print()
print('overlap fit:')
print(fit_report(fitout_tx.params))
print()

"""
Saving
"""
with open('tx_safwm_summary.dat', 'w') as f:
    f.write('#tau\ttau_err\ttx\ttx_err\tpemax\tpemax_err\teta\teta_err\tpemax_time\n')
    
    [f.write('{:.5e}\t{:.5e}\t{:.5e}\t{:.5e}\t{:.5e}\t{:.5e}\t{:.5e}\t{:.5e}\t{:.5e}\n'
             ''.format(tau, tau_err, tx, tx_err, pe_max, pe_max_err, eta, eta_err, pe_max_time))
     for tau, tau_err, tx, tx_err, pe_max, pe_max_err, eta, eta_err, pe_max_time
     in zip(tau, tau_err, tx, tx_err, pe_max, pe_max_err, eta, eta_err, pe_max_time)]

plt.figure()
plt.xlim([0,7])
plt.ylim([0,0.02])
plt.errorbar(lifetime/tau,pe_max,yerr=pe_max_err, fmt='ro', label='data')
overlap=params_tx['overlap_fit'].value
plt.plot(lifetime/tau_p,pemax(tau_p,lifetime)  )
plt.legend(shadow=True, fancybox=True)
plt.xlabel('bandwidth ratio: Gamma_p / Gamma_0')
plt.ylabel('P_e max')
plt.show()
