#!/usr/bin/env python

import atom_interaction as AI
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from uncertainties import ufloat
from uncertainties import unumpy
from scipy.stats import norm

sns.set_context("poster")
sns.set_style("white")
sns.set_style("ticks", {"xtick.direction": "in",
                        "ytick.direction": "in"})

""" Useful parameters """
# all the times are expressed in nanoseconds. Frequencies in GHz.
l = 0.032
gamma_0 = 1 / 26.2  # e-9
dt_tx = 2
"""
data import
"""
infile_tx_decaying = 'non_reversed_with_alternating_09_26_onwards_combined_histo_tx_2ns'
infile_tx_rising = 'reversed_with_alternating_09_26_onwards_combined_histo_tx_2ns'
infile_rx_decaying = 'non_reversed_with_alternating_09_26_onwards_combined_histo_rx_5ns'
infile_rx_rising = 'reversed_with_alternating_09_26_onwards_combined_histo_rx_5ns_4'

raw_tx_decaying = np.genfromtxt(infile_tx_decaying)
raw_rx_decaying = np.genfromtxt(infile_rx_decaying)
raw_tx_rising = np.genfromtxt(infile_tx_rising)
raw_rx_rising = np.genfromtxt(infile_rx_rising)

h_shift = 878.0 + 1.5

time_rise = (raw_tx_rising[:, 1] - h_shift)
time_deca = (raw_tx_decaying[:, 1] - h_shift)

# time_rise = (raw_tx_rising[:, 1] - 878.5+.5)
# time_deca = (raw_tx_decaying[:, 1] - 878.5)


# integration limits for accidental noise subtraction
idx = np.arange(np.argmin(abs(time_rise + 150)),
                np.argmin(abs(time_rise - 300)))
idx_r = np.arange(np.argmin(abs(time_rise + 150)),
                  np.argmin(abs(time_rise + 100)))
idx_d = np.arange(np.argmin(abs(time_deca + 100)),
                  np.argmin(abs(time_deca + 50)))

bg_r0 = np.mean(unumpy.uarray(raw_tx_rising[idx_r, 8],
                              raw_tx_rising[idx_r, 9]))
G_r0 = unumpy.uarray(raw_tx_rising[idx, 8], raw_tx_rising[idx, 9]) - bg_r0
bg_r = np.mean(unumpy.uarray(raw_tx_rising[idx_r, 4],
                             raw_tx_rising[idx_r, 5]))
G_r = unumpy.uarray(raw_tx_rising[idx, 4], raw_tx_rising[idx, 5]) - bg_r

delta_r = (G_r0 - G_r) / (sum(G_r0)) / dt_tx

bg_d0 = np.mean(unumpy.uarray(raw_tx_decaying[idx_d, 8],
                              raw_tx_decaying[idx_d, 9]))
G_d0 = unumpy.uarray(raw_tx_decaying[idx, 8], raw_tx_decaying[idx, 9]) - bg_d0
bg_d = np.mean(unumpy.uarray(raw_tx_decaying[idx_d, 4],
                             raw_tx_decaying[idx_d, 5]))
G_d = unumpy.uarray(raw_tx_decaying[idx, 4], raw_tx_decaying[idx, 5]) - bg_d

delta_d = (G_d0 - G_d) / (sum(G_d0)) / dt_tx

time_r = time_rise[idx]
time_d = time_deca[idx]


""" Bootstrap evaluation of the error """
reps = 2000


def data_i(delta):
    return [np.random.normal(v.n, v.s, 1)[0]
            for v
            in delta]

decays = np.zeros((reps, len(time_d)))
for k in range(reps):
    delta_i = data_i(delta_d)
    decays[k, :] = AI.my_int(lambda t: np.interp(t, time_d, delta_i),
                             time_d, dt_tx, gamma_0, l)
decay = np.mean(decays, 0)
decay_err = np.std(decays, 0)
decay_max = ufloat(np.mean(np.max(decays, 1)),
                   np.std(np.max(decays, 1)))


rises = np.zeros((reps, len(time_d)))
for k in range(reps):
    delta_i = data_i(delta_r)
    rises[k, :] = AI.my_int(lambda t: np.interp(t, time_r, delta_i),
                            time_r, dt_tx, gamma_0, l)
