import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from lmfit import Model
from lmfit import Parameters

# constants
offset_tx = 879
infile_long_dec = ('non_reversed_with_alternating_09_26_'
                   'onwards_combined_histo_tx_2ns')
infile_long_ris = ('reversed_with_alternating_09_26_'
                   'onwards_combined_histo_tx_2ns')


def val_cross(vec, value):
    return np.argmin(np.abs(vec - value))


def exp_profile(x, Amplitude, tau, x_offset, y_offset):
    x = x - x_offset
    f = Amplitude * np.exp(-x / tau) * (x * np.sign(tau) >= 0)
    return f + y_offset

exp_model = Model(exp_profile)


""" decay """
# data import, following gnuplot file
data_decay = np.genfromtxt(infile_long_dec)
time = data_decay[:, 1] - offset_tx
signal = data_decay[:, 8]
signal_err = data_decay[:, 9]

t_min = 4
idx0 = val_cross(time, t_min)
t_max = 200
idx1 = val_cross(time, t_max)

time = time[idx0:idx1]
signal = signal[idx0:idx1]
signal_err = signal_err[idx0:idx1]


p = Parameters()
p.add('Amplitude', np.max(signal))
p.add('tau', 13)
p.add('x_offset', 1.29, vary=1)
p.add('y_offset', np.mean(signal[-30:]), vary=0)

fit_decay = exp_model.fit(signal, x=time, params=p, weights=1 / signal_err)
print(fit_decay.fit_report())

plt.figure()
gs = gridspec.GridSpec(3, 1)

# fit plot
ax1 = plt.subplot(gs[:2, :])
plt.errorbar(time, signal, yerr=signal_err, fmt='o')
plt.plot(time, fit_decay.eval(x=time))
plt.setp(ax1.get_xticklabels(), visible=False)

# residual plot
ax2 = plt.subplot(gs[2, :], sharex=ax1)
plt.errorbar(time, signal - fit_decay.eval(x=time), yerr=signal_err, fmt='o')
# plt.xlim([0.5, 180])

with open('residuals_decay.dat', 'w') as f:
    f.write('#Delta_t(ns)\tresidual\tresidual_err\n')
    [f.write('{:.1f}\t{:3e}\t{:3e}\n'.format(i, j, k))
     for i, j, k
     in zip(time,
            signal -
            fit_decay.eval(x=time),
            signal_err)]

""" rise """
# data import, following gnuplot file
data_rise = np.genfromtxt(infile_long_ris)
time = data_rise[:, 1] - offset_tx
signal = data_rise[:, 8]
signal_err = data_rise[:, 9]

t_min = -200
idx0 = val_cross(time, t_min)
t_max = -2
idx1 = val_cross(time, t_max)

time = time[idx0:idx1]
signal = signal[idx0:idx1]
signal_err = signal_err[idx0:idx1]


p = Parameters()
p.add('Amplitude', np.max(signal))
p.add('tau', -13, max=-1)
p.add('x_offset', 1.29, vary=1)
p.add('y_offset', np.mean(signal[:30]), vary=0)

fit_rise = exp_model.fit(signal, x=time, params=p, weights=1 / signal_err)
print(fit_rise.fit_report())

plt.figure()
gs = gridspec.GridSpec(3, 1)

# fit plot
ax1 = plt.subplot(gs[:2, :])
plt.errorbar(time, signal, yerr=signal_err, fmt='o')
plt.plot(time, fit_rise.eval(x=time))
plt.setp(ax1.get_xticklabels(), visible=False)

# residual plot
ax2 = plt.subplot(gs[2, :], sharex=ax1)
plt.errorbar(time, signal - fit_rise.eval(x=time), yerr=signal_err, fmt='o')
# plt.xlim([-200, 10])


with open('residuals_rise.dat', 'w') as f:
    f.write('#Delta_t(ns)\tresidual\tresidual_err\n')
    [f.write('{:.1f}\t{:3e}\t{:3e}\n'.format(i, j, k))
     for i, j, k
     in zip(time,
            signal -
            fit_rise.eval(x=time),
            signal_err)]


plt.show()
