import matplotlib.pyplot as plt
import numpy as np

from lmfit.models import LinearModel
from lmfit.models import QuadraticModel

import seaborn as sns

from scipy.stats import pearsonr
from sklearn.cluster import KMeans

# data import
time, _, rtrip, prs1, prs2, _, _, delta, err, _ = np.genfromtxt('offset_rtt_2_4_15.dat').T

"""choice between linear or parabolic model for the clocks frequencies drift"""
linear = False

driftfit = QuadraticModel()
if linear:
    driftfit = LinearModel()

"""clocks frequencies drift fit"""
# guess initial parameters
p = driftfit.guess(delta, x=time)
result = driftfit.fit(delta, x=time, params=p)
print(result.fit_report())

np.savetxt('residual_offset_parabolic.dat',
           np.c_[time, delta, result.best_fit, result.residual],
           header='Time\tdelta\tdelta_fit\tresidual', delimiter='\t',
           footer=result.fit_report())

"""sorting of the round trip times for aesthetic reasons when plotting
shaded area"""
sort_mask = np.argsort(rtrip)
rtrip_s = rtrip[sort_mask]
res = result.residual[sort_mask]

"""Pearson correlation"""
print('Pearson r: {:.3f}, P-value: {:.3f}'.format(*pearsonr(res, rtrip_s)))


"""Fit of the residuals with a linear model to asses correlation"""
residual_fit = LinearModel()
p_res = residual_fit.guess(rtrip_s, res)
result_res = residual_fit.fit(res, x=rtrip_s, params=p_res)
print(result_res.fit_report())

# calculates the confidence interval by bootstrapping
result_res.conf_interval(sigmas=[.99])
print(result_res.ci_report())
print(result_res.ci_out['slope'])

""" organize RTT in 4 clusters"""
kmeans = KMeans(n_clusters=4, random_state=0).fit(rtrip.reshape(-1, 1))
centers = kmeans.cluster_centers_[:, 0]

"""Correlation of Offset with Fibre Length"""
DresDr = result_res.params['slope'].value # slope of residual vs round trip time
refractive_indx = 1.45
c = 3e8
v = c/refractive_indx
RTT_per_L = 2 * 1. / v * 1e12 # psec
print('correlation of offset with L: {} ps / m'.format(DresDr * RTT_per_L))

# RTT per meter of fiber = 459.41772365 ns/m

# assign k_means to a fiber, from shortest to longest
fibers = np.searchsorted(np.sort(centers), centers)
fibers_lab = ['Fiber {}\nRTT {:.2f} ns'.format(j, k)
              for j, k
              in zip(fibers, centers)]

"""plotting"""
fig0 = plt.figure('Clock offset vs time')
result.plot(fig=fig0)

######
plt.figure('Residuals vs RTT with linear fit')
plt.plot(rtrip_s, res, 'o', alpha=.1)
dely = result_res.eval_uncertainty(sigma=3)
plt.fill_between(rtrip_s, result_res.best_fit - dely,
                 result_res.best_fit + dely, color="#ABABAB")
plt.plot(rtrip, result_res.best_fit)

plt.figure('Distribution of residuals for each fiber')
# sns.violinplot(x=[centers[j] for j in kmeans.labels_], y=res, bw=.15)
sns.violinplot(x=[fibers_lab[j] for j in kmeans.labels_],
               y=res,
               bw=.15,
               order=np.sort(fibers_lab))
plt.ylabel('residual (ns)')


fit_x_ext = time.ptp() / 5
t_fit_plot = np.linspace(time.min() - fit_x_ext, time.max() + fit_x_ext, 200)
np.savetxt('residual_offset_parabolic-fit_line.dat',
           np.c_[t_fit_plot,
                 result.eval(x=t_fit_plot)],
           header='Time\tdelta_fit', delimiter='\t',
           footer=result.fit_report())

plt.show()
