import os
import xarray as xr
import numpy as np

print_high_snow_removal = True
print_low_snow_removal = False

sfcdata_path = "/scratch1/NCEPDEV/stmp2/Helin.Wei/spinup/hr3/C768/sfc_data/gfs."
spinup_path = "/scratch2/NCEPDEV/stmp1/Youlong.Xia/landDA/cycle_land/C768/DA_GHCN_HR4/mem000/restarts/vector"

# List directories in the path
dirs_to_process = sorted([d for d in os.listdir(sfcdata_path) if os.path.isdir(os.path.join(sfcdata_path, d))])
numdates = len(dirs_to_process)
print(f"Processing {numdates} dates")

vegetation_file = xr.open_dataset("/scratch2/NCEPDEV/land/Michael.Barlage/forcing/C768/vector/ufs-land_C768_hr3_static_fields.nc")
vegetation_type = vegetation_file['vegetation_category'].values

for idate in range(numdates):
    date_to_process = dirs_to_process[idate]
    datestring = date_to_process[-8:]

    sfc_date = datestring
    spin_date = f"{datestring[:4]}-{datestring[4:6]}-{datestring[6:]}"
    print(f"Moving: {spin_date} to {sfc_date}")

    spinup_file = xr.open_dataset(f"{spinup_path}/ufs_land_restart_anal.{spin_date}_00-00-00.nc")

    spinsmc = spinup_file['soil_moisture_vol'][0, :, :].values
    spinslc = spinup_file['soil_liquid_vol'][0, :, :].values
    spinstc = spinup_file['temperature_soil'][0, :, :].values
    spinsnd = spinup_file['snow_depth'][0, :].values
    spinswe = spinup_file['snow_water_equiv'][0, :].values

    high_snow_removal = 0
    low_snow_removal = 0

    for iloc in range(len(spinsnd)):
        if spinsnd[iloc] > 2000.0 and vegetation_type[iloc] == 15:
            reduction_factor = 2000.0 / spinsnd[iloc]
            if print_high_snow_removal:
                print(f"Reducing glacier location with depth = {spinsnd[iloc]} by factor = {reduction_factor}")
            spinsnd[iloc] = 2000.0
            spinswe[iloc] = reduction_factor * spinswe[iloc]
            high_snow_removal += 1

        if spinsnd[iloc] > 10000.0:
            reduction_factor = 10000.0 / spinsnd[iloc]
            if print_high_snow_removal:
                print(f"Reducing non-glacier location with depth = {spinsnd[iloc]} by factor = {reduction_factor}")
            spinsnd[iloc] = 10000.0
            spinswe[iloc] = reduction_factor * spinswe[iloc]
            high_snow_removal += 1

        if (0 < spinsnd[iloc] < 1) or (0 < spinswe[iloc] < 0.01):
            if print_low_snow_removal:
                print(f"Removing location with SWE = {spinswe[iloc]} and depth = {spinsnd[iloc]}")
            spinsnd[iloc] = 0.0
            spinswe[iloc] = 0.0
            low_snow_removal += 1

    num_in_spinup = spinstc.shape[1]
    print(f"Num in spinup: {num_in_spinup}")
    print(f"Num high_snow_removal: {high_snow_removal}")
    print(f"Num low_snow_removal: {low_snow_removal}")

    nloc = -1
    for itile in range(1, 7):
        print(f"Starting tile: {itile}")
        sfcdata_file = xr.open_dataset(f"{sfcdata_path}{sfc_date}/sfc_data.tile{itile}.nc", mode='r+')

        inmask = sfcdata_file['vtype'][0, :, :].values

        sfcsmc = sfcdata_file['smc'].values
        sfcslc = sfcdata_file['slc'].values
        sfcstc = sfcdata_file['stc'].values
        sfcswe = sfcdata_file['sheleg'].values
        sfcsnd = sfcdata_file['snwdph'].values

        ndims = inmask.shape
        print(f"Num vtype /= 0: {np.sum(inmask > 0)}")

        for idim0 in range(ndims[0]):
            for idim1 in range(ndims[1]):
                if inmask[idim0, idim1] != 0:
                    nloc += 1

                    sfcsmc[0, :, idim0, idim1] = spinsmc[:, nloc]
                    sfcslc[0, :, idim0, idim1] = spinslc[:, nloc]
                    sfcstc[0, :, idim0, idim1] = spinstc[:, nloc]
                    sfcswe[0, idim0, idim1] = spinswe[nloc]
                    sfcsnd[0, idim0, idim1] = spinsnd[nloc]

        num_in_tiles = nloc + 1
        print(f"Number of cumulative locs: {num_in_tiles}")

        # Update the variables in the NetCDF file
        sfcdata_file['smc'].values = sfcsmc
        sfcdata_file['slc'].values = sfcslc
        sfcdata_file['stc'].values = sfcstc
        sfcdata_file['sheleg'].values = sfcswe
        sfcdata_file['snwdph'].values = sfcsnd
        sfcdata_file.close()

    if num_in_tiles != num_in_spinup:
        print("Number in tiles /= number in spinup")
        exit()

print("Processing completed.")
