# this script analyses the data of the single atom-fwm experiment 2015-16 und reproduces the published plots and numbers.
# the analysis is based on four histogram files. 
#
# transmission files:
# non_reversed_with_alternating_09_26_onwards_combined_histo_tx_2ns
# reversed_with_alternating_09_26_onwards_combined_histo_tx_2ns
#
# reflections files:
# non_reversed_with_alternating_09_26_onwards_combined_histo_rx_5ns
# reversed_with_alternating_09_26_onwards_combined_histo_rx_5ns_4

import numpy as np 
import matplotlib.pyplot as plt
from scipy.integrate import ode

lifetime = 26.2348e-9 
gamma_0 = 1 / lifetime # atom decay rate
tau_p = 13.3e-9 # probe photons decay time
gamma_p = 1/tau_p
overlap = 0.033 # spatial overlap factor; 1 = 100% overlap with dipole pattern, 0 =  no overlap

print(" -------------------------- transmission --------------------------")
print()

time_offset = 879 # time offset used for all transmission data
time_offset_index = 439 # index for time offset used for all transmission data


# first part deals with the heralding efficiency, the transmission delta and the extinction
 
print(" -------------------------- decaying photon --------------------------")
timebin = 2 # 2ns. time bin size for transmission histogram
data_fromfile = np.genfromtxt('non_reversed_with_alternating_09_26_onwards_combined_histo_tx_2ns')
trigcounts_1 = 135534347  # total number of trigger, obtained from file header
trigbgcounts_1 = 179327264 # total number of trigger in background phase , obtained from file header


#time bins:
t_data_1 = data_fromfile[:,1]
# conicidences per trigger, atom in trap:
with_atom_1 = data_fromfile[:,4]
# statiscal uncertainty (squareroot N counting errror):
err_with_atom_1 = data_fromfile[:,5]
# accidental events, two time intervals to determine background, total length 150bins=300ns
# the accidental interval were chosen to be far away from the photon but within the AOM open-interval
acc_window_start_1 = 750
acc_window_end_1 = 800
acc_window_start_2 = 1024
acc_window_end_2 = 1274
acc_window_number_of_bins = 150
with_atom_bg_1 = np.mean( np.concatenate( (  with_atom_1[acc_window_start_1/timebin:acc_window_end_1/timebin], with_atom_1[acc_window_start_2/timebin:acc_window_end_2/timebin]) ) ) # 
print("decaying photon, accidentals per trigger per time bin, with atom: "+ '%.9f'  %  with_atom_bg_1) 
#-------
#-------
# conicidences per trigger, no atom in trap:
without_atom_1 = data_fromfile[:,8]
# statiscal uncertainty:
err_without_atom_1 = data_fromfile[:,9]
# accidental events
without_atom_bg_1 =  np.mean( np.concatenate( ( without_atom_1[acc_window_start_1/timebin:acc_window_end_1/timebin], without_atom_1[acc_window_start_2/timebin:acc_window_end_2/timebin]) ) ) #
without_atom_bg_1_err = (np.concatenate( ( err_without_atom_1[acc_window_start_1/timebin:acc_window_end_1/timebin], err_without_atom_1[acc_window_start_2/timebin:acc_window_end_2/timebin]) ))**2 
print("decaying photon, accidentals per trigger per time bin, without atom: "+ '%.9f'  %  without_atom_bg_1) 

## decaying pulse starts at t=881ns
t_start_1 = without_atom_1.argmax()
print("decaying photon, time of maximum (ns): ", t_data_1[t_start_1]) 

# range of summing for eta_f and deltas. sum from -14ns to +100ns, total length 57bins=114ns
summing_range_short = 7 # sum 7 bins = 14ns from t<0 for the decaying photon and vice versa 
summing_range_long = 50 # sum 50 bins = 50ns from t>0 for the decaying photon and vice versa 
sumrange_min_1 = time_offset_index - summing_range_short
sumrange_max_1 = time_offset_index + summing_range_long
sumrange_number_of_bins = sumrange_max_1-sumrange_min_1

