import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (30,10)
plt.rcParams.update({'font.size': 12})

# ========================================================
# Solution de l'EDP
# $\partial_t u + \partial_x q(u) = 0$
# avec condition initiale de Riemann (discontinuité en x_0)
# sur le domaine $[x_L,x_R]$
# ========================================================

# ========================================================
## Paramètres
# =======================================================

q  = lambda u : u**2/2  # flux de l'équation, ici Burgers
dq = lambda u : u       # q' 
inv_dq = lambda y : y   # (q')^(-1) car (x-x_0)/t = q'(sol) donc sol= (q')^(-1)((x-x_0)/t)


# bornes du domaine [x_L,x_R] et point de discontinuité x_0
x_L, x_0, x_R = -2, 0, 2
L  = x_R-x_L   # longueur du domaine

# nb de mailles de discrétisation spatiale
Nx = 50 

# centres des cellules du maillage
xx = np.linspace(x_L,x_R,Nx) 
# pas d'espace
dx = xx[1]-xx[0] 


# temps final
T  = 1.5  


# paramètre pour la stabilité : condition CFL
# le pas de temps change au cours du temps
cfl = 0.99
# cfl = 0.4



# ========================================================
# DONNEE INITALE (Pb de Riemann) et SOLUTION EXACTE discrète
# ========================================================

# Pour t=0 on trouve la donnée initiale de type Riemann
def exacte(t,xx,u_L,u_R,x_0):
    uu = np.zeros_like(xx)
    # TO DO
    return uu





# ========================================================
# SCHEMAS : définition des flux numériques
# ========================================================
Lax_Friedrichs = lambda wL,wR : ( q(wL)+q(wR) - (wR-wL)*dx/dt )/2





# ========================================================
# BC : Neumann homogènes
# Shift des vecteurs avec conditions de Neumann
# ========================================================
v_plus  = lambda vv : np.concatenate( (vv[1:], vv[-1:]) ) 
v_moins = lambda vv : np.concatenate( (vv[:1], vv[:-1]) )


# ========================================================
# ========================================================
## Main
# ========================================================
# ========================================================

# Liste de nom de fonctions
# Chaque fonction est un schema 
Liste_schemas = [ 
    'Godunov',
    'Lax_Friedrichs',
    'Lax_Wendroff',
    'Upwind_Conservatif_1',
    'Upwind_Conservatif_2'
    ]


# ========================================================
# initialisation 
# ========================================================

# Données Pbs de Riemann
# uL, uR = 2, 1
uL, uR = -0.5, 1

# choix du schéma
schema = Lax_Friedrichs


schema_name = schema.__name__
g = lambda wL, wR :  schema(wL, wR) # flux numérique

# solution exacte à t=0 : donnee initiale
uu_init = np.ones_like(xx)*uL
mask = (xx>x_0)
uu_init[mask] = uR

# La solution  est une liste de vecteurs numpy, chaque vecteur est la solution à un instant donné
sol_exacte = [ np.copy(uu_init) ] 
sol_approx = [ np.copy(uu_init) ] 


# temps
t  = 0
temps = [ t ] # pour stocker les instants t^n pour les affichages

dt = cfl*dx/np.amax(abs(sol_approx[-1])) 


# ========================================================
# marche en temps et sauvagarde des solutions
# ========================================================
while t<T+dt:

    dt = cfl*dx/np.amax(abs(sol_approx[-1])) # il change en fct du schema !!!
    t  += dt
    temps.append(t)

    uuc = np.copy(sol_approx[-1])
    uup = v_plus(sol_approx[-1])
    uum = v_moins(sol_approx[-1])

    fluxp = g(uuc,uup)
    fluxm = v_moins(fluxp)

    uu_exacte = exacte(t,xx,uL,uR,x_0)
    uu_new = uuc - (dt/dx) * (  fluxp - fluxm  )
    sol_exacte.append(uu_exacte)
    sol_approx.append( uu_new )

print("Fin marche en temps")




# ========================================================
# FONCTION POUR L'AFFICHAGE à l'instant t
# ========================================================

def affichage(ax, t, uu_exacte, uu_approx):
    ax.plot(xx, uu_exacte, label='Exacte')
    ax.plot(xx, uu_approx ,'-o', label='Approx')
    ax.legend()
    ax.set_xlabel('$x$')
    ax.set_ylabel('$u$')
    ax.grid()
    ax.set_xlim(xx[0],xx[-1])
    ax.set_ylim(min(uL,uR)*1.5,max(uL,uR)*1.5)


# Initialisation 
fig,ax = plt.subplots(nrows=1,ncols=1)
fig.suptitle(f'{dt = :g}, {dx = :g}, {uL:g}, {uR:g}, t = {temps[0]:g}')
ax.set_title(schema_name)
ax.plot(xx, uu_init, ':', label='t = 0')
affichage(ax, temps[0], sol_exacte[0], sol_approx[0])

plt.draw()
fig.tight_layout()
plt.pause(1)  


# Affichage de la solution à chaque instant
for nt in range(1,len(temps)):
    t = temps[nt]
    fig.suptitle(f'{dt = :g}, {dx = :g}, {uL:g}, {uR:g}, t = {temps[0]:g}')
    plt.pause(0.0001)
    ax.cla()
    ax.set_title(schema_name)
    ax.plot(xx, uu_init, ':', label='t = 0')
    affichage(ax, t, sol_exacte[nt], sol_approx[nt])
    fig.tight_layout()
   


plt.show()
