#!/usr/bin/env python2

from Whitley_Stroud import *

import matplotlib.pyplot as plt
import numpy as np

from uncertainties import ufloat
from uncertainties import unumpy
from uncertainties.unumpy import uarray

# import lmfit
from lmfit import Minimizer
from lmfit import Parameters
from lmfit import fit_report
# from lmfit import conf_interval


""" Contant parameters"""
DeltaT = 30e-9  # integration time window
Gamma1 = 6.066
Gamma2 = 0.666
Delta = 40.
# delta = 5.
lw = 1.4
pump1 = .450
pump2 = 15

"""
data import and column names
"""

raw_data = np.genfromtxt('data/plot_data_tau.dat')

freq = (raw_data[:, 0] - 68) * 2
pairs = raw_data[:, 1] / 1.875
pairs_err = np.sqrt(pairs)
pairs_err[pairs_err == 0] = 1
signal = raw_data[:, 2] / 1.875
signal_err = np.sqrt(signal)
signal_err[signal_err == 0] = 1
idler = raw_data[:, 3] / 1.875
idler_err = np.sqrt(idler)
idler_err[idler_err == 0] = 1
taus = raw_data[:, 5]
taus_err = raw_data[:, 6]
eff_s = pairs / signal
eff_i = pairs / idler
eff_s_err = unumpy.std_devs(uarray(pairs, pairs_err) /
                            uarray(signal, signal_err))
eff_i_err = unumpy.std_devs(uarray(pairs, pairs_err) /
                            uarray(idler, idler_err))


""" from tau to number"""
mu = 0.0808
mu_err = 0.002
mu_u = ufloat(mu, mu_err)
tau_u = uarray(taus, taus_err)
tau_0 = 1000 / (2 * np.pi * 6.067)


N_u = (tau_0 / tau_u - 1) / mu_u
N = unumpy.nominal_values(N_u)
N_err = unumpy.std_devs(N_u)

plt.figure('N vs detuning')
plt.errorbar(freq, N, yerr=N_err)
# plt.plot(freq, idler / N, 'o')
# plt.plot(freq, signal / N, 'o')

plt.show()

""" renormalize rates """

signal_u = uarray(signal, signal_err) / N_u
idler_u = uarray(idler, idler_err) / N_u
pairs_u = uarray(pairs, pairs_err) / N_u

signal = unumpy.nominal_values(signal_u)
signal_err = unumpy.std_devs(signal_u)
idler = unumpy.nominal_values(idler_u)
idler_err = unumpy.std_devs(idler_u)
pairs = unumpy.nominal_values(pairs_u)
pairs_err = unumpy.std_devs(pairs_u)


def signal_f(x, parvals):
    return (parvals['num'] * parvals['etas'] *
            single_lw(x, parvals['Delta'], parvals['Oma'],
                      parvals['Omb'], parvals['x0'], parvals['lw']) +
            parvals['dc_s'])


def idler_f(x, parvals):
    return (parvals['num'] * parvals['etai'] *
            single_lw(x, parvals['Delta'], parvals['Oma'],
                      parvals['Omb'], parvals['x0'], parvals['lw']) +
            parvals['dc_i'])


def pair_f(x, parvals):
    return (parvals['num'] * parvals['etai'] * parvals['etas'] *
            pairs_lw(x, parvals['Delta'], parvals['Oma'],
                     parvals['Omb'], parvals['x0'], parvals['lw']))


def eff_s_f(x, parvals):
    return pair_f(x, parvals) / signal_f(x, parvals)


def eff_i_f(x, parvals):
    return pair_f(x, parvals) / idler_f(x, parvals)


def save_fit(x_val, y_val, f_name, result=None):
    with open(f_name, 'w') as f:
        f.write('#freq\tFitValue\n')
        [f.write('{}\t{}\n'.format(a, b))
         for a, b
         in zip(x_val, y_val)]
        f.write('\n\n')
        if result is not None:
            f.write(fit_report(result))


