
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (30,10)
plt.rcParams.update({'font.size': 12})

# ========================================================
# Solution de l'EDP
# $\partial_t u + c \partial_x u = 0$
# avec condition initiale $u(0,x) = g(x)$
# sur le domaine $[0,L]$
# ========================================================

# ========================================================
## Paramètres
# ========================================================

# vitesse du transport
# TO DO tester avec des vitesses négatives
c = 1

# Le domaine spatial est [0,L]
L  = 8   

# nb de points de discrétisation spatiale
Nx = 51  
# maillage
xx = np.linspace(0,L,Nx) 
# pas d'espace
dx = xx[1]-xx[0] 

# Avec des conditions aux limites périodiques, on faira un nombre entier de tours.
# Le temps final est calculé pour faire nb_tours et retrouver la donnée initiale
nb_tours = 1
T  = nb_tours*L/abs(c)   


# paramètre pour la stabilité
# TO DO, tester avec des valeurs de cfl de 0.1 à 1.1 
# cfl = 0.1 # tres dissipatif
# cfl = 0.5
# cfl = 0.999
cfl = 1
# cfl = 1.1 # solution instable

# pas de temps : condition CFL
# pour l'équation de transport, le pas de temps est fixé une fois pour toute 
# ce ne sera pas le cas pour les autres équations où le flux dépend de la solution
dt = cfl*dx/abs(c) 






# ========================================================
# DONNEE INITALE & SOLUTION EXACTE
# ========================================================

# Si u(0,x) = g(x) alors u(t,x) = g(x-ct)
# Avec des conditions aux limites périodiques, on utilise
# np.roll() pour décaler les valeurs de uu de c*t/dx

def g(xx) :
    uu = np.zeros_like(xx)
    mask = xx<L/2
    uu[mask] = ( 1+np.sin(4*np.pi*xx[mask]/L-np.pi/2) )/2
    mask = (2*L/3<xx) & (xx<5*L/6)
    uu[mask] = 1
    return uu
    
exacte = lambda t,xx : np.roll(g(xx),int(c*t/dx))


# ========================================================
# LISTE DES SCHEMAS
# ========================================================

# Pour chaque schema, 
# en entrée uu vecteur sol à l'instant t^n, 
# en sortie uuNew vecteur sol à l'instant t^{n+1}

def Gauche(uu):
    uuc = np.copy(uu)
    uum = np.roll(uu,1) # uum[j] <- uu[j-1] avec cond. bord periodiques
    uuNew = uuc - c*dt/dx*(uuc-uum)
    return uuNew

def Droite(uu):
    return uu

def Upwind(uu):
    return uu

def Centre(uu):
    return uu

def Lax_Friedrichs(uu):
    return uu

def Lax_Wendroff(uu):
    return uu



# ========================================================
# ========================================================
## Main
# ========================================================
# ========================================================

# Liste de nom de fonctions
# Chaque fonction est un schema 
Liste_schemas = [ 
                Gauche, 
                Droite, 
                Upwind,
                Centre, 
                Lax_Friedrichs,
                Lax_Wendroff
                ]

# choisir la grille d'affichage des solutions
nrows, ncols = 2, 3


# ========================================================
# initialisation 
# ========================================================

# temps
t  = 0
temps = [ t ] # pour stocker les instants t^n pour les affichages

# nombre de pas de temps effectués
nt = 0 

# solution exacte à t=0 : donnee initiale
uu_init = exacte(0,xx)

# La solution exacte est une liste de vecteurs numpy, chaque vecteur est la solution
# exacte à un instant donné
sol_exacte = [ np.copy(uu_init) ] 

# dictionnaire des solutions : pour chaque schema, on sauvegarde une liste, chaque
# élément de la liste est un vecteur numpy solution à un instant donné
sol_approx = { schema.__name__ : [ np.copy(uu_init) ] for schema in Liste_schemas } 



# ========================================================
# marche en temps et sauvagarde des solutions
# ========================================================
while t<T+dt:
    t  += dt
    nt += 1 
    temps.append(t)
    uu_exacte = exacte(t,xx)
    sol_exacte.append(uu_exacte)
    for schema in Liste_schemas:
        schema_name = schema.__name__
        uu_old = sol_approx[schema_name][-1]
        uu_new = schema(uu_old)
        sol_approx[schema_name].append(uu_new)

print("Fin marche en temps")




# ========================================================
# Affichages
# ========================================================

# ax : repère pour l'affichage (dans une grille de nrows x ncols, avec numéro de 0 à nrows*ncols-1)
# t  : instant
# uu_* : solution exacte/approchée à l'instant t 

def affichage(ax, t, uu_exacte, uu_approx):
    ax.plot(xx, uu_exacte, label='Exacte')
    ax.plot(xx, uu_approx ,'-o', label='Approx')
    ax.legend()
    ax.set_xlabel('$x$')
    ax.set_ylabel('$u$')
    ax.grid()


# Initialisation 
fig,axes = plt.subplots(nrows=nrows,ncols=ncols)
fig.suptitle(f'{dt = :g}, {dx = :g}, alpha = {c*dt/dx:g}, t = {temps[0]:g}')
for schema, ax in zip(Liste_schemas,axes.reshape(-1)):
    schema_name = schema.__name__
    ax.set_title(schema_name)
    ax.plot(xx, uu_init, ':', label='t = 0')
    affichage(ax, temps[0], sol_exacte[0], sol_approx[schema_name][0])
plt.draw()
fig.tight_layout()
plt.pause(1)  


# Affichage de la solution à chaque instant
for nt in range(1,len(temps)):
    t = temps[nt]
    fig.suptitle(f'{dt = :g}, {dx = :g}, alpha = {c*dt/dx:g}, t = {temps[0]:g}')
    plt.pause(0.001)
    for schema, ax in zip(Liste_schemas, axes.reshape(-1)):
        schema_name = schema.__name__
        ax.cla()
        ax.set_title(schema_name)
        ax.plot(xx, uu_init, ':', label='t = 0')
        affichage(ax, t, sol_exacte[nt], sol_approx[schema_name][nt])
    fig.tight_layout()

plt.show()