window_ratio = (summing_range_long+summing_range_short)/acc_window_number_of_bins
#heralding efficiency eta_f: 
eta_f_1 = sum( without_atom_1[sumrange_min_1:sumrange_max_1] - without_atom_bg_1 )
eta_f_1_error = np.sqrt( sum( (err_without_atom_1[sumrange_min_1:sumrange_max_1])**2 ) + window_ratio**2*sum(without_atom_bg_1_err) )
print("decaying photon, heralding efficiency eta_f, acc corrected: {:.8f} +/- {:.8f}".format(eta_f_1,eta_f_1_error) )

### correcting for collection efficiency (46+/-1??? %) and detector efficiency (51.9+/-.1??? %) <--- now change to 52+/-1
eta_f_1_tilde = eta_f_1/(0.46*0.52)
eta_f_1_tilde_error = eta_f_1_tilde * np.sqrt( (eta_f_1_error/eta_f_1)**2 + (1.0/46.0)**2 + (1.0/52.0)**2 )
print("decaying photon, efficiency corrected for collection/detection eta_f_1_tilde, acc corrected: {:.8f} +/- {:.8f}".format(eta_f_1_tilde,eta_f_1_tilde_error) )


delta_1 = ( (without_atom_1-without_atom_bg_1) - (with_atom_1-with_atom_bg_1) ) / ( eta_f_1 * timebin*1e-9 )
err_delta_1 = np.sqrt( err_with_atom_1**2 + err_without_atom_1**2 ) / ( eta_f_1 * timebin*1e-9 )
# plot and print results
plt.errorbar( (t_data_1-time_offset), delta_1*1e-6, yerr=err_delta_1*1e-6 , color='red', fmt='o' )

extinction_1 = sum(delta_1[sumrange_min_1:sumrange_max_1]) *timebin*1e-9
print("decaying photon, extinction, acc corrected: "+ '%.5f'  %  extinction_1)
#plt.show()

## error 
counts_with_atom_1 = sum( data_fromfile[sumrange_min_1:sumrange_max_1,2] )
counts_without_atom_1 = sum( data_fromfile[sumrange_min_1:sumrange_max_1,6] )
acc_with_atom_1 = sum( np.concatenate( (  data_fromfile[acc_window_start_1/timebin:acc_window_end_1/timebin,2], data_fromfile[acc_window_start_2/timebin:acc_window_end_2/timebin,2]) ) )
acc_without_atom_1 = sum( np.concatenate( (  data_fromfile[acc_window_start_1/timebin:acc_window_end_1/timebin,6], data_fromfile[acc_window_start_2/timebin:acc_window_end_2/timebin,6]) ) )

error_on_tx = (1-extinction_1)*np.sqrt( (counts_with_atom_1+ (sumrange_number_of_bins/acc_window_number_of_bins)**2*acc_with_atom_1 )/ ((counts_with_atom_1-(sumrange_number_of_bins/acc_window_number_of_bins)**2*acc_with_atom_1)**2 ) + (counts_without_atom_1+ (sumrange_number_of_bins/acc_window_number_of_bins)**2*acc_without_atom_1 )/ ((counts_without_atom_1-(sumrange_number_of_bins/acc_window_number_of_bins)**2*acc_without_atom_1)**2 ) + 1/trigcounts_1 + 1/trigbgcounts_1  ) 
print("decaying photon, standard deviation of extinction: " + '%.5f'  % error_on_tx)

#print("decaying photon, accidentals per bin, with atom: " + '%.1f'  % (acc_with_atom_1/150) )
print("decaying photon, standard deviation of accidentals per bin per trigger, with atom: " + '%.10f'  % (np.sqrt(acc_with_atom_1)/acc_window_number_of_bins/trigcounts_1) )
#print("decaying photon, accidentals per bin, without atom: " + '%.1f'  % (acc_without_atom_1/150) )
print("decaying photon, standard deviation of accidentals per bin per trigger, without atom: " + '%.10f'  % (np.sqrt(acc_without_atom_1)/acc_window_number_of_bins/trigbgcounts_1) )
#---------------
#------------------
print()
print(" -------------------------- rising photon --------------------------")
# rising photon !!!! :
data_fromfile = np.genfromtxt('reversed_with_alternating_09_26_onwards_combined_histo_tx_2ns')
trigcounts_2 = 124327751  ## total number of trigger in atom phase , obtained from file header
trigbgcounts_2 = 163613920# # total number of trigger in background phase , obtained from file header

