import numpy as np 
import matplotlib.pyplot as plt
import glob
from lmfit import minimize, Parameters, fit_report

numberofbins = 80 # number of probe pulse bins 
bg_bins = [81,82,83] # bis for bg correction

i=0
tx_matrix = np.zeros((len(glob.glob("*.dat")), len(np.genfromtxt(glob.glob("*.dat")[0])[:,0]))  )
scat_matrix = 0*tx_matrix
ac_scat_matrix = 0*tx_matrix
filelist =glob.glob("*.dat")

filelist.sort()
freqlist= np.zeros(len(filelist))
atombin_list= np.zeros(len(filelist))
for file in filelist:
	data_fromfile = np.genfromtxt(file)
	t_data = data_fromfile[:,0]
	# counts with atom
	with_atom = data_fromfile[:,1]
	# counts without atom
	without_atom = data_fromfile[:,2]
	bg = np.array( [with_atom[bg_bins],without_atom[bg_bins]] ).mean()
	with_atom = with_atom - bg
	without_atom = without_atom - bg
	tx_matrix[i,:] =  (with_atom)/without_atom
	freqlist[i] = int(file[0:6])
	atombin_list[i] = np.genfromtxt(file,comments='$', skip_footer=len(tx_matrix[0,:]))[-1]
	scat_matrix[i,:] = (without_atom - with_atom)/atombin_list[i]
	for k in range(0,len(scat_matrix[i,:])):
		ac_scat_matrix[i,k] = sum(scat_matrix[i,0:k])
	i=i+1
	    

freqmin = (min(freqlist)-190000)*2*1e-3
freqmax = (max(freqlist)-190000)*2*1e-3
#freqticks = np.arange(int(freqmin),freqmax, int( (freqmax-freqmin)/3))
freqticks = [40,50]
timeticks = [0,10,20,30]
#extent = [freqmin, freqmax,len(tx_matrix[0,:]),0 ]
extent = [freqmin, freqmax,40,0 ]

fig1=plt.figure()
plt.imshow(100*tx_matrix[:,0:numberofbins].transpose(), extent=extent, cmap='RdBu',interpolation='none',vmin=82, vmax=100,aspect=1 )
plt.xticks(freqticks)
plt.yticks(timeticks)
plt.colorbar(orientation='vertical',label='rel. transmission (%)', ticks=[82,90,100])
#plt.title('extinction over frequency and time')
plt.xlabel('rel. probe frequency (MHz)')
plt.ylabel('probe time (ms)')
fig1.savefig('tx_matrix.pdf', bbox_inches='tight')
#plt.figure()
#plt.imshow(ac_scat_matrix.transpose(), extent=extent, cmap='gnuplot2',interpolation='none')
#plt.xticks(freqticks, fontsize=9)
#plt.colorbar(orientation='vertical')

############# create a new matrix tx over frequency and over # of photons scattered
scat_step_size=9.832
max_scat = 400

tx_matrix_scat= np.zeros((len(glob.glob("*.dat")),int(max_scat/scat_step_size)))

for k in range(0,len(glob.glob("*.dat"))):
	for i in range(0,int(max_scat/scat_step_size)):
		tx_matrix_scat[k,i] = tx_matrix[k,np.where((ac_scat_matrix[k,0:numberofbins]/scat_step_size).round()==i)].mean()

tx_matrix_scat[np.isnan(tx_matrix_scat)] = 1	
#extent = [freqmin, freqmax,int(max_scat/scat_step_size), 0 ]
extent = [freqmin, freqmax,30*scat_step_size/0.65/0.56, 0 ]
scat_ticks = [0,300, 600]
fig2= plt.figure()
plt.imshow(100*tx_matrix_scat[:,0:31].transpose(), extent=extent, cmap='RdBu',interpolation='none',vmin=82, vmax=100,aspect=0.034  )
plt.xticks(freqticks)
plt.yticks(scat_ticks)
plt.colorbar(orientation='vertical', label ='rel. transmission (%)', ticks=[82,90,100])
plt.xlabel('rel. probe frequency (MHz)')
plt.ylabel('# of scattered photons')
#plt.title('extinction over frequency and # of photons scattered')
fig2.savefig('tx_matrix_scat.pdf', bbox_inches='tight')

# extract tx min and resonance freq from tx_matrix_scat
tx_min = np.zeros(int(max_scat/scat_step_size))
freq_tx_min = np.zeros(int(max_scat/scat_step_size))
for k in range(0,int(max_scat/scat_step_size)):
	tx_min[k] = (1-tx_matrix_scat[:,k]).max()
	freq_tx_min[k] = (1-tx_matrix_scat[:,k]).argmax()*0.6 +freqmin

###########
def fit_residual(params, x, data, err=None):
	model = ( 1 - params['tx'].value / ( 4*( (params['x0'].value-x) / (params['gamma'].value) )**2 + 1 ) 
		+ params['dis'].value * (params['x0'].value-x) / ( 4*( (params['x0'].value-x) / (params['gamma'].value) )**2 + 1 ) )
	residual = model - data
	if err != None:
		return residual/err
	return residual
	
