#!/usr/bin/env python
import numpy as np
from scipy.integrate import ode


def exp_decaying(t, gamma, t0=0):
    t = t - t0
    return (t >= 0) * np.sqrt(gamma) * np.exp(-gamma * t / 2)


def exp_rising(t, gamma, t0=0):
    t = t - t0
    return (t <= 0) * np.sqrt(gamma) * np.exp(gamma * t / 2)


def P_exp_decaying(t, gamma_0, gamma_p, l, t0=0):
    t = t - t0
    if gamma_0 == gamma_p:
        return l * (gamma_0 * t)**2 * exp_decaying(t, gamma_0)
    else:
        return (t >= 0) * (4 * l * gamma_0 * gamma_p) / \
               (gamma_p - gamma_0)**2 * (np.exp(-gamma_0 * t / 2) -
                                         np.exp(-gamma_p * t / 2))**2


def P_exp_rising(t, gamma_0, gamma_p, l, t0=0):
    t = t - t0
    return (4 * l * gamma_0 * gamma_p) / \
           (gamma_p + gamma_0)**2 *\
           ((t <= 0) * np.exp(gamma_p * t) +
            (t > 0) * np.exp(- gamma_0 * t))


def r_f(t, gamma_0, gamma_p, l, xi, p, t0=0):
    t = t - t0
    return (xi(t, gamma_p) -
            np.sqrt(l * gamma_0 * p(t, gamma_0, gamma_p, l)))**2


def delta(t, gamma_0, gamma_p, l, xi, p, t0=0):
    t = t - t0
    return xi(t, gamma_p)**2 - \
        (xi(t, gamma_p) -
         np.sqrt(l * gamma_0 * p(t, gamma_0, gamma_p, l)))**2


def my_int(f, time_v, dt, gamma_0, l):
    """
    Routine to integrate the atom interaction
    differential equation from the extintion data
    """
    def atom(t, y):
        dydt = f(t) - gamma_0 * (1-l) * y
        return dydt
    y = np.zeros(len(time_v))
    y0, t0 = 0, np.min(time_v)
    sol = ode(atom).set_integrator('dopri5')
    sol.set_initial_value(y0, t0)
    for k in range(len(time_v)):
        y[k] = sol.integrate(sol.t + dt)[0]
    return np.roll(y, 1)


if __name__ == '__main__':
    import matplotlib.pyplot as plt

    gamma_0 = 1 / 26.2
    gamma_p = 1 / 13.5
    l = 0.03
    # time_v, dt = np.linspace(-400, 400, 1000, retstep=True)
    dt = 2
    time_v = np.arange(-400, 400, dt)
    # plt.figure()
    # plt.plot(time_v, P_exp_decaying(time_v, gamma_0, gamma_p, 0.03))
    # plt.plot(time_v, P_exp_rising(time_v, gamma_0, gamma_p, 0.03))

    plt.figure()
    plt.title('Difference of transmission')
    plt.plot(time_v,
             delta(time_v, gamma_0, gamma_p, l, exp_decaying, P_exp_decaying))
    plt.plot(
        time_v, delta(time_v, gamma_0, gamma_p, l, exp_rising, P_exp_rising))

    def f_to_int(t):
        return delta(t, gamma_0, gamma_p, l, exp_rising, P_exp_rising)
    y_rise = my_int(f_to_int, time_v, dt, gamma_0, l)

    plt.figure()
    plt.title('Excitation comparison')
    plt.plot(time_v, y_rise)
    plt.plot(time_v, P_exp_rising(time_v, gamma_0, gamma_p, l))

    def f_to_int(t):
        return delta(t, gamma_0, gamma_p, l, exp_decaying, P_exp_decaying)
    y_dec = my_int(f_to_int, time_v, dt, gamma_0, l)

    plt.plot(time_v, y_dec)
    plt.plot(time_v, P_exp_decaying(time_v, gamma_0, gamma_p, l))

    plt.show()
