import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
import os
import re
from lmfit import Model, Parameters
from lmfit.models import ConstantModel

data_folder=os.getcwd()
data_file_list = ["cross_g2_mercury_ts7_17over256ns_bins.dat"] #np.flip(np.asarray(os.listdir(data_folder)))
plot_folder="normal_g2_plots"
fit_folder="normal_g2_fit"

def constant(c):
	return c

def one_gaussian_g2_model(tau, c, A_0, tau_fwhm,tau_0):
	"""
	tau: time_bin_vector
	c: floor_offset
	A: decay_amplitude,
	tau_0: timing offset
	tau_c: time_constant
	"""
	# tau_S = 6e-8
	# tau_0 = 241.5*2e-9
	tau_sigma = tau_fwhm/(2*np.sqrt(2*np.log(2)))
	return c + A_0*(np.exp(-(tau-tau_0)**2/(2*(tau_sigma**2))))


def simple_gaussian_g2_model(tau, c, A_0, tau_c,tau_0):
	"""
	tau: time_bin_vector
	c: floor_offset
	A: decay_amplitude,
	tau_0: timing offset
	tau_c: time_constant
	"""
	# tau_S = 6e-8
	# tau_0 = 241.5*2e-9
	# tau_sigma = tau_fwhm/(2*np.sqrt(2*np.log(2)))
	return c + A_0*(np.exp(-(tau-tau_0)**2/((tau_c**2))))

def two_exponential_g2_model(tau, c, A_0, tau_c,tau_0,tau_delay):
	"""
	tau: time_bin_vector
	c: floor_offset
	A: decay_amplitude,
	tau_0: timing offset
	tau_c: time_constant
	"""
	# tau_S = 6e-8
	# tau_0 = 241.5*2e-9
	# tau_c = tau_fwhm/(2*np.log(2))
	return c + 0.25*A_0*(np.exp(-2*np.abs(tau-tau_0+tau_delay)/(tau_c))) + 0.25*A_0*(np.exp(-2*np.abs(tau-tau_0-tau_delay)/(tau_c))) 

def two_gaussian_g2_model(tau, c, A_0, tau_c,tau_0,tau_delay):
	"""
	tau: time_bin_vector
	c: floor_offset
	A: decay_amplitude,
	tau_0: timing offset
	tau_c: time_constant
	"""
	# tau_S = 6e-8
	# tau_0 = 241.5*2e-9
	# tau_c = tau_fwhm/(2*np.log(2))
	return c + 0.25*A_0*(np.exp(-(tau-tau_0+tau_delay)**2/((tau_c**2)))) + 0.25*A_0*(np.exp(-(tau-tau_0-tau_delay)**2/((tau_c**2))))


for filename in data_file_list:
	root_filename = filename[:-4]
	print("Processing: {}".format(filename))

	# # extract the normalisation factor
	# g2_file = open(os.path.join(data_folder,filename))
	# line1 = re.split(' |,',g2_file.readline())

	# counts1 = int(line1[2])
	# counts2 = int(line1[4])

	# line2 = re.split(' |,',g2_file.readline())
	# total_time =  int(line2[2])
	# time_interval = int(line2[5])
	# bin_size = time_interval/256

	scaling_factor = 7.784299004740641e-05#total_time/(counts1*counts2*time_interval)
	bin_size = 17/256

	# g2_file.close() 

	skip_time= 207.40-15#205.6 #203.59#ns
	fit_window= 30 #ns

	# skip the parts with deadtime
	data_arr = np.genfromtxt(os.path.join(data_folder,filename),skip_header=int(skip_time/bin_size),max_rows=int(fit_window/bin_size)).T
	time_vector = data_arr[0]#*1e9 # just to change to time in units of ns for fitting
	print(time_vector)

	# MODELS USED IN FIT
	# g2_model = Model(three_peak_single_decay_g2_model_symmetric,nan_policy="propagate")
	# one_peak_model = Model(simple_gaussian_g2_model,nan_policy="propagate")
	one_peak_model = Model(two_gaussian_g2_model,nan_policy="propagate")
	# constant_model = Model(constant,nan_policy="propagate")



	# Find sharpest point
	# time_highest=time_vector[np.argmax(data_arr[1])]
	# g2_highest = data_arr[1,np.argmax(data_arr[1])]

	c_guess = 1/scaling_factor #data_arr[1,2]
	A_0_guess = 1900/13196*c_guess ## data_arr[1,np.argmax(data_arr[1])] - c_guess
	tau_0_guess =  0 #time_vector[np.argmax(data_arr[1])]

	# current_coincidence_peak_amp = data_arr[1,np.argmax(data_arr[1])] - c_guess
	# tau_c_position = np.argmax(data_arr[1])
	# while current_coincidence_peak_amp > 0.5*A_0_guess:
	# 	tau_c_position += 1
	# 	current_coincidence_peak_amp = data_arr[1,tau_c_position] - c_guess

	tau_c_0_guess = 0.23 #ns np.abs(time_vector[tau_c_position]-time_highest) #s
	tau_delay_guess = 10.3
	# start_point = 0#211-3 # bin position
	# end_point = len(data_arr[0])#start_point + 60 #number of bins 
	# time_vector = data_arr[0][start_point:end_point]
	# data_vector = data_arr[1][start_point:end_point]
	# # print(time_vector)

	# to help with fitting put all in ns
	fit_params = Parameters()
	fit_params.add('tau_c', value=tau_c_0_guess,vary=False)#, max=5e8, min=0)
	fit_params.add('tau_0', value=tau_0_guess)
	fit_params.add('tau_delay',value = tau_delay_guess, min=tau_delay_guess- tau_c_0_guess, max=tau_delay_guess+tau_c_0_guess)
	# fit_params.add('tau_0', value=241*2e-9, max=243*2e-9, min=239*2e-9)
	fit_params.add('c',value=c_guess, vary=False)
	fit_params.add('A_0',value = A_0_guess, vary= False)

	data_vector = data_arr[1]

	# Try to see if there are side peaks to fit to else 
	# result = g2_model.fit(data_arr[1],tau=data_arr[0]*bin_size, c= c_guess, A_0 = A_0_guess,reci_tau_c_0 = reci_tau_c_0_guess,A_S = A_S_guess, reci_tau_c_S = reci_tau_c_S_guess,weights=1/np.sqrt(data_arr[1]))
	peak_result = one_peak_model.fit(data_vector,fit_params,tau=time_vector, weights=1/np.sqrt(data_vector),nan_policy="propagate")
	# constant_result = ConstantModel().fit(data_vector,x=time_vector*bin_size, c= c_guess,nan_policy="propagate",weights=1/np.sqrt(data_vector))

	print(peak_result.fit_report())
	# print(constant_result.fit_report())