fitout=[]
fitout_list = []
tx_min_fit =  np.zeros(int(max_scat/scat_step_size))
tx_err =  np.zeros(int(max_scat/scat_step_size))
freq_tx_min_fit = np.zeros(int(max_scat/scat_step_size))
freq_err  = np.zeros(int(max_scat/scat_step_size))
for k in range(0,int(max_scat/scat_step_size)):
	#fit tx only when there are at least four entries in row	
	if sum(tx_matrix_scat[:,k]<1)>4:
		tx_for_fit = tx_matrix_scat[np.where(tx_matrix_scat[:,k]<1),k]
		freq_for_fit = np.array(np.where(tx_matrix_scat[:,k]<1))*0.6 +freqmin
		
		params=Parameters()
		params.add('x0',value=44,min=40,max=50,vary=True) #
		params.add('tx',value=0.15,vary=True) #
		params.add('gamma',value=7,min=6,max=15,vary=True) #
		params.add('dis',value=0.002,max=0.01,min=-0.01,vary=True) #
		
		fitout.append(minimize(fit_residual,params,args=(freq_for_fit[0],tx_for_fit[0])))
		x_array = np.linspace(freq_for_fit.min(),freq_for_fit.max(),100)
		fitarray =  ( 1 - fitout[k].params['tx'].value / ( 4*( (fitout[k].params['x0'].value-x_array) / (fitout[k].params['gamma'].value) )**2 + 1 ) + fitout[k].params['dis'].value * (fitout[k].params['x0'].value-x_array) / ( 4*( (fitout[k].params['x0'].value-x_array) / (fitout[k].params['gamma'].value) )**2 + 1 ) )
		min_trans =  min(fitarray)
		tx_min_fit[k] = 1-min_trans
		tx_err[k] = fitout[k].params['tx'].stderr
		freq_tx_min_fit[k]= fitout[k].params['x0'].value
		freq_err[k] = fitout[k].params['x0'].stderr
		#print(fitout[k].params['dis'].value)
###
plt.figure()
col_eff = 0.65
apd_eff = 0.56 #updatd from 0.52 as of 17/08/2016
op_path = 0.90
scat_photons_range = np.arange(0,max_scat,scat_step_size) / (col_eff*apd_eff*op_path)
scat_photons_range = np.round(scat_photons_range+scat_photons_range[1]/2)[0:len(scat_photons_range)-1]

#plt.plot(scat_photons_range,100*tx_min,'o')
plt.errorbar(scat_photons_range,100*tx_min_fit,yerr= 100*tx_err,  color='blue', fmt='o')
plt.ylim([8,19])
plt.xlim([-5,780])
plt.ylabel('extinction %')
plt.xlabel('# of scattered photons')
plt.title('extinction')
np.savetxt('scat_tx',np.array([scat_photons_range,100*tx_min_fit,100*tx_err]).T,fmt='%.3f')

plt.figure()
#plt.plot(scat_photons_range,freq_tx_min,'o', color='red')
plt.errorbar(scat_photons_range,freq_tx_min_fit,yerr=freq_err, color='blue', fmt='o')
plt.ylabel('resonance freq (MHz)')
plt.xlabel('# of scattered photons')
plt.ylim([44,48])
plt.xlim([-5,780])
plt.title('resonance freq shift')
np.savetxt('scat_resfreq',np.array([scat_photons_range,freq_tx_min_fit,freq_err]).T,fmt='%.2f')

#save_tx_matrix = np.zeros((tx_matrix.size,3 ))
#for i in range(0,len(tx_matrix[:,0])):
#	for k in range(0,len(tx_matrix[0,:])):
#		index= i*len(tx_matrix_scat[0,:]) + k 
#		save_tx_matrix[index,0] = freqmin+0.6*i
#		save_tx_matrix[index,1] = 0.5*k
#		save_tx_matrix[index,2] = tx_matrix[i,k]
#np.savetxt("tx_matrix.csv", save_tx_matrix)
#np.savetxt("tx_matrix.csv", tx_matrix[:,0:80].transpose())

outfile = open('tx_matrix','w')
for i in range(0,len(tx_matrix[:,0])):
	print >> outfile, ''
	for k in range(0,len(tx_matrix[0,0:numberofbins])):
		#if abs(tx_matrix[i,k]) > 1 or tx_matrix[i,k] < 0 :
		#	tx = 1
		#else:
		#	tx = tx_matrix[i,k]
		tx = tx_matrix[i,k]
		print >> outfile, freqmin+0.6*i, 0.5*k, tx
		#print(freqmin+0.6*i, 0.5*k, tx_matrix[i,k],file=outfile)
	#print('',file=outfile)
outfile.close()
print('extinction:', tx_matrix[:,0].min())
#plt.show()
