import numpy as np
from math import *
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# import seaborn

plt.rcParams.update({'font.size': 23})

def plot(rho, figname):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.set_zlim3d(-0.7, 0.7) 
    # for z,height in enumerate(abs(rho)):
    for z,height in enumerate(np.abs(np.real(rho))):
        ax.bar(np.arange(4),height,zs=z,zdir='y',color='b',alpha=0.6)
    plt.savefig(figname+'.pdf')

def plot_real_imaginary(rho, figname, color='red', width=0.7, tilt=25, rotate=30):
    fig = plt.figure(figsize=(12,5)) #default is (8,6)
    ax = fig.add_subplot(121, projection='3d')
    ax.set_zlim3d(-0.7, 0.7) 
    # for z,height in enumerate(abs(rho)):
        # ax.bar(np.arange(4),height,zs=z,zdir='y',color='b',alpha=0.6)
    # for z,height in enumerate(np.imag(rho)):
        # ax.bar(np.arange(4),height,zs=z,zdir='y',color='b',alpha=0.6)
    for x,height in enumerate((np.real(rho))):
        ax.grid(False)
        ax.set_zticks([-0.5,0,0.5])
        ax.bar3d(
            x = np.ones(4)*x, 
            y = np.arange(4), 
            z = np.zeros(4), 
            dx=np.ones(4)*width, 
            dy=np.ones(4)*width, 
            dz=height,
            color=color
            )
    xticks = np.arange(4)
    labels = ['HH','HV','VH','VV']
    ax.set_xticks(xticks)
    ax.set_xticklabels(labels, minor=False)
    ax.set_yticks(xticks)
    ax.set_yticklabels(labels, minor=False)
    ax.view_init(tilt, rotate)
    ax.set_zticklabels(['-0.5  ','0 ','0.5 '])


    ax2 = fig.add_subplot(122, projection='3d')
    ax2.set_zlim3d(-0.7, 0.7) 
    for x,height in enumerate((np.imag(rho))):
        ax2.grid(False)
        ax2.set_zticks([-0.5,0,0.5])
        ax2.bar3d(
            x = np.ones(4)*x, 
            y = np.arange(4), 
            z = np.zeros(4), 
            dx=np.ones(4)*width, 
            dy=np.ones(4)*width, 
            dz=height,
            color=color
            )
    ax2.set_xticks(xticks)
    ax2.set_xticklabels(labels, minor=False)
    ax2.set_yticks(xticks)
    ax2.set_yticklabels(labels, minor=False)
    ax2.view_init(tilt, rotate)
    ax2.set_zticklabels(['-0.5  ','0 ','0.5 '])
    plt.savefig(figname+'_real_imaginary.pdf')

def plot_real(rho, figname, color='red', width=0.7, tilt=25, rotate=30, title=''):
    fig = plt.figure() #default is (8,6)
    ax = fig.add_subplot(111, projection='3d')
    ax.set_zlim3d(-0.7, 0.7) 
    # for z,height in enumerate(abs(rho)):
        # ax.bar(np.arange(4),height,zs=z,zdir='y',color='b',alpha=0.6)
    # for z,height in enumerate(np.imag(rho)):
        # ax.bar(np.arange(4),height,zs=z,zdir='y',color='b',alpha=0.6)
    for x,height in enumerate((np.real(rho))):
        ax.grid(False)
        ax.set_zticks([-0.5,0,0.5])
        ax.bar3d(
            x = np.ones(4)*x, 
            y = np.arange(4), 
            z = np.zeros(4), 
            dx=np.ones(4)*width, 
            dy=np.ones(4)*width, 
            dz=height,
            color=color
            )
    xticks = np.arange(4)
    labels = ['HH','HV','VH','VV']
    ax.set_xticks(xticks)
    ax.set_xticklabels(labels, minor=False)
    ax.set_yticks(xticks)
    ax.set_yticklabels(labels, minor=False)
    ax.view_init(tilt, rotate)
    ax.set_zticklabels(['-0.5  ','0 ','0.5 '])
    ax.set_title(title, fontsize=30)
    plt.savefig(figname+'_real.pdf')

if __name__ == '__main__':
    rho_before = np.loadtxt('rho_before.dat').view(complex)
    rho_after = np.loadtxt('rho_after.dat').view(complex)

    # np.savetxt('rho_before_real.dat', np.real(rho_before))
    # np.savetxt('rho_before_imaginary.dat', np.imag(rho_before))
    # np.savetxt('rho_after_real.dat', np.real(rho_after))
    # np.savetxt('rho_after_imaginary.dat', np.imag(rho_after))

    plot_real(rho_before, 'rho_before',title='No Circulator')
    plot_real(rho_after, 'rho_after', title='Circulators in Channel')

    # plot_real_imaginary(rho_after-rho_before, 'rho_diff')
