import numpy as np
import matplotlib.pyplot as plt


# ========================================================
# DONNEE INITALE (Pb de Riemann) et SOLUTION EXACTE discrète
# ========================================================

# Pour t=0 on trouve la donnée initiale de type Riemann
def exacte(t,xx):
    sigma = (q(u_R)-q(u_L))/(u_R-u_L) # vitesse d'un choc
    det = lambda t,x : u_L if (x-x_0)<=dq(u_L)*t else ( u_R if (x-x_0)>=dq(u_R)*t else inv_dq((x-x_0)/t) )
    if u_L >= u_R : # choc 
        return np.array([u_L if (x-x_0)<=sigma*t else u_R for x in xx])
    else : # detente
        return np.array([det(t,x) for x in xx])


# ========================================================
# BC : Neumann homogènes
# Shift des vecteurs avec conditions de Neumann
# ========================================================
v_plus  = lambda vv : np.concatenate( (vv[1:], vv[-1:]) ) 
v_moins = lambda vv : np.concatenate( (vv[:1], vv[:-1]) )



# ========================================================
# SCHEMAS : définition des flux numériques
# ========================================================
Lax_Friedrichs = lambda wL,wR : ( q(wL)+q(wR) - (wR-wL)*dx/dt )/2

Lax_Wendroff = lambda wL,wR : ( q(wL)+q(wR) - dt/dx * dq((wL+wR)/2) *(q(wR)-q(wL)) )/2

Upwind_Conservatif_1 = lambda wL,wR :  q(wL)*(dq((wL+wR)/2)>=0) + q(wR)*(dq((wL+wR)/2)<0)

Upwind_Conservatif_2 = lambda wL,wR :  q(wL)*((q(wR)-q(wL))*(wR-wL)>=0) + q(wR)*((q(wR)-q(wL))*(wR-wL)<0)

def Godunov (wL,wR) : 
    flux = np.ones_like(wL)
    mask_L_1 = (wL>=0) & (wR>=0)
    flux[mask_L_1]=q(wL[mask_L_1])
    mask_L_2 = (wL>=0) & (0>=wR) & (wR+wL>0)
    flux[mask_L_2]=q(wL[mask_L_2])
    mask_R_1 = (wL<=0) & (wR<=0)
    flux[mask_R_1]=q(wR[mask_R_1])
    mask_R_2 = (wL>=0) & (0>=wR) & (wR+wL<0)
    flux[mask_R_2]=q(wR[mask_R_2])
    mask_0 = (wL<0) & (0<wR)
    flux[mask_0]=q(0)
    return flux



# ========================================================
# FONCTION POUR L'AFFICHAGE à l'instant t
# ========================================================
def affichage(ax,t,uu):
    ax.plot(xx,exacte(0,xx),':',label='Initial')
    ax.plot(xx,exacte(t,xx),label='Exacte')
    ax.plot(xx,uu,'-o',label='Approx')
    #ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    ax.legend()
    ax.set_xlabel('$x$')
    ax.set_ylabel('$u$')
    ax.grid()
    ax.set_xlim(xx[0],xx[-1])
    ax.set_ylim(min(u_L,u_R)*1.5,max(u_L,u_R)*1.5)
    

# ========================================================  
# MAIN
# ========================================================

# EDP
q  = lambda u : u**2/2  # flux de l'équation, ici Burgers
dq = lambda u : u       # q' 
inv_dq = lambda y : y   # (q')^(-1) car (x-x_0)/t = q'(sol) donc sol= (q')^(-1)((x-x_0)/t)


# bornes du domaine [x_L,x_R] et point de discontinuité x_0
x_L, x_0, x_R = -2, 0, 2

Nx = 50 # nb de mailles
T  = 1.5  # temps final

cfl = 0.99
# cfl = 0.4

# Données Pbs de Riemann
IC = { 'choc': (2,-1), 
       'détente': (-0.5,1) }

Liste_schemas = ['Godunov','Lax_Friedrichs','Lax_Wendroff','Upwind_Conservatif_1','Upwind_Conservatif_2']

L  = x_R-x_L   # longueur du domaine
xx = np.linspace(x_L,x_R,Nx) # centres des cellules du maillage
dx = xx[1]-xx[0] # pas d'espace

for i,key in enumerate(IC.keys()):
    u_L, u_R = IC[key]

    # initialisation
    uu = exacte(0,xx) # donnee initiale
    sol = { schema:np.copy(uu) for schema in Liste_schemas }
    dt = cfl*dx/np.amax(abs(uu)) 
    fig,axes = plt.subplots(nrows=2,ncols=3,figsize=(20,10))
    for schema,ax in zip(Liste_schemas,axes.reshape(-1)):
        ax.set_title(schema)
        affichage(ax,0,sol[schema])
    print("nt=0, t=0")
    fig.suptitle(f'dx={dx:g},  t=0')
    plt.pause(1)

    for schema,ax in zip(Liste_schemas,axes.reshape(-1)):
        print(f"\n{schema=}")
        t  = 0
        nt = 0 # nombre de pas de temps effectues

        # marche en temps
        while t+dt<=T:
        
            dt = cfl*dx/np.amax(abs(sol[schema])) # il change en fct du schema !!!
            t  += dt
            nt += 1 # nb de pas en temps, pour affichage

            uuc = np.copy(sol[schema])
            uup = v_plus(sol[schema])
            uum = v_moins(sol[schema])
            g = lambda wL, wR :  eval(schema)(wL, wR)
            fluxp = g(uuc,uup)
            fluxm = np.roll(fluxp,1)
            fluxm[0] = g(uum[0],uuc[0])
            # sol[schema] = uuc - (dt/dx) * (  g(uuc,uup) - g(uum,uuc)  )
            sol[schema] = uuc - (dt/dx) * (  fluxp - fluxm  )
            # affichage de la solution approchee et exacte
            if nt%5==0 or abs(T-t)<=dt:
                print(f"{nt=}, {t=}")
                fig.suptitle(f'dx={dx:g}')
                plt.pause(0.0001)
                ax.cla()
                ax.set_title(f"{schema}, dt={dt:g}, t={t:g}")
                affichage(ax,t,sol[schema])


plt.show()