# # Solution Exacte Saint-Venant

import numpy as np

import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (15,8)

from scipy.optimize import fsolve

# ===================================================================================
# ## Fonctions utiles
# ===================================================================================

# Vitesses des chocs
sigma_1 = lambda h_L,v_L,h_s,v_s : v_L-h_s*np.sqrt( g/2 * (h_L+h_s)/(h_L*h_s) )
sigma_2 = lambda h_R,v_R,h_s,v_s : v_R+h_s*np.sqrt( g/2 * (h_R+h_s)/(h_R*h_s) )

# Vitesses des détentes
lambda_1 = lambda h,v : v-np.sqrt(g*h)
lambda_2 = lambda h,v : v+np.sqrt(g*h)

# Solution dans la 1-detente
h_det_1 = lambda x,t,h_L,v_L : (v_L+2*np.sqrt(g*h_L)-x/t)**2/(9*g) 
v_det_1 = lambda x,t,h_L,v_L : (v_L+2*np.sqrt(g*h_L)+2*x/t)/3

# Solution dans la 2-detente
h_det_2 = lambda x,t,h_R,v_R : (-v_R+2*np.sqrt(g*h_R)+x/t)**2/(9*g) 
v_det_2 = lambda x,t,h_R,v_R : (v_R-2*np.sqrt(g*h_R)+2*x/t)/3

# Invariants de Riemann pour les cas secs
Inv_1 = lambda h,v : v+2*np.sqrt(g*h)
Inv_2 = lambda h,v : v-2*np.sqrt(g*h)

# État intermediaire
def star(h_L,v_L,h_R,v_R):
    if (h_L>0) and (h_R>0) and ( v_R-v_L < 2*np.sqrt(g)*(np.sqrt(h_R)+np.sqrt(h_L)) ) :
        zeta = lambda h,psi : ( 2*np.sqrt(g)/(np.sqrt(h)+np.sqrt(psi)) ) if h <= psi else ( np.sqrt( g/2 * (h+psi)/(h*psi) ) )
        f = lambda coeur : v_R-v_L+(coeur-h_L)*zeta(coeur,h_L)+(coeur-h_R)*zeta(coeur,h_R)
        h_s = fsolve(f,(h_L+h_R)/2)[0]
        v_s = v_R+(h_s-h_R)*zeta(h_s,h_R)
    else:
        h_s,v_s = 0, 0
    return h_s, v_s

