import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from uncertainties import ufloat

from lmfit import Model, CompositeModel
from lmfit import Parameters
from lmfit.models import ConstantModel


sns.set_context("poster")
sns.set_style("white")
sns.set_style("ticks", {"xtick.direction": "in",
                        "ytick.direction": "in"})
blue_line = sns.color_palette()[0]
red_line = sns.color_palette()[2]
blue = sns.color_palette('dark')[0]
red = sns.color_palette('dark')[2]

"""
data import and column names
"""
raw_data = np.genfromtxt('plot_data.dat')

freq = (raw_data[:, 0] - 68) * 2
pairs = raw_data[:, 1] / 1.875
pairs_err = np.sqrt(pairs)
pairs_err[pairs_err == 0] = 1
signal = raw_data[:, 2] / 1.875
signal_err = np.sqrt(signal)
idler = raw_data[:, 3] / 1.875
idler_err = np.sqrt(idler)

""" fitting """
# g = 2 * np.pi * .666
# g = 8


def co_f(x, Om, x0, g):
    x = x - x0
    return (4 * x**2 + g**2) * Om**2 / (4 * x**2 + g**2 + 2 * Om**2)**2


def inco_f(x, Om, x0, g):
    x = x - x0
    return 2 * Om**2 / (4 * x**2 + g**2 + 2 * Om**2)**2


def eff(a, b):
    return 1 / (1 + b / a)


fit_co = Model(co_f, prefix='co_')
fit_inc = Model(inco_f, prefix='inco_')
fit_pair = ConstantModel(prefix='Amp_') * fit_co
fit_eff = ConstantModel(prefix='eta_') * \
    CompositeModel(fit_co, ConstantModel(prefix='ratio_') * fit_inc, eff)
# fit_signal = ConstantModel(prefix='Amp_') / ConstantModel(prefix='eta_') *\
#     (fit_co + ConstantModel(prefix='ratio_') * fit_inc)

fit_signal = ConstantModel(prefix='Amp_') * fit_inc

fit_idler = ConstantModel(prefix='Amp_') *\
    (fit_co + ConstantModel(prefix='ratio_') * fit_inc)


""" start fitting the pairs with coherent function"""
p = Parameters()
p.add('co_x0', -1.3, vary=1)
p.add('co_Om', 3.159, vary=1)
p.add('Amp_c', 1e5)
p.add('co_g', 7.95, vary=0)

result = fit_pair.fit(pairs, x=freq, params=p)
print(result.fit_report())
# result.plot()

x_n = np.linspace(np.min(freq), np.max(freq), int(1e3))
dely = result.eval_uncertainty(sigma=3)
plt.fill_between(freq,
                 result.best_fit - dely,
                 result.best_fit + dely,
                 color="#ABABAB")
plt.plot(x_n, result.eval(x=x_n), 'r')
plt.errorbar(freq, pairs, yerr=pairs_err, fmt='o')
# plt.show()


# """ now the singles with the incoherent function"""
# p = Parameters()
# p.add('inco_x0', -1.3, vary=1)
# p.add('inco_Om', 3.159, min=0, vary=1)
# p.add('Amp_c', 1e5)
# p.add('inco_g', 2 * np.pi * 0.0666, vary=0)

# result = fit_signal.fit(signal, x=freq, params=p)
# print(result.fit_report())
# result.plot()

# plt.show()

p = result.params
p['co_Om'].set(vary=0)
p['co_x0'].set(vary=0)
p['Amp_c'].set(vary=0)
p['co_g'].set(vary=1)
p.add('inco_Om', expr='co_Om')
p.add('inco_x0', expr='co_x0')
p.add('inco_g', expr='co_g')
p.add('ratio_c', 5)
p.add('eta_c', .15)

result = fit_eff.fit(pairs / idler, x=freq, params=p,
                     # weights=1 / pairs_err
                     )
print(result.fit_report())

plt.figure('heralded signal efficiency')
dely = result.eval_uncertainty(sigma=3)
plt.fill_between(freq,
                 result.best_fit - dely,
                 result.best_fit + dely,
                 color="#ABABAB")
plt.plot(x_n, result.eval(x=x_n), 'r')
plt.plot(freq, pairs / idler, 'o')