#time bins:
t_data_2 = data_fromfile[:,1]
# conicidences per trigger, atom in trap:
with_atom_2 = data_fromfile[:,4]
# statiscal uncertainty:
err_with_atom_2 = data_fromfile[:,5]
# accidental events, one timeinterval to determine background
acc_window_start_3 = 970
acc_window_end_3 = 1270
with_atom_bg_2 = np.mean(  with_atom_2[acc_window_start_3/timebin:acc_window_end_3/timebin] ) # 
print("rising photon, accidentals per trigger per time bin, with atom: "+ '%.9f'  %  with_atom_bg_2) 

#-------
#-------
# conicidences per trigger, no atom in trap:
without_atom_2 = data_fromfile[:,8]
# statiscal uncertainty:
err_without_atom_2 = data_fromfile[:,9]
# accidental events
without_atom_bg_2 = np.mean(  without_atom_2[acc_window_start_3/timebin:acc_window_end_3/timebin] ) # 
print("rising photon, accidentals per trigger per time bin, without atom: "+ '%.9f'  %  without_atom_bg_2) 
## rising pulse starts at t=877ns
t_start_2 = without_atom_2.argmax()
print("rising photon, time of maximum (ns): ", t_data_2[t_start_2]) 

# range of summing for eta_f and deltas. sum from -100ns to +14ns
sumrange_min_2 = time_offset_index - summing_range_long
sumrange_max_2 = time_offset_index + summing_range_short

#heralding efficiency eta_f: 
eta_f_2 = sum( without_atom_2[sumrange_min_2:sumrange_max_2] - without_atom_bg_2 )  
print("rising photon, heralding efficiency eta_f, acc corrected: " + '%.5f'  % eta_f_2)

#data_diff_1 = (without_atom_1-without_atom_bg_1)-(with_atom_1-with_atom_bg_1)
delta_2 = ( (without_atom_2-without_atom_bg_2) - (with_atom_2-with_atom_bg_2) ) / ( eta_f_2 * timebin*1e-9 )
err_delta_2 = np.sqrt( err_with_atom_2**2 + err_without_atom_2**2 ) / ( eta_f_2 * timebin*1e-9 )
# plot and print results
plt.errorbar((t_data_2-time_offset),delta_2*1e-6, yerr=err_delta_2*1e-6 , color='blue', fmt='o' )
plt.xlim([-100,100])
extinction_2 =(sum(delta_2[sumrange_min_2:sumrange_max_2])*timebin*1e-9)
print("rising photon, extinction, acc corrected: "+ '%.5f'  % extinction_2  )
plt.xlabel('time since herald (ns)')
plt.ylabel('delta (10^6 1/s) ')
plt.title('accidental corrected transmission delta; \n shown in figure 4 of paper')


## error 
counts_with_atom_2 = sum( data_fromfile[sumrange_min_2:sumrange_max_2,2] )
counts_without_atom_2 = sum( data_fromfile[sumrange_min_2:sumrange_max_2,6] )
acc_with_atom_2 = sum( np.concatenate( (  data_fromfile[acc_window_start_2/timebin:acc_window_end_2/timebin,2], data_fromfile[acc_window_start_2/timebin:acc_window_end_2/timebin,2]) ) )
acc_without_atom_2 = sum( np.concatenate( (  data_fromfile[acc_window_start_2/timebin:acc_window_end_2/2,6], data_fromfile[acc_window_start_2/timebin:acc_window_end_2/timebin,6]) ) )

error_on_tx = (1-extinction_2)*np.sqrt( (counts_with_atom_2+ (sumrange_number_of_bins/acc_window_number_of_bins)**2*acc_with_atom_2 )/ ((counts_with_atom_2-(sumrange_number_of_bins/acc_window_number_of_bins)**2*acc_with_atom_2)**2 ) + (counts_without_atom_2+ (sumrange_number_of_bins/acc_window_number_of_bins)**2*acc_without_atom_2 )/ ((counts_without_atom_2-(sumrange_number_of_bins/acc_window_number_of_bins)**2*acc_without_atom_2)**2 ) + 1/trigcounts_2 + 1/trigbgcounts_2  ) 
print("rising photon, standard deviation of extinction: " + '%.5f'  % error_on_tx)

