#!/usr/bin/env python3

import h5py
import re
import matplotlib.pyplot as plt
import numpy as np
import sys

def extract_data( h5file: str, part: str ):
    data = []
    with h5py.File( h5file, 'r' ) as f:
        group_names = list(f.keys())

        pattern = re.compile( r'^d3plot_stp\d{3}$')
        for group_name in group_names:
            if pattern.match(group_name):
                base = f"/{group_name}/fields03/grid/scalar"
                # time
                time = f"{base}/time"
                # kinetic energy
                ke = f"{base}/kinetic energy({part})"
                # internal energy
                ie = f"{base}/internal energy({part})"
                # total energy
                te = f"{base}/total energy({part})"
                t = f[time][:][0]
                kw = f[ke][:][0]
                iw = f[ie][:][0]
                tw = f[te][:][0]
                data.append( [t, kw, iw, tw] )
    return np.array(data)

def plot( data, part ):
    t  = data[:,0]
    kw = data[:,1]
    iw = data[:,2]
    tw = data[:,3]
    
    plt.figure(figsize=(12,8))
    
    plt.plot(t,tw,marker='o', linestyle='-',color='b',label='Total')
    plt.plot(t,iw, linestyle='-',color='g',label='Internal')
    plt.plot(t,kw, linestyle='-',color='m',label='Kinetic')
    plt.xlabel('time [s]')
    plt.ylabel('energy')
    plt.title(f'Energies over time [{part}]')
    plt.grid(True)
    plt.legend()
    plt.savefig(f'energy_{part}.png', format='png', dpi=300)
    plt.show()
    

if __name__ == "__main__":
    try:
        if len(sys.argv) is not 3:
            raise ValueError(f"Usage: {sys.argv[0]} h5-file part, part = 0|1|ALL")
        #
        # PARSE COMMAND LINE
        h5file = sys.argv[1]
        part = "PART_"
        part = f"{part}{sys.argv[2]}"
        #
        ## EXTRACT DATA
        data = extract_data( h5file, part )
        #
        # PLOT DATA
        plot(data,part)
        #
        # DONE
    except ValueError as e:
        print( f"An error has ocurred: {e}")
        
