import yaml
import netCDF4 as nc
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np

def plot_netcdf_data(nc_file, config_file):
    # 1. Load User Configuration
    with open(config_file, 'r') as f:
        config = yaml.safe_load(f)
    
    vars_to_plot = config.get('variables_to_plot', [])
    output_file = config.get('output_filename', 'diag_plot.png')
    cmap_choice = config.get('colormap', 'plasma')

    # 2. Open NetCDF File
    ds = nc.Dataset(nc_file)
    
    # Extract Coordinates
    lats = ds.variables['Latitude'][:]
    lons = ds.variables['Longitude'][:]
    
    # Create the figure
    num_vars = len(vars_to_plot)
    fig, axes = plt.subplots(
        nrows=num_vars, ncols=1, 
        figsize=(12, 6 * num_vars),
        subplot_kw={'projection': ccrs.PlateCarree()}
    )
    
    # Ensure axes is iterable if only one plot
    if num_vars == 1:
        axes = [axes]

    # 3. Plotting Loop
    for ax, var_name in zip(axes, vars_to_plot):
        if var_name not in ds.variables:
            print(f"Warning: {var_name} not found in file. Skipping.")
            continue
            
        data = ds.variables[var_name][:]
        
        # Add map features
        ax.set_global()
        ax.add_feature(cfeature.COASTLINE, linewidth=0.5)
        ax.add_feature(cfeature.BORDERS, linestyle=':', linewidth=0.5)
        ax.gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False)
        
        # Scatter plot for observation data
        sc = ax.scatter(
            lons, lats, c=data, 
            transform=ccrs.PlateCarree(),
            cmap=cmap_choice, s=1, alpha=0.7
        )
        
        # Add colorbar and title
        plt.colorbar(sc, ax=ax, orientation='vertical', shrink=0.7, label=var_name)
        ax.set_title(f"Field: {var_name} (Satellite: {ds.Satellite_Sensor})", loc='left')

    # 4. Save and Close
    plt.tight_layout()
    plt.savefig(output_file, dpi=200)
    print(f"Successfully saved plot to {output_file}")
    ds.close()

if __name__ == "__main__":
    # Update these paths as needed
    NC_FILENAME = "diag_ssmis_f17_ges.2026010712"
    CONFIG_NAME = "plot_config.yaml"
    
    plot_netcdf_data(NC_FILENAME, CONFIG_NAME)
    