# ===================================================================================
# Solution exacte
# ===================================================================================
def sol_exacte_riemann(xx,t,h_L,v_L,h_R,v_R,h_s,v_s):
    hh = np.zeros_like(xx)
    vv = np.zeros_like(xx)   

    if h_s > max(h_L,h_R) > 0 :
        # print('1-choc, 2-choc')
        vit_1 = sigma_1(h_L,v_L,h_s,v_s )
        vit_2 = sigma_2(h_R,v_R,h_s,v_s)
        mask_L = (xx <= vit_1*t)
        hh[mask_L] = h_L 
        vv[mask_L] = v_L 
        mask_s = (xx > vit_1*t) & (xx < vit_2*t) 
        hh[mask_s] = h_s 
        vv[mask_s] = v_s
        mask_R = (xx >= vit_2*t)
        hh[mask_R] = h_R 
        vv[mask_R] = v_R 
         
    elif 0 < h_L < h_s <= h_R :
        # print('1-choc, 2-dét')
        vit_1 = sigma_1(h_L,v_L,h_s,v_s )
        vit_2_s = lambda_2(h_s,v_s)
        vit_2_R = lambda_2(h_R,v_R)
        mask_L = (xx <= vit_1*t)
        hh[mask_L] = h_L 
        vv[mask_L] = v_L 
        mask_s = (xx > vit_1*t) & (xx < vit_2_s*t) 
        hh[mask_s] = h_s 
        vv[mask_s] = v_s 
        mask_det_2 = (xx > vit_2_s*t) & (xx < vit_2_R*t)  
        hh[mask_det_2] = h_det_2(xx[mask_det_2],t,h_R,v_R) 
        vv[mask_det_2] = v_det_2(xx[mask_det_2],t,h_R,v_R)
        mask_R = (xx >= vit_2_R*t)
        hh[mask_R] = h_R 
        vv[mask_R] = v_R 
         
    elif 0 < h_R < h_s <= h_L :
        # print('1-dét, 2-choc')
        vit_1_L = lambda_1(h_L,v_L)
        vit_1_s = lambda_1(h_s,v_s)
        vit_2 = sigma_2(h_R,v_R,h_s,v_s)
        mask_L = (xx <= vit_1_L*t)
        hh[mask_L] = h_L 
        vv[mask_L] = v_L 
        mask_det_1 = (xx>vit_1_L*t) & (xx<vit_1_s*t)   
        hh[mask_det_1] = h_det_1(xx[mask_det_1],t,h_L,v_L) 
        vv[mask_det_1] = v_det_1(xx[mask_det_1],t,h_L,v_L)
        mask_s =  (xx>=vit_1_s*t) & (xx<vit_2*t) 
        hh[mask_s] = h_s 
        vv[mask_s] = v_s 
        mask_R = (xx>=vit_2*t)
        hh[mask_R] = h_R 
        vv[mask_R] = v_R 
         
    elif 0 < h_s <= min(h_L,h_R) :
        # print('1-dét, 2-dét')
        vit_1_L = lambda_1(h_L,v_L)
        vit_1_s = lambda_1(h_s,v_s)
        vit_2_s = lambda_2(h_s,v_s)
        vit_2_R = lambda_2(h_R,v_R)
        mask_L = (xx <= vit_1_L*t)
        hh[mask_L] = h_L 
        vv[mask_L] = v_L 
        mask_det_1 = (xx>vit_1_L*t) & (xx<vit_1_s*t)   
        hh[mask_det_1] = h_det_1(xx[mask_det_1],t,h_L,v_L) 
        vv[mask_det_1] = v_det_1(xx[mask_det_1],t,h_L,v_L)
        mask_s =  (xx>=vit_1_s*t) & (xx<vit_2_s*t)
        hh[mask_s] = h_s 
        vv[mask_s] = v_s
        mask_det_2 = (xx > vit_2_s*t) & (xx < vit_2_R*t)  
        hh[mask_det_2] = h_det_2(xx[mask_det_2],t,h_R,v_R) 
        vv[mask_det_2] = v_det_2(xx[mask_det_2],t,h_R,v_R)
        mask_R = (xx >= vit_2_R*t)
        hh[mask_R] = h_R 
        vv[mask_R] = v_R 
    
    elif h_L == 0 :
        # print('2-dét')
        vit_2_s = Inv_2(h_R,v_R)
        vit_2_R = lambda_2(h_R,v_R)
        mask_L = (xx <= vit_2_s*t)
        hh[mask_L] = h_L 
        vv[mask_L] = v_L 
        mask_det_2 = (xx > vit_2_s*t) & (xx < vit_2_R*t)  
        hh[mask_det_2] = h_det_2(xx[mask_det_2],t,h_R,v_R) 
        vv[mask_det_2] = v_det_2(xx[mask_det_2],t,h_R,v_R)
        mask_R = xx >= vit_2_R*t
        hh[mask_R] = h_R 
        vv[mask_R] = v_R 
    
    elif h_R == 0 : 
        # print('1-dét')
        vit_1_L = lambda_1(h_L,v_L)
        vit_1_s = Inv_1(h_L,v_L)
        mask_L = (xx <= vit_1_L*t)
        hh[mask_L] = h_L 
        vv[mask_L] = v_L 
        mask_det_1 = (xx>vit_1_L*t) & (xx<vit_1_s*t)   
        hh[mask_det_1] = h_det_1(xx[mask_det_1],t,h_L,v_L) 
        vv[mask_det_1] = v_det_1(xx[mask_det_1],t,h_L,v_L)
        mask_R = (xx >= vit_1_s*t)
        hh[mask_R] = h_R 
        vv[mask_R] = v_R 
         
    else :
        # print('1-dét, zone seche, 2-dét')
        vit_1_L = lambda_1(h_L,v_L)
        vit_1_s = Inv_1(h_L,v_L)
        vit_2_s = Inv_2(h_R,v_R)
        vit_2_R = lambda_2(h_R,v_R)
        mask_L = (xx <= vit_1_L*t)
        hh[mask_L] = h_L 
        vv[mask_L] = v_L 
        mask_det_1 = (xx>vit_1_L*t) & (xx<vit_1_s*t)   
        hh[mask_det_1] = h_det_1(xx[mask_det_1],t,h_L,v_L) 
        vv[mask_det_1] = v_det_1(xx[mask_det_1],t,h_L,v_L)
        mask_s =  (xx>=vit_1_s*t) & (xx<vit_2_s*t) 
        hh[mask_s] = 0 
        vv[mask_s] = 0
        mask_det_2 = (xx > vit_2_s*t) & (xx < vit_2_R*t)  
        hh[mask_det_2] = h_det_2(xx[mask_det_2],t,h_R,v_R) 
        vv[mask_det_2] = v_det_2(xx[mask_det_2],t,h_R,v_R)
        mask_R = (xx >= vit_2_R*t)
        hh[mask_R] = h_R 
        vv[mask_R] = v_R          

    return hh,vv


    