def fit_function(params, freq, signal, idler, pairs, eff_s, eff_i):
    parvals = params.valuesdict()
    signal_res = signal_f(freq, parvals) - signal
    signal_res = signal_res / signal_err

    idler_res = idler_f(freq, parvals) - idler
    idler_res = idler_res / idler_err

    pair_res = pair_f(freq, parvals) - pairs
    pair_res = pair_res / pairs_err

    eff_s_res = eff_s_f(freq, parvals) - eff_s
    eff_s_res = eff_s_res / eff_s_err
    eff_i_res = eff_i_f(freq, parvals) - eff_i
    eff_i_res = eff_i_res / eff_i_err

    return np.concatenate((
        # signal_res,
        # idler_res,
        # pair_res,
        eff_s_res,
        eff_i_res
    ))


p = Parameters()
# add with tuples: (NAME VALUE VARY MIN  MAX  EXPR  BRUTE_STEP)
p.add_many(('num', np.max(signal), True, None, None, None, None),
           ('etas', .16, True, None, None, None, None),
           ('etai', .14, True, None, None, None, None),
           ('Delta', 40, False, None, None, None, None),
           ('Oma', 3, True, None, None, None, None),
           ('Omb', 3, True, None, None, None, None),
           ('x0', 0, True, -2, 2, None, None),
           ('lw', 1, True, 0, None, None, None),
           ('dc_s', 0, False, 0, None, None, None),
           ('dc_i', 0, False, 0, None, None, None))


mini = Minimizer(fit_function, p, (freq, signal, idler, pairs, eff_s, eff_i))
result = mini.minimize()
# result = mini.least_squares()
# result = minimize(fit_function, p,
#                   args=(freq, signal, idler, pairs, eff_s, eff_i),
#                   # method='nelder'
#                   )
print(result.params.pretty_print())
print(fit_report(result))

# print(result.eval_uncertainty() )

# ci = conf_interval(mini, result)
# lmfit.printfuncs.report_ci(ci)
""" Plotting """
parvals = result.params.valuesdict()
ext = 2

f, axarr = plt.subplots(3, sharex=True)
f.set_figheight(15)

freq_plot = np.linspace(ext * np.min(freq), ext * np.max(freq), 1000)
# plt.figure('Singles')
axarr[0].errorbar(freq, signal, yerr=signal_err, fmt='o')
axarr[0].plot(freq_plot, signal_f(freq_plot, parvals))
axarr[0].errorbar(freq, idler, yerr=idler_err, fmt='o')
axarr[0].plot(freq_plot, idler_f(freq_plot, parvals))

save_fit(freq_plot, signal_f(freq_plot, parvals), 'fit_signal.dat', result)
save_fit(freq_plot, idler_f(freq_plot, parvals), 'fit_idler.dat', result)

# plt.figure('Pairs')
axarr[1].errorbar(freq, pairs, yerr=pairs_err, fmt='o')
axarr[1].plot(freq_plot, pair_f(freq_plot, parvals))

save_fit(freq_plot, pair_f(freq_plot, parvals), 'fit_pairs.dat', result)

# plt.figure('Efficiencies')
axarr[2].errorbar(freq, eff_s, yerr=eff_s_err, fmt='o')
axarr[2].plot(freq_plot, eff_s_f(freq_plot, parvals))
axarr[2].errorbar(freq, eff_i, yerr=eff_s_err, fmt='o')
axarr[2].plot(freq_plot, eff_i_f(freq_plot, parvals))


save_fit(freq_plot, eff_s_f(freq_plot, parvals), 'fit_eff_s.dat', result)
save_fit(freq_plot, eff_i_f(freq_plot, parvals), 'fit_eff_i.dat', result)

plt.tight_layout()

with open('eff_data.data', 'w') as f:
    f.write('#freq\teff_s\teff_s_err\teff_i\teff_i_err\n')
    [f.write('{}\t{}\t{}\t{}\t{}\n'.format(a, b, c, d, e))
     for a, b, c, d, e
     in zip(freq, eff_s, eff_s_err, eff_i, eff_i_err)]



plt.show()