rise = np.mean(rises, 0)
rise_err = np.std(rises, 0)
rise_max = ufloat(np.mean(np.max(rises, 1)),
                  np.std(np.max(rises, 1)))


""" Printing """
print('Rising extinction: {:.2f} %'.format(sum(delta_r) * dt_tx * 100))
print('Rising decay: {:.2f} %'.format(sum(delta_d) * dt_tx * 100))

print('Max Pe for decaying: {:.3f} %\n'
      'Max Pe for rising:  {:.3f} %\n'
      'Ratio: {:.3f}'
      ''.format(decay_max * 100.,
                rise_max * 100.,
                rise_max/decay_max))
# print(decay_max, rise_max)

""" Plotting """
# plt.figure()
# sns.distplot(np.max(decays, 1), kde=False, fit=norm)
# plt.figure()
# sns.distplot(np.max(rises, 1), kde=False, fit=norm)

dt_rx = 5
eta = 3.68e-3 * 2 * 2
rx_det_eff = 0.0078
rx_decaying_offset = 3.11e-7
rx_rising_offset = 3.2e-7
rx_decaying_offset_err = 4.67e-9
rx_rising_offset_err = 4.95e-9
rx_scalefactor = gamma_0 * eta * rx_det_eff * dt_rx
P_r = (raw_rx_rising[:, 4] - rx_rising_offset) / rx_scalefactor
P_r_err = (raw_rx_rising[:, 5] + rx_rising_offset_err) / rx_scalefactor
time_pr = (raw_rx_rising[:, 1] - (878.5 + 12))
P_d = (raw_rx_decaying[:, 4] - rx_decaying_offset) / rx_scalefactor
P_d_err = (raw_rx_decaying[:, 5] + rx_decaying_offset_err) / rx_scalefactor
time_pd = (raw_rx_decaying[:, 1] - (881.5 + 12))

idx = np.arange(np.argmin(abs(time_pr + 150)),
                np.argmin(abs(time_pr - 300)))

time_v, dt = np.linspace(min(time_d), max(time_d), 1000, retstep=True)
gamma_p = 1 / 13.5

f, ax1 = plt.subplots()
ax1.errorbar(time_pd[idx], P_d[idx], yerr=P_d_err[idx], fmt='o')
ax1.errorbar(time_d, decay, yerr=decay_err, fmt='o')
ax1.plot(time_v, AI.P_exp_decaying(time_v, gamma_0, gamma_p, l))
plt.xlim(time_d[[0, -1]])

f, ax2 = plt.subplots()
ax2.errorbar(time_pr[idx], P_r[idx], yerr=P_r_err[idx], fmt='o')
ax2.errorbar(time_r, rise, yerr=rise_err, fmt='o')
ax2.plot(time_v, AI.P_exp_rising(time_v, gamma_0, gamma_p, l))
plt.xlim(time_d[[0, -1]])

# [plt.plot(time_d, decays[k, :]) for k in range(reps)]
# [plt.plot(time_r, rises[k, :]) for k in range(reps)]

plt.figure()
plt.plot(time_d, unumpy.nominal_values(G_d0), 'o-')
plt.plot(time_r, unumpy.nominal_values(G_r0), 'o-')

"""
Saving
"""
with open('ODE_rising.dat', 'w') as f:
    f.write('#time\tPe\tPe_err\n')
    [f.write('{:.1f}\t{:.5e}\t{:.5e}\n'
             ''.format(t, P, P_err))
     for t, P, P_err
     in zip(time_r, rise, rise_err)]

with open('ODE_decaying.dat', 'w') as f:
    f.write('#time\tPe\tPe_err\n')
    [f.write('{:.1f}\t{:.5e}\t{:.5e}\n'
             ''.format(t, P, P_err))
     for t, P, P_err
     in zip(time_r, decay, decay_err)]

plt.show()


# Rising extinction: 4.44+/-0.38 %
# Rising decay: 4.27+/-0.37 %
# Max Pe for decaying: 1.760+/-0.088 %
# Max Pe for rising:  2.745+/-0.126 %
# Ratio: 0.641+/-0.044