plt.figure()
plt.plot(freq, fit_signal.eval(x=freq, params=result.params))
plt.plot(freq, fit_pair.eval(x=freq, params=result.params))
# plt.plot(freq, fit_inc.eval(x=freq, params=result.params))
plt.plot(freq, idler, 'o')
# plt.plot(freq, fit_eff.eval(x=freq, params=result.params))
# plt.plot(freq, pairs / signal, 'o')

with open('eff_i.dat', 'w') as f:
    f.write('#2-ph_detuning(MHz)\teff_s\n')
    [f.write('{:.3f}\t{:.3e}\n'.format(d, e))
     for d, e
     in zip(x_n, result.eval(x=x_n))]

plt.show()


"""
plot of the singles
"""
f, ax = plt.subplots()
f.set_size_inches(14, 9)
plt.errorbar(freq, signal,
             yerr=signal_err,
             # xerr=OD_err,
             fmt='o',
             color=blue)
plt.errorbar(freq, idler,
             yerr=idler_err,
             # xerr=OD_err,
             fmt='o',
             color=red)

ax.tick_params(labelsize=26)
ax.xaxis.set_tick_params(width=3)
ax.yaxis.set_tick_params(width=3)
for axis in ['bottom', 'left']:
    ax.spines[axis].set_linewidth(3)
sns.despine(trim=False, offset=20)
ax.spines['bottom'].set_position('zero')
plt.yticks(np.arange(0, 200001, 1e5))
# plt.xticks(np.arange(-40, 40.1, 40))
plt.ylim(0, 230000)
plt.xlim(-17, 16)
plt.xlabel('2-ph detuning (MHz)', x=1, fontsize=32)
plt.ylabel('rate (1/s)',
           fontsize=32, rotation=0,
           y=1.02, labelpad=-60)

plt.tight_layout()
plt.savefig("singles_vs_freq.pdf", format="pdf")

"""
plot for the pair rate
"""
f, ax = plt.subplots()
f.set_size_inches(14, 9)

plt.errorbar(freq, pairs,
             yerr=pairs_err,
             # xerr=OD_err,
             fmt='o',
             color=sns.color_palette()[1])

ax.tick_params(labelsize=26)
ax.xaxis.set_tick_params(width=3)
ax.yaxis.set_tick_params(width=3)
for axis in ['bottom', 'left']:
    ax.spines[axis].set_linewidth(3)
sns.despine(trim=False, offset=20)
ax.spines['bottom'].set_position('zero')
plt.yticks(np.arange(0, 15000.1, 5000))
# # plt.xticks(np.arange(-40, 40.1, 40))
# plt.ylim(0, 6000)
plt.xlim(-17, 16)
plt.xlabel('2-ph detuning (MHz)', x=1, fontsize=32)
plt.ylabel('coincidences (1/s)',
           fontsize=32, rotation=0,
           y=1.02, labelpad=-60)

plt.tight_layout()
plt.savefig("pairs_vs_freq.pdf", format="pdf")

"""
plot for the efficiency
"""
f, ax = plt.subplots()
f.set_size_inches(14, 9)
eff_s = pairs / signal
eff_i = pairs / idler
eff_s_err = [(ufloat(p, p_err) / ufloat(s, s_err)).s
             for p, p_err, s, s_err
             in zip(pairs, pairs_err, signal, signal_err)]
eff_i_err = [(ufloat(p, p_err) / ufloat(i, i_err)).s
             for p, p_err, i, i_err
             in zip(pairs, pairs_err, idler, idler_err)]
plt.errorbar(freq, eff_s,
             yerr=eff_s_err,
             # xerr=OD_err,
             fmt='o',
             color=blue)

plt.errorbar(freq, eff_i,
             yerr=eff_i_err,
             # xerr=OD_err,
             fmt='o',
             color=red)

ax.tick_params(labelsize=26)
ax.xaxis.set_tick_params(width=3)
ax.yaxis.set_tick_params(width=3)
for axis in ['bottom', 'left']:
    ax.spines[axis].set_linewidth(3)
sns.despine(trim=False, offset=20)
ax.spines['bottom'].set_position('zero')
plt.yticks(np.arange(0, .21, .05))
# plt.xticks(np.arange(-40, 40.1, 40))
plt.ylim(0, .21)
plt.xlim(-17, 16)
plt.xlabel('2-ph detuning (MHz)', x=1, fontsize=32)
plt.ylabel('efficiency',
           fontsize=32, rotation=0,
           y=1.02, labelpad=-60)

plt.tight_layout()
plt.savefig("eff_vs_freq.pdf", format="pdf")


# plt.show()