#print("rising photon, accidentals per bin, with atom: " + '%.1f'  % (acc_with_atom_2/150) )
print("rising photon, standard deviation of accidentals per bin per trigger, with atom: " + '%.10f'  % (np.sqrt(acc_with_atom_2)/acc_window_number_of_bins/trigbgcounts_2) )
#print("rising photon, accidentals per bin, without atom: " + '%.1f'  % (acc_without_atom_2/150) )
print("rising photon, standard deviation of accidentals per bin per trigger, without atom: " + '%.10f'  % (np.sqrt(acc_without_atom_2)/acc_window_number_of_bins/trigbgcounts_2) )

# theory for P_e(t) and deltas

def Theta(x): #heaviside function; 
    return 1 * (x >= 0)

def decaying_exp(t):
    return Theta(t)*np.exp(-gamma_p*(t)/2)

def rising_exp(t):
    return Theta(-t)*np.exp(gamma_p*(t)/2)
    
def p_decay(t):
    return Theta(t)*4*overlap*gamma_0*gamma_p / (gamma_p-gamma_0)**2 * ( np.exp(-gamma_0/2*t) - np.exp(-gamma_p/2*t) )**2 

def p_rise(t):
    return Theta(-t)* 4*overlap*gamma_0*gamma_p / (gamma_p+gamma_0)**2 *np.exp(gamma_p*t)  + Theta(t)* 4*overlap*gamma_0*gamma_p / (gamma_p+gamma_0)**2 *np.exp(-gamma_0*t)

def diff_decay(t):
    return 2*np.sqrt( overlap*gamma_0*gamma_p*p_decay(t) ) * decaying_exp(t) - overlap*gamma_0*p_decay(t)
    
def diff_rise(t):
    return 2*np.sqrt( overlap*gamma_0*gamma_p*p_rise(t))*rising_exp(t) - overlap*gamma_0*p_rise(t)

t_theory = np.arange(-200,200,0.1)*1e-9 # time axis for theory curves
# plot delta
plt.plot(t_theory*1e9, diff_decay(t_theory)*1e-6,color='red')
plt.plot(t_theory*1e9, diff_rise(t_theory)*1e-6,color='blue')

#----------------
# some analysis of the influence of the integration interval on the extinction value => sum 100ns is good

#theory_extinction_decay = np.zeros(len(t_theory)-1)
#theory_extinction_rise = np.zeros(len(t_theory)-1)
#for k in range(0,len(t_theory)-1):
#    theory_extinction_decay[k] = sum(diff_decay(t_theory[0:k+1]))*(0.1*1e-9)
#    theory_extinction_rise[k] = sum(diff_rise(t_theory[k+1:]))*(0.1*1e-9)

#running_extinction_1 = np.zeros(200)
#running_extinction_2 = np.zeros(200)
#for h in range(0,200):
#    running_extinction_1[h] =(sum(delta_1[sumrange_min_1:time_offset_index+h])*timebin*1e-9) 
#    running_extinction_2[h] =(sum(delta_2[time_offset_index-h:sumrange_max_2])*timebin*1e-9)    
#plt.figure()
#plt.plot(t_theory[:-1]*1e9,theory_extinction_decay,color='red')
#plt.plot(-t_theory[:-1]*1e9,theory_extinction_rise,color='blue')
#plt.title('extinction dependence on integration window \n theory + data')

#t_running_extinction = np.arange(0,200)*2 # 2ns bins
#plt.plot(t_running_extinction,running_extinction_1,'o', color='red')
#plt.plot(t_running_extinction,running_extinction_2,'o', color='blue')
#plt.xlim([0,200])


# second part derives P_e from transmission histogram
print('')
print('P_e from transmission:')

