import netCDF4
import matplotlib.pyplot as plt
import cartopy.feature as cfeature
import numpy as np
import pandas as pd
import cartopy.crs as ccrs
import os
from datetime import datetime
from datetime import timedelta
from dateutil.relativedelta import relativedelta

def ReadNETCDF(FilePath, VariName):
    f = netCDF4.Dataset(FilePath)
    MyArray = f.variables[VariName][:].ravel()
    return MyArray

def ReadNETCDF_2D(FilePath, VariName):
    f = netCDF4.Dataset(FilePath)
    MyArray = f.variables[VariName][:]
    return MyArray

def Computediff(VariName):
    print(VariName)
    DAFilePath="//scratch2/NCEPDEV/land/Youlong.Xia/landDA/cycle_land/DA_era5/mem000/restarts/vector"
    OpenloopFilePath="/scratch2/NCEPDEV/land/Youlong.Xia/landDA/cycle_land/openloop_era5/mem000/restarts/vector"
    imsFilePath="/scratch2/NCEPDEV/land/data/DA/snow_ice_cover/IMS/CalSnowCover/4km/C96"

    RestartDate = "2019-10-01_18-00-00"
    DeltaTime = 1  # day
    SimulationDuration = 244  # days
    simuMonth = 8 # months    

    RestartDate = datetime.strptime(RestartDate, '%Y-%m-%d_%H-%M-%S')
#we want to create 3 dataframe: 1) Truth_stdDF: the rows are locations and columns are truth value at each time step and the last column will be the standard deviation of truth, 2) diff_averageDF: the rows are same as previous and each column
#is difference between OI and Truth at each time step, the last column will be average of diffs at each row (location), 3) diff2_rmseDF: rows same as previous but each column will be diff**2 and the last column will be RMSE that is sqrt of average diff**2

# Read IMS basin mask (which is different from ufs land basin mask due to 
# latitude and longitude order)

    basin0 = ReadNETCDF(imsBasinPath, "basins")
    ims = np.empty([len(basin0)])   
    ims[:] = np.nan

#   defined the basin mask
    basin = ReadNETCDF(BasinPath, "basins")
    DA = np.empty([len(basin)])
    Openloop = np.empty([len(basin)])
    DA[:] = np.nan
    Openloop[:] = np.nan 
   
    ims_avg = np.empty([SimulationDuration]) 
    model_avg = np.empty([SimulationDuration])
    DA_avg = np.empty([SimulationDuration])
    model_avg[:] = np.nan
    DA_avg[:] = np.nan
    ims_avg[:] = np.nan

    xx = np.empty([SimulationDuration])
    
# Read ims snow cover data
    
    ist = 0
    for im in range(simuMonth):
        new_date = RestartDate + relativedelta(months=im)
        str_yr = new_date.strftime('%Y')
        str_m = new_date.strftime('%m')        
        new_mon = str_yr+str_m
        imsFile = imsFilePath+ "/IMSscf.C96.4km."+new_mon+".nc"
        
        snowc = ReadNETCDF_2D(imsFile , "IMSscf")
        #print(snowc.shape, snowc.size, snowc.ndim)
        ie = ist+len(snowc[:,0].ravel())
        
        for it in range(len(snowc[:,0])):
            for j in range(len(basin0)):
                if basin0[j] == 2106:
                    ims[j] = snowc[it,j]
            avg = np.nanmean(ims)
            ims_avg[ist:ie] =  avg
        ist = ie

    print("ims snow cover calculation completed")
    
    for SimDay in range(SimulationDuration):
        print(SimDay)
        PassedTimeStepNo = (SimDay)
        ThisRestartDate = RestartDate + timedelta(days=PassedTimeStepNo * DeltaTime)
        print(ThisRestartDate)
        xx[SimDay] = SimDay

        thisyear, thismonth, thisday, thishr, thismin, thissec = ThisRestartDate.year, ThisRestartDate.month, ThisRestartDate.day, ThisRestartDate.hour, ThisRestartDate.minute, ThisRestartDate.second
          
# Read DA and open loop data files

        DAFile = DAFilePath+ "/ufs_land_restart_back.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec)
        openloopFile = OpenloopFilePath+ "/ufs_land_restart_back.%4d-%02d-%02d_%02d-%02d-%02d.nc" % (thisyear, thismonth, thisday, thishr, thismin, thissec)
        DA_array = ReadNETCDF(DAFile , VariName)
        Openloop_array = ReadNETCDF(openloopFile , VariName)

        for i in range(len(basin)):
            if basin[i] == 2106:
                DA[i] = DA_array[i]
                Openloop[i] =Openloop_array[i]
        modelavg = np.nanmean(Openloop)
        DAavg = np.nanmean(DA)
        model_avg[SimDay] = modelavg
        DA_avg[SimDay] = DAavg

    plt.figure(figsize=(12,8))
    ax1 = plt.subplot(1, 1, 1)
    ax1.set_ylim(0.0, 1.0)
    ax1.plot(xx,model_avg,'b-',label="openloop")
    ax1.plot(xx,DA_avg,'r-',label="DA")
    ax1.plot(xx,ims_avg,'k-',label="IMS") 
    ax1.set_ylabel("Snow Cover Fraction", fontsize=20)
    ax1.set_xlabel("Simulation days since 01 October 2019", fontsize=20)
    ax1.set_title("Lena River Basin", pad=20, fontsize=20)
    ax1.legend()
    plt.tight_layout()
    plt.savefig('Lena_River_basin_ims_da_openloop.png')
    plt.close()

BasinPath = "/scratch2/NCEPDEV/land/data/evaluation/basins/C96/GRDC_C96.nc"
imsBasinPath = "/scratch2/NCEPDEV/land/Zhichang.Guo/lsmask/IMS_C96_corners.nc" 
PlotPath= "/scratch2/NCEPDEV/land/Youlong.Xia/landDA/plots/C96/Average_DA-OpenLoopDiffsnap/"
RMSEplot= "/scratch2/NCEPDEV/land/Youlong.Xia/landDA/plots/C96/RMSE_DA-OpenLoopDiffsnap/"


VarList=["snow_cover_fraction"]
for var in VarList:
    Computediff(var)
