import numpy as np 
import matplotlib.pyplot as plt
import argparse 

#V2:changed threshold to use sum of the two instead of individual post selection
#V3:changed threshold method to keeping 0.5%(default) or X%
#added bad case option

parser = argparse.ArgumentParser(description=' script ... ')
parser.add_argument('-i','--input', help='Input file name',required=True)#parsing for the input file
parser.add_argument('-p','--plot', help='plot data and fit; 0: plot; 1: no plot', type=int,default = 0)# plot yes or no
parser.add_argument('-ht','--histogram', help='show histogram distrbution', type=int,default = 0)
parser.add_argument('-th','--threshold', help='threshold ratio for postselecting', type=float,default = 0.5)
parser.add_argument('-tm','--threshold_mode', help='threshold mode: good(default) or bad', type=str,default = 'good')
#parser.add_argument('-ot','--overtime', help='show transmission over time', type=int,default = 0)

# read in parser arguments and input file
args = parser.parse_args()
data_fromfile = np.genfromtxt(args.input)

bg_corr_blue = 0.38
bg_corr_red = 0.275
# ----- init ------
blue_atom_1 = data_fromfile[:,0] - bg_corr_blue
red_atom_1 = data_fromfile[:,1] - bg_corr_red
blue_bg_1 = data_fromfile[:,2] - bg_corr_blue
red_bg_1 = data_fromfile[:,3] - bg_corr_red
blue_atom_2= data_fromfile[:,4] - bg_corr_blue
red_atom_2 = data_fromfile[:,5] - bg_corr_red
blue_bg_2 = data_fromfile[:,6] - bg_corr_blue
red_bg_2 = data_fromfile[:,7] - bg_corr_red
atombin = data_fromfile[:,8] #atombin number for each atom

#-----probe cycle 1------#
avg_blue_atom_1 = np.mean(blue_atom_1)
avg_blue_bg_1 = np.mean(blue_bg_1)
tx_blue_1 = avg_blue_atom_1/avg_blue_bg_1
tx_blue_error_1 = np.sqrt( (np.sqrt(sum(blue_atom_1))/sum(blue_bg_1) )**2 + (sum(blue_atom_1)*np.sqrt(sum(blue_bg_1))/sum(blue_bg_1)**2 )**2 )

avg_red_atom_1 = np.mean(red_atom_1)
avg_red_bg_1 = np.mean(red_bg_1)
tx_red_1 = avg_red_atom_1/avg_red_bg_1
tx_red_error_1 = np.sqrt( (np.sqrt(sum(red_atom_1))/sum(red_bg_1) )**2 + (sum(red_atom_1)*np.sqrt(sum(red_bg_1))/sum(red_bg_1)**2 )**2 )

std_blue_atom_1 = np.std(blue_atom_1)
std_blue_bg_1 = np.std(blue_bg_1)
std_red_atom_1 = np.std(red_atom_1)
std_red_bg_1 = np.std(red_bg_1)

print("" )
print("-----------" )
print("first probe cycle" )
print("	blue        |	red" )
print("atom1	"+'%.1f'% avg_blue_atom_1+" ( "+'%.1f'% std_blue_atom_1+" )   | "+'%.1f'%avg_red_atom_1+" ( "+'%.1f'% std_red_atom_1+" )" )
print("bg1	"+'%.1f'% avg_blue_bg_1+" ( "+'%.1f'% std_blue_bg_1+" )   | "+'%.1f'%avg_red_bg_1+" ( "+'%.1f'% std_red_bg_1+" )" )
print("tx1	"+'%.3f'% tx_blue_1+" ( "+'%.3f'% tx_blue_error_1+" )| "+'%.3f'%tx_red_1+" ( "+'%.3f'% tx_red_error_1+" )" )