# chop the time intervals for ODE. we choose -220ns to +220ns
time_p_e_range_min_1 = time_offset_index - 110
time_p_e_range_max_1 = time_offset_index + 110
time_p_e_range = t_data_1[time_p_e_range_min_1:time_p_e_range_max_1] - time_offset
# slice the deltas in the interval
delta_1_pe = delta_1[time_p_e_range_min_1:time_p_e_range_max_1]*1e-9 # times 1e-9 because ODE is done in ns-timescale 
err_delta_1_pe = err_delta_1[time_p_e_range_min_1:time_p_e_range_max_1]*1e-9
delta_2_pe = delta_2[time_p_e_range_min_1:time_p_e_range_max_1]*1e-9
err_delta_2_pe = err_delta_2[time_p_e_range_min_1:time_p_e_range_max_1]*1e-9

def delta_decay(t):
    return np.interp(t, time_p_e_range, delta_1_pe)
    
def delta_rise(t):
    return np.interp(t, time_p_e_range, delta_2_pe)

def dummy_delta_decay(t):
    return np.interp(t, time_p_e_range, dummy_delta_1_pe)
    
def dummy_delta_rise(t):
    return np.interp(t, time_p_e_range, dummy_delta_2_pe)


def my_int(f, time_v, dt, gamma_0, overlap):
    """
    Routine to integrate the atom interaction
    differential equation from the extintion data
    """
    def atom(t, y):
        dydt = f(t) - gamma_0 * (1-overlap) * y
        return dydt
    y = np.zeros(len(time_v))
    y0, t0 = 0, np.min(time_v)
    sol = ode(atom).set_integrator('dopri5')
    sol.set_initial_value(y0, t0)
    # y = []
    for k in range(len(time_v)):
        y[k] = sol.integrate(sol.t + dt)[0]
    return np.roll(y,1)
    
# Integration of the ODE. single run
#y_d = my_int(delta_decay, time_p_e_range, timebin, gamma_0*1e-9, overlap)
#y_r = my_int(delta_rise, time_p_e_range, timebin, gamma_0*1e-9, overlap)

#print('decaying photon, P_e,max : ' + '%.4f'  % np.max(y_d))
#print('rising photon, P_e,max : ' + '%.4f'  % np.max(y_r))
#print('ratio, P_e,max rising/decay: ' + '%.3f'  % (np.max(y_r)/np.max(y_d)) ) 

""" Bootstrap evaluation of the error of P_e """
print('Bootstrap evaluation of the error of P_e. hang on this takes about 30sec')
reps = 300
# 
decays = np.zeros((reps, len(time_p_e_range)))
rises = np.zeros((reps, len(time_p_e_range)))
dummy_delta_1_pe = np.zeros( len(time_p_e_range))
dummy_delta_2_pe = np.zeros( len(time_p_e_range))
for k in range(reps):
    # produce new data set
    for i in range(0,len(time_p_e_range)):
        dummy_delta_1_pe[i] = np.random.normal(delta_1_pe[i],err_delta_1_pe[i])
        dummy_delta_2_pe[i] = np.random.normal(delta_2_pe[i],err_delta_2_pe[i])
    decays[k, :] =  my_int(dummy_delta_decay, time_p_e_range, timebin, gamma_0*1e-9, overlap)
    rises[k, :] =  my_int(dummy_delta_rise, time_p_e_range, timebin, gamma_0*1e-9, overlap)
# get averaged curves and errors    
decay = np.mean(decays, 0)
decay_err = np.std(decays, 0)
rise = np.mean(rises, 0)
rise_err = np.std(rises, 0)

print('decaying photon, P_e,max : ' + '%.4f ( %.4f )'  % (np.max(decay), decay_err[decay.argmax()] ) )
print('rising photon, P_e,max : ' +  '%.4f ( %.4f )'  % (np.max(rise), rise_err[rise.argmax()] ) )
pe_ratio = np.max(rise)/np.max(decay)
pe_ratio_err = ( np.sqrt( (rise_err[rise.argmax()]/np.max(decay))**2 + (decay_err[decay.argmax()]*np.max(rise)/np.max(decay)**2)**2  ) )
print('ratio, P_e,max rising/decay: ' +  '%.3f ( %.3f )'  % ( pe_ratio, pe_ratio_err) )