# ===================================================================================
# ## AFFICHAGE (prevu pour interact)
# ===================================================================================

def inter(fig,ax1,ax2,key,xx,t,h_L,v_L,h_R,v_R,h_s,v_s,hh,vv):
    ax1.set_ylim([min([h_L,h_R,h_s])-0.1,max([h_L,h_R,h_s])+0.1])
    ax2.set_ylim([min([v_L,v_R,v_s])-0.1,max([v_L,v_R,v_s])+0.1])
    ax1.plot(xx,hh,'-o')
    ax2.plot(xx,vv,'-o')
    ax1.set_title(f"${h_L=:g}$, ${h_s=:g}$, ${h_R=:g}$")
    ax2.set_title(f"${v_L=:g}$, ${v_s=:g}$, ${v_R=:g}$")
    ax1.grid()
    ax2.grid()
    plt.suptitle(f"Test : {key}, {t=:g}")

# ===================================================================================
# ## MAIN
# ===================================================================================

# bornes du domaine [x_L,x_R] et point de discontinuité x_0
x_L, x_0, x_R = -10, 0, 10
Nx = 100 # nb de mailles
g  = 9.81 # gravité
xx = np.linspace(x_L,x_R,Nx) 

# Données Pbs de Riemann
IC = { '1-choc, 2-choc': [(1,1),(1,-1)], 
       '1-choc, 2-détente': [(1,1),(2,1)], 
       '1-détente, 2-choc': [(2,1),(1,1)], 
       '1-détente, 2-détente': [(1,-1),(1,1)], 
       'zone sèche, 2-détente': [(0,0),(1,1)], 
       '1-détente, zone sèche': [(1,1),(0,0)], 
       '1-détente, zone sèche, 2-détente': [(0.1,-3),(0.1,3)] }


for i,key in enumerate(IC.keys()):
    val = IC[key]
    h_L, v_L = val[0]
    h_R, v_R = val[1]
    h_s, v_s = star(h_L,v_L,h_R,v_R)
    fig,(ax1,ax2) = plt.subplots(nrows=1,ncols=2)
    for t in np.linspace(0,1.5,11):    
        hh,vv = sol_exacte_riemann(xx,t,h_L,v_L,h_R,v_R,h_s,v_s)
        inter(fig,ax1,ax2,key,xx,t,h_L,v_L,h_R,v_R,h_s,v_s,hh,vv)
        plt.pause(0.1)
        ax1.cla()
        ax2.cla()