#-----tx over time for probe 1------#
i=0
step=600
tx_bin_red_1=[]
tx_bin_red_error_1=[]
tx_bin_blue_1=[]
tx_bin_blue_error_1=[]
bin_blue_bg_1=[]
bin_blue_bg_error_1=[]
bin_red_bg_1=[]
bin_red_bg_error_1=[]
while i+step+1 < len(red_atom_1) :
	tx_bin_red_1.append(red_atom_1[i:i+step].mean()/red_bg_1[i:i+step].mean())
	tx_bin_blue_1.append(blue_atom_1[i:i+step].mean()/blue_bg_1[i:i+step].mean())
	tx_bin_blue_error_1.append( np.sqrt( (np.sqrt(sum(blue_atom_1[i:i+step]))/sum(blue_bg_1[i:i+step]) )**2 + (sum(blue_atom_1[i:i+step])*np.sqrt(sum(blue_bg_1[i:i+step]))/sum(blue_bg_1[i:i+step])**2 )**2 ) )
	tx_bin_red_error_1.append( np.sqrt( (np.sqrt(sum(red_atom_1[i:i+step]))/sum(red_bg_1[i:i+step]) )**2 + (sum(red_atom_1[i:i+step])*np.sqrt(sum(red_bg_1[i:i+step]))/sum(red_bg_1[i:i+step])**2 )**2 ) )
	bin_blue_bg_1.append(blue_bg_1[i:i+step].mean())
	bin_blue_bg_error_1.append(np.sqrt( blue_bg_1[i:i+step].sum())/step )
	bin_red_bg_1.append(red_bg_1[i:i+step].mean())
	bin_red_bg_error_1.append(np.sqrt( red_bg_1[i:i+step].sum())/step )
	i = i + step + 1

#-----probe cycle 2------#
avg_blue_atom_2 = np.mean(blue_atom_2)
avg_blue_bg_2 = np.mean(blue_bg_2)
tx_blue_2 = avg_blue_atom_2/avg_blue_bg_2
tx_blue_error_2 = np.sqrt( (np.sqrt(sum(blue_atom_2))/sum(blue_bg_2) )**2 + (sum(blue_atom_2)*np.sqrt(sum(blue_bg_2))/sum(blue_bg_2)**2 )**2 )

avg_red_atom_2 = np.mean(red_atom_2)
avg_red_bg_2 = np.mean(red_bg_2)
tx_red_2 = avg_red_atom_2/avg_red_bg_2
tx_red_error_2 = np.sqrt( (np.sqrt(sum(red_atom_2))/sum(red_bg_2) )**2 + (sum(red_atom_2)*np.sqrt(sum(red_bg_2))/sum(red_bg_2)**2 )**2 )

std_blue_atom_2 = np.std(blue_atom_2)
std_blue_bg_2 = np.std(blue_bg_2)
std_red_atom_2 = np.std(red_atom_2)
std_red_bg_2 = np.std(red_bg_2)

print("" )
print("second probe cycle" )
print("	blue        |	red" )
print("atom2	"+'%.1f'% avg_blue_atom_2+" ( "+'%.1f'% std_blue_atom_2+" )   | "+'%.1f'%avg_red_atom_2+" ( "+'%.1f'% std_red_atom_2+" )" )
print("bg2	"+'%.1f'% avg_blue_bg_2+" ( "+'%.1f'% std_blue_bg_2+" )   | "+'%.1f'%avg_red_bg_2+" ( "+'%.1f'% std_red_bg_2+" )" )
print("tx2	"+'%.3f'% tx_blue_2+" ( "+'%.3f'% tx_blue_error_2+" )| "+'%.3f'%tx_red_2+" ( "+'%.3f'% tx_red_error_2+" )" )

print("" )
print("total atombins: "+'%d'%len(blue_atom_1))

print("" ) 
print("--- post selection: ----" )
#thresholding 

th_atom = blue_atom_2 + red_atom_2
total_atombins = len(blue_atom_1)
#threshold = np.arange(th-3.9,th+25,1)*1.0
thresholdsteps=40
thresholdoffset = 7
if args.threshold_mode == 'good':
	threshold = np.linspace(min(th_atom)+thresholdoffset,max(th_atom)/1.8,thresholdsteps)
