import numpy as np
import matplotlib.pyplot as plt

# ========================================================
# PARAMETRES
# ========================================================

# EDP
q  = lambda u : u**2/2  # flux de l'équation
dq = lambda u : u       # q' 
inv_dq = lambda y : y   # (q')^(-1) car (x-x_0)/t = q'(sol) donc sol= (q')^(-1)((x-x_0)/t)

# données Pb de Riemann
# u_L, u_R = 0, 1 # detente
u_L, u_R = 1, 0 # choc

# bornes du domaine [x_L,x_R] et point de discontinuité x_0
x_L, x_0, x_R = -2, 0, 2

Nx = 50 # nb de mailles
T  = 1  # temps final

cfl = 0.999


# ========================================================
# DONNEE INITALE (Pb de Riemann) et SOLUTION EXACTE discrète
# ========================================================

# Pour t=0 on retrouve la donnée initiale de type Riemann
def exacte(t,xx):
    sigma = (q(u_R)-q(u_L))/(u_R-u_L) # vitesse d'un choc
    det = lambda t,x : u_L if (x-x_0)<=dq(u_L)*t else ( u_R if (x-x_0)>=dq(u_R)*t else inv_dq((x-x_0)/t) )
    if u_L >= u_R : # choc 
        return np.array([u_L if (x-x_0)<=sigma*t else u_R for x in xx])
    else : # detente
        return np.array([det(t,x) for x in xx])


# ========================================================
# BC : Neumann homogènes
# Shift des vecteurs avec conditions de Neumann
# ========================================================
v_plus  = lambda vv : np.concatenate( (vv[1:], vv[-1:]) ) 
v_moins = lambda vv : np.concatenate( (vv[:1], vv[:-1]) )

# ========================================================
# SCHEMA
# ========================================================

def Upwind_NON_Conservatif(uu,dt,choix):
    uup = v_plus(uu)
    uuc = np.copy(uu)
    uum = v_moins(uu)
    if choix=='j-1':
        c =  dq(uum)
    elif choix=='j-1/2':
        c = dq((uum+uuc)/2)
    elif choix=='j':
        c = dq(uuc)
    elif choix=='j+1/2':
        c = dq((uup+uuc)/2)
    elif choix=='j+1':
        c = dq(uup)
    return uuc - (dt/dx)*( (uuc-uum)*c*( c>0 ) + (uup-uuc)*c*( c<=0 ) ) 


# ========================================================
# FONCTION POUR L'AFFICHAGE à l'instant t
# ========================================================
def affichage(ax,t,uu):
    ax.plot(xx,exacte(0,xx),':',label='Initial')
    ax.plot(xx,exacte(t,xx),label='Exacte')
    ax.plot(xx,uu,'-o',label='Approchée')
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    ax.set_xlabel('$x$')
    ax.set_ylabel('$u$')
    ax.grid()
    

# ========================================================  
# MAIN
# ========================================================

L  = x_R-x_L   # longueur du domaine
xx = np.linspace(x_L,x_R,Nx+1) # points du maillage
dx = xx[1]-xx[0] # pas d'espace

t  = 0
uu = exacte(0,xx) # donnee initiale

# choix du pas en temps dt en fixant le nombre de Courant cfl valeur pour la condition CFL
dt = cfl*dx/np.amax(abs(dq(uu))) 
nt = 0 # nombre de pas de temps effectues

# affichage de la donnee initiale
fig,axes = plt.subplots(nrows=1,ncols=1,figsize=(20,10))
ax = axes

affichage(ax,t,uu)
plt.pause(0.5)

# marche en temps
while t+dt<=T:
    
    dt = cfl*dx/np.amax(abs(dq(uu))) 
    t  += dt
    nt += 1 # nb de pas en temps, pour affichage

    schema = 'Upwind_NON_Conservatif'
    # uu = eval(schema)(uu,'j-1')
    uu = eval(schema)(uu,dt,'j-1/2')
    # uu = eval(schema)(uu,'j')
    # uu = eval(schema)(uu,'j+1/2')
    # uu = eval(schema)(uu,'j+1')

    # schema = 'Upwind_Conservatif'
    # uu = eval(schema)(uu)

    # affichage de la solution approchee et exacte
    if nt%5==0 or abs(T-t)<=dt:
        print(f"{nt=}, {t=}")
        affichage(ax,t,uu)
        fig.suptitle(f'dt={dt:g}, dx={dx:g}, t={t:g}')
        plt.pause(0.001)
        ax.cla()

affichage(ax,t,uu)
fig.suptitle(f'dt={dt:g}, dx={dx:g}, t={t:g}')
plt.show()