# # Solution Exacte Saint-Venant

import numpy as np

import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (15,8)


# ========================================================
# Shift des vecteurs avec conditions de Neumann
# ========================================================
v_plus  = lambda vv : np.concatenate( (vv[1:], vv[-1:]) ) 
v_moins = lambda vv : np.concatenate( (vv[:1], vv[:-1]) )


# ========================================================
# SV
# ========================================================

# WW = [h,hv]
# QQ = [hv,hv^2+gh^2/2]
Q0 = lambda WW : WW[1] 
Q1 = lambda WW : WW[1]**2/WW[0]+g*WW[0]**2/2  

lambda_0 = lambda h,v : v-np.sqrt(h*g)
lambda_1 = lambda h,v : v+np.sqrt(h*g)

# ===================================================================================
# ## APPROX
# ===================================================================================

# Flux de Lax-Friedrichs en variables conservatives  avec 2 équations
# g0 = lambda w0L, w1L, w0R, w1R :  ( Q0([w0L, w1L])+Q0([w0R, w1R]) - (w0R-w0L)*dx/dt )/2 
# g1 = lambda w0L, w1L, w0R, w1R :  ( Q1([w0L, w1L])+Q1([w0R, w1R]) - (w1R-w1L)*dx/dt )/2 

# Flux de Rusanov en variables conservatives avec 2 équations
def g0(w0L, w1L, w0R, w1R):
    l0L = np.max(abs(lambda_0(w0L,w1L/w0L)))
    l0R = np.max(abs(lambda_0(w0R,w1R/w0R)))
    l1L = np.max(abs(lambda_1(w0L,w1L/w0L)))
    l1R = np.max(abs(lambda_1(w0R,w1R/w0R)))
    A = max( [l0L,l0R,l1L,l1R] )
    return ( Q0([w0L, w1L])+Q0([w0R, w1R]) - (w0R-w0L)*A )/2 

def g1(w0L, w1L, w0R, w1R):
    l0L = np.max(abs(lambda_0(w0L,w1L/w0L)))
    l0R = np.max(abs(lambda_0(w0R,w1R/w0R)))
    l1L = np.max(abs(lambda_1(w0L,w1L/w0L)))
    l1R = np.max(abs(lambda_1(w0R,w1R/w0R)))
    A = max( [l0L,l0R,l1L,l1R] )
    return ( Q1([w0L, w1L])+Q1([w0R, w1R]) - (w1R-w1L)*A )/2 

# Schema VF avec 2 équations
def sol_approx(hh,vv,dx,dt):

    WW0, WW1 = hh, hh*vv

    WW0c, WW1c = np.copy(WW0), np.copy(WW1)
    WW0p, WW1p = v_plus(WW0), v_plus(WW1)
    WW0m, WW1m = v_moins(WW0), v_moins(WW1)

    # première équation
    flux0p = g0(WW0c,WW1c,WW0p,WW1p)
    flux0m = np.roll(flux0p,1)
    flux0m[0] = g0(WW0m[0],WW1m[0],WW0c[0],WW1c[0])
    WW0_new = WW0c - (dt/dx) * (  flux0p - flux0m  )
    
    # deuxième équation
    flux1p = g1(WW0c,WW1c,WW0p,WW1p)
    flux1m = np.roll(flux1p,1)
    flux1m[0] = g1(WW0m[0],WW1m[0],WW0c[0],WW1c[0])
    WW1_new = WW1c - (dt/dx) * (  flux1p - flux1m  )

    return WW0_new , WW1_new/WW0_new

# ===================================================================================
# ## MAIN
# ===================================================================================

# bornes du domaine [x_L,x_R] et point de discontinuité x_0
x_L, x_0, x_R = -10, 0, 10
Nx = 400 # nb de mailles
g  = 9.81 # gravité
cfl = 0.5

xx = np.linspace(x_L,x_R,Nx) 
dx = xx[1]-xx[0]

# Données Pbs de Riemann
IC = { '1-choc, 2-choc': [(1,1),(1,-1)], 
       '1-choc, 2-détente': [(1,1),(2,1)], 
       '1-détente, 2-choc': [(2,1),(1,1)], 
       '1-détente, 2-détente': [(1,-1),(1,1)], 
    #    'zone sèche, 2-détente': [(0,0),(1,1)], 
    #    '1-détente, zone sèche': [(1,1),(0,0)], 
    #    '1-détente, zone sèche, 2-détente': [(0.1,-3),(0.1,3)] 
       }


for i,key in enumerate(IC.keys()):
    
    val = IC[key]
    h_L, v_L = val[0]
    h_R, v_R = val[1]

    # Initialisation
    t = 0
    hh = np.zeros_like(xx)
    vv = np.zeros_like(xx)
    mask_L = xx<0
    mask_R = xx>=0
    hh[mask_L] , vv[mask_L] = h_L , v_L
    hh[mask_R] , vv[mask_R] = h_R , v_R
    
    fig,(ax0,ax1) = plt.subplots(nrows=1,ncols=2)

    # Marche en temps
    while t<1.0:
        
        l0 = np.max(abs(lambda_0(hh,vv)))
        l1 = np.max(abs(lambda_1(hh,vv)))
        dt = 0.999*dx/max( l0, l1 )
        t += dt
        
        hh,vv = sol_approx(hh,vv,dx,dt)
        
        ax0.cla()
        ax0.set_title(f"${h_L=:g}$, ${h_R=:g}$")
        ax0.plot(xx,hh,'-o')
        ax0.grid()
        
        ax1.cla()
        ax1.set_title(f"${v_L=:g}$, ${v_R=:g}$")
        ax1.plot(xx,vv,'-o')
        ax1.grid()
        
        plt.suptitle(f"Test : {key}, {t=:g}, {dt=:g}")
        plt.pause(0.0001)