else:
	threshold = np.linspace(max(th_atom)-thresholdoffset,max(th_atom)/1.5,thresholdsteps)
tx_post_blue = 0*threshold
tx_post_err_blue = 0*threshold
tx_post_red = 0*threshold
tx_post_err_red = 0*threshold
tx_post_sum = 0*threshold
tx_post_err_sum = 0*threshold
tx_post_atombins = 0*threshold
best_th_idx = 0

for j in range(0,len(threshold)):
	post_blue = []
	post_blue_bg = []
	post_red = []
	post_red_bg =[]
	post_sum = []
	post_sum_bg =[]

	for i in range(0,total_atombins-1):
		if args.threshold_mode == 'good':
			if th_atom[i] < threshold[j]:
				post_blue.append(blue_atom_1[i])
				post_blue_bg.append(blue_bg_1[i])
				post_red.append(red_atom_1[i])
				post_red_bg.append(red_bg_1[i])
				post_sum.append(blue_atom_1[i]+red_atom_1[i])
				post_sum_bg.append(blue_bg_1[i]+red_bg_1[i])
		else:
			if th_atom[i] > threshold[j]:
				post_blue.append(blue_atom_1[i])
				post_blue_bg.append(blue_bg_1[i])
				post_red.append(red_atom_1[i])
				post_red_bg.append(red_bg_1[i])
				post_sum.append(blue_atom_1[i]+red_atom_1[i])
				post_sum_bg.append(blue_bg_1[i]+red_bg_1[i])

	tx_post_atombins[j] = len(post_blue)*1.0
	#evaluate tx
	post_blue = np.array(post_blue)
	post_blue_bg = np.array(post_blue_bg)
	avg_post_blue = np.mean(post_blue)
	avg_post_blue_bg = np.mean(post_blue_bg)
	tx_post_blue[j] = avg_post_blue/avg_post_blue_bg
	tx_post_err_blue[j] = np.sqrt( (np.sqrt(sum(post_blue))/sum(post_blue_bg) )**2 + (sum(post_blue)*np.sqrt(sum(post_blue_bg))/sum(post_blue_bg)**2 )**2 )

	post_red = np.array(post_red)
	post_red_bg = np.array(post_red_bg)
	avg_post_red = np.mean(post_red)
	avg_post_red_bg = np.mean(post_red_bg)
	tx_post_red[j] = avg_post_red/avg_post_red_bg
	tx_post_err_red[j] = np.sqrt( (np.sqrt(sum(post_red))/sum(post_red_bg) )**2 + (sum(post_red)*np.sqrt(sum(post_red_bg))/sum(post_red_bg)**2 )**2 )

	tx_post_sum[j] = np.mean(post_sum)/np.mean(post_sum_bg)
	tx_post_err_sum[j] = np.sqrt( (np.sqrt(sum(post_sum))/sum(post_sum_bg) )**2 + (sum(post_sum)*np.sqrt(sum(post_sum_bg))/sum(post_sum_bg)**2 )**2 )

	#if j<best_th_sum_idx and tx_post_err_sum[j]<0.01:
	if j>best_th_idx and tx_post_atombins[j]/total_atombins*100 < args.threshold:
		best_th_idx = j


print("post tx "+'%.3f'% tx_post_blue[best_th_idx]+" ( "+'%.3f'% tx_post_err_blue[best_th_idx]+" ) | "+'%.3f'%tx_post_red[best_th_idx]+" ( "+'%.3f'% tx_post_err_red[best_th_idx]+" ) " ) 
print("sum "+'%.3f'% tx_post_sum[best_th_idx]+" ( "+'%.3f'% tx_post_err_sum[best_th_idx]+" ) " ) 
print("threshold: "+'%.3f'%threshold[best_th_idx]+ "( "+'%.3f'%(tx_post_atombins[best_th_idx]/total_atombins*100)+" %)")
print("" )