"""
Saving
"""
with open('ODE_rising.dat', 'w') as f:
    f.write('#time\tPe\tPe_err\n')
    
    [f.write('{:.1f}\t{:.5e}\t{:.5e}\n'
             ''.format(t, P, P_err))
     for t, P, P_err
     in zip(time_p_e_range, rise, rise_err)]

with open('ODE_decaying.dat', 'w') as f:
    f.write('#time\tPe\tPe_err\n')
    [f.write('{:.1f}\t{:.5e}\t{:.5e}\n'
             ''.format(t, P, P_err))
     for t, P, P_err
     in zip(time_p_e_range, decay, decay_err)]
    
""" Plotting and results """
plt.figure()
plt.errorbar(time_p_e_range, decay, yerr=decay_err, fmt= 'o',color='red')
plt.errorbar(time_p_e_range, rise, yerr=rise_err, fmt='o',color='blue') 
plt.xlabel('time since herald (ns)')
plt.ylabel('P_e ')
plt.title('excited state population P_e from transmission; \n shown in figure 5 of paper')
plt.xlim([-200,200])

plt.plot(t_theory*1e9, p_decay(t_theory),color='red')
plt.plot(t_theory*1e9, p_rise(t_theory),color='blue')

print()
print()


# now P_e from the reflection histogram
print(" -------------------------- Reflection --------------------------")
print()
raw_rx_decaying = np.genfromtxt('non_reversed_with_alternating_09_26_onwards_combined_histo_rx_5ns')
raw_rx_rising = np.genfromtxt('reversed_with_alternating_09_26_onwards_combined_histo_rx_5ns_4')

timebin_rx = 5 # 5ns time bins for reflection histogram
eta_f_tilde = eta_f_1/(0.46*0.519) #overall eff for two shapes, coupling of tx mode into fiber = 46% (fiber coupling + optical path losses), efficiency of tx APD = 51.9%
rx_det_eff = 0.0078*0.56/0.6186 # 0.0078 was from saturation curve measurement done with APD Tauhuay [eff 61.86%]. Expt APD eff is 56%
rx_scalefactor=gamma_0* eta_f_tilde *rx_det_eff*timebin_rx*1e-9
rx_decaying_offset= np.mean(raw_rx_decaying[220:320, 4] )
rx_decaying_trigger = 135534347 # from file header
rx_decaying_offset_err= np.sqrt( rx_decaying_offset) / np.sqrt( (rx_decaying_trigger*len(raw_rx_decaying[220:320, 4])) )

print('decaying photon, accidentals per trigger per bin: ' + '%.9f ( %.10f )' % (rx_decaying_offset,rx_decaying_offset_err) )
rx_rising_offset= np.mean(raw_rx_rising[220:320, 4] )
rx_rising_trigger = 124327751 # from file header
rx_rising_offset_err= np.sqrt( rx_rising_offset) / np.sqrt( (rx_rising_trigger*len(raw_rx_rising[220:320, 4])) )
print('rising photon, accidentals per trigger per bin: '+ '%.9f ( %.10f )' % (rx_rising_offset,rx_rising_offset_err))

P_r = (raw_rx_rising[:, 4] - rx_rising_offset) / rx_scalefactor
P_r_err = np.sqrt( (raw_rx_rising[:, 5]**2 + rx_rising_offset_err**2) ) / rx_scalefactor
time_pr = (raw_rx_rising[:, 1] - (time_offset + 12))
P_d = (raw_rx_decaying[:, 4] - rx_decaying_offset) / rx_scalefactor
P_d_err = np.sqrt( (raw_rx_decaying[:, 5]**2 + rx_decaying_offset_err**2) ) / rx_scalefactor
time_pd = (raw_rx_decaying[:, 1] - (time_offset + 12))

plt.figure()
plt.errorbar(time_pd, P_d, yerr=P_d_err, color='red',fmt='o')
plt.errorbar(time_pr, P_r, yerr=P_r_err,color='blue', fmt='o')
plt.plot(t_theory*1e9, p_decay(t_theory),color='red')
plt.plot(t_theory*1e9, p_rise(t_theory),color='blue')
plt.xlabel('time since herald (ns)')
plt.ylabel('P_e ')
plt.title('excited state population P_e from reflection; \n shown in figure 5 of paper')
plt.xlim([-200,200])
plt.show() 