# 	# best fit arrays
# 	dip_best_fit_array = dip_result.best_fit

# 	# fit results
# 	dip_dic = dip_result.params
# 	# constant_dic = constant_result.params

	plt.clf()
	plt.rcParams['font.size'] = '22'
	fig, axs = plt.subplots(figsize=(11.7,8.3),constrained_layout=True)



	axs.set_ylabel(r"$g^{(2)}$")	
	axs.set_xlabel("time [ns]")
	# axs.set_ylim([0.95,2])#1.25])

	scaling_factor = scaling_factor#1/peak_result.params['c'].value
	# plot raw data

	axs.errorbar(time_vector-peak_result.params['tau_0'].value,data_vector*scaling_factor,yerr=np.sqrt(data_arr[1])*scaling_factor,fmt='r.')
	# plot fit
	axs.plot(time_vector-peak_result.params['tau_0'].value,peak_result.best_fit*scaling_factor,"k--")

	# plot coherence time
	peak_time = 0#peak_result.params['tau_0'].value
	arrow_length = peak_result.params['tau_c'].value/2
	arrow_g2_value = (peak_result.params['A_0'].value)/2 + peak_result.params['c'].value

	# axs.arrow(peak_time, arrow_g2_value,-arrow_length,0,length_includes_head=True,linestyle='--',color='blue')
	# axs.arrow(peak_time, arrow_g2_value,+arrow_length,0,length_includes_head=True,linestyle='--',color='blue')

	modified_time_vector = time_vector-peak_result.params['tau_0'].value
	# axs.arrow(peak_time, arrow_g2_value,-arrow_length,0,length_includes_head=True,linestyle='--',color='blue')
	# axs.arrow(peak_time, arrow_g2_value,+arrow_length,0,length_includes_head=True,linestyle='--',color='blue')

	# np.savetxt(os.path.join(fit_folder,root_filename+"_fitted.dat"),np.asarray([modified_time_vector,data_vector*scaling_factor,np.sqrt(data_vector)*scaling_factor,peak_result.best_fit*scaling_factor]),header="time[ns],raw g(2),error bar, fit value")


	# plot secondary coincidences axis
	axs2 = axs.twinx()
	axs2.set_ylim(np.asarray(axs.get_ylim())/scaling_factor)
	axs2.set_ylabel("coincidences")

	# plt.xlim([-1,1])
	plt.show()
	# plt.savefig(os.path.join(plot_folder,root_filename+".pdf"),format="pdf",dpi=1000)

	fit_time_vector = np.linspace(-20,20,10000) #ns
	fit_data_vector = two_gaussian_g2_model(fit_time_vector, peak_result.params['c'], peak_result.params['A_0'], peak_result.params['tau_c'],0,peak_result.params['tau_delay'])
	fit_coincidence_vector = fit_data_vector*scaling_factor

	np.savetxt("g2X_fitted_Hg.dat",np.asarray([fit_time_vector,fit_data_vector,fit_coincidence_vector]).T,header="time[ns],coincidences, g(2)")
	np.savetxt("g2X_fit_report_Hg.txt",np.asarray([0]),header=peak_result.fit_report())