##--------------------plotting------------------
post_fig = plt.figure(3)
ax = post_fig.add_subplot(111)
ax.set_xscale('log')
plt.xlabel('threshold %')
plt.ylabel('transmission')
plt.title('postselected transmission')
plt.errorbar(tx_post_atombins/total_atombins*100, tx_post_blue,  yerr=tx_post_err_blue, color='b')
plt.errorbar(tx_post_atombins/total_atombins*100, tx_post_red,  yerr=tx_post_err_red, color='r')
plt.errorbar(tx_post_atombins/total_atombins*100, tx_post_sum,  yerr=tx_post_err_sum, color='k')
#plt.ylim(0.65,0.8)
plt.savefig(args.input+'_post'+'.pdf', bbox_inches='tight')
	
if args.plot == 1:
	post_fig.show()

#overtime
overtime_fig = plt.figure(4)
plt.subplot(211)
plt.errorbar(range(0,len(tx_bin_red_1)), tx_bin_red_1, yerr=tx_bin_red_error_1,  color='r')
plt.errorbar(range(0,len(tx_bin_blue_1)), tx_bin_blue_1, yerr=tx_bin_blue_error_1, color='b')
plt.xlim(0,len(tx_bin_red_1))
#plt.xlabel('every 600 atombins (approx 1 min)')
plt.ylabel('transmission')
plt.title('transmission and probe power over atombins')
#plt.savefig(args.input+'_tx_time'+'.pdf', bbox_inches='tight')

plt.subplot(212)
plt.xlim(0,len(bin_red_bg_1))
#power_overtime_fig = plt.figure(5)
plt.errorbar(range(0,len(bin_red_bg_1)), bin_red_bg_1, yerr=bin_red_bg_error_1,  color='r',fmt='')
plt.errorbar(range(0,len(bin_blue_bg_1)), bin_blue_bg_1, yerr=bin_blue_bg_error_1,  color='b',fmt='')
plt.xlabel('every 600 atombins (approx 1 min)')
plt.ylabel('transmitted photons')
#plt.title('probe power over atombins')
plt.savefig(args.input+'_time'+'.pdf', bbox_inches='tight')
#power_overtime_fig.show()
#overtime_fig.show()
if args.plot == 1:
	plt.show()

#histogram	
if args.histogram == 1:
	binwidth = 1
	binrange=  np.arange(0,300,binwidth)
	hist_blue_atom_1, bins_blue_atom_1 = np.histogram(blue_atom_1, binrange)
	hist_blue_bg_1, bins_blue_bg_1 = np.histogram(blue_bg_1, binrange)
	width_blue = binwidth*0.7
	center_blue = (bins_blue_atom_1[:-1] + bins_blue_atom_1[1:]) / 2
	hist_red_atom_1, bins_red_atom_1 = np.histogram(red_atom_1,binrange)
	hist_red_bg_1, bins_red_bg_1 = np.histogram(red_bg_1, binrange)
	width_red = binwidth*0.7
	center_red = (bins_red_atom_1[:-1] + bins_red_atom_1[1:]) / 2

	blue_fig = plt.figure(1)
	plt.bar(center_blue, hist_blue_atom_1, align='center', width=width_blue,alpha = 0.5)
	plt.bar(center_blue, hist_blue_bg_1, align='center', width=width_blue, color='red', alpha = 0.5)
	plt.xlabel('counts')
	plt.ylabel('# of occurances')
	plt.title('BLUE detector')
	plt.savefig(args.input+'_blue_hist'+'.pdf', bbox_inches='tight')
	red_fig = plt.figure(2)
	plt.xlabel('counts')
	plt.ylabel('# of occurances')
	plt.title('RED detector')
	plt.bar(center_red, hist_red_atom_1, align='center', width=width_red,alpha = 0.5)
	plt.bar(center_red, hist_red_bg_1, align='center', width=width_red, color='red',alpha = 0.5)
	plt.savefig(args.input+'_red_hist'+'.pdf', bbox_inches='tight')
	if args.plot == 1:
		blue_fig.show()
		red_fig.show()
