# Regrid to regular

This script shows how to regrid ERA5 data from a gaussian reduced grid to regular lon-lat grid with a linear approach.

In [None]:
import intake
import dask
import logging
from distributed import Client
import xarray as xr
#client.close()
#client=Client(silence_logs=logging.ERROR)

## User parameters:

- **catalog**: The one from git is always reachable but uses simplecache which can cause problems with async
- **catalog_entry**: One of the era5 time series available in list(cat)
- **target_global_vars**: A list of variables that is defined for 1280 longitudes at the equator and should be interpolated
- **openchunks**: A chunk setting for all dimension that are not related to lat and lon. Larger values mean less chunks but need more memory
- **to_load_selection**: The selection of the time series for which the workflow is applied

In [None]:
catalog="https://gitlab.dkrz.de/data-infrastructure-services/era5-kerchunks/-/raw/main/main.yaml"
#catalog="/work/bm1344/DKRZ/git/era5-kerchunks/main.yaml"
catalog_entry="surface_analysis_daily"
target_global_vars=["2t"]
openchunks=dict(
 time=4,
 #level=5 #for 3d data
)
to_load_selection=dict(
 time="2010"
)

Open catalog and load data for the template for dask functions:

In [None]:
dask.config.set({'logging.distributed': 'error'})
cat=intake.open_catalog(catalog)
dssource=cat[catalog_entry](chunks=openchunks).to_dask()
template_source=dssource[target_global_vars].isel(**{a:0 for a in openchunks.keys()}).load()

1. Unstack: Define function and template
- Select equator lons for interpolation
- Chunk entire record (lonxlat)

In [None]:
def unstack(ds):
 return ds.rename({'value':'latlon'}).set_index(latlon=("lat","lon")).unstack("latlon")

template_unstack=unstack(template_source)
equator_lons=template_unstack[target_global_vars].sel(lat=0.0,method="nearest").dropna(dim="lon")["lon"]

In [None]:
latlonchunks={
 a:len(template_unstack[a])
 for a in template_unstack.dims
}

nolatlonchunks={
 a:dssource[target_global_vars].chunksizes[a]
 for a in openchunks.keys()
}

In [None]:
template_unstack=template_unstack.chunk(**latlonchunks)

2. Interp: Interpolate all nans linearly and select only next to equator longitudes.

In [None]:
def interp(ds):
 #reindexed_block=ds.dropna(dim="lon").reindex(lon=xr.concat([ds["lon"],equator_lons],"lon")["lon"]).sortby("lon").drop_duplicates('lon')
 #interped=ds.interpolate_na(dim="lon",method="linear",period=360.0)
 #interped=dask.optimize(interped)[0]
 #return ds.groupby("lat").apply(
 # lambda dslat: dslat.dropna(dim="lon").interp(lon=equator_lons.values,method="linear",kwargs={"fill_value": "extrapolate"})
 #)
 return ds.interpolate_na(dim="lon",method="linear",period=360.0).reindex(lon=equator_lons)

template_interp=interp(template_unstack)

In [None]:
template_unstack=template_unstack.expand_dims(**{a:dssource[a] for a in openchunks.keys()}).chunk(nolatlonchunks)
template_interp=template_interp.expand_dims(**{a:dssource[a] for a in openchunks.keys()}).chunk(nolatlonchunks)

## Define workflow here

In [None]:
original=dssource[target_global_vars].sel(**to_load_selection)
#unstacked=dssource[target_global_vars].map_blocks(unstack,template=template_unstack[target_global_vars])
#dataset:
unstacked=original.map_blocks(unstack,template=template_unstack[target_global_vars].sel(time="2010"))
#dataarray:
#unstacked=original.map_blocks(unstack,template=template_unstack.sel(time="2010"))
#unstacked=dssource[target_global_vars].sel(time="2010").map_blocks(unstack,template=template_unstack[target_global_vars]).chunk(lat=1)
#unstacked=dssource[target_global_vars].sel(time="2010").map_blocks(unstack,template=template_unstack.sel(time="2010"))
#interped=unstacked.map_blocks(interp,template=template_interp[target_global_vars])
interped=unstacked.map_blocks(interp,template=template_interp.sel(time="2010"))
#interped=unstacked.map_blocks(interp,template=template_interp.sel(time="2010"))

In [None]:
interped=dask.optimize(interped)[0]

In [None]:
interped

## Run workflow here

In [None]:
from dask.diagnostics import ProgressBar
with ProgressBar():
 t2=interped.compute()

In [None]:
import hvplot.xarray
t2.hvplot.image(x="lon",y="lat")