diff --git a/docs/animation.md b/docs/animation.md index 35ad8d3..c512b4a 100644 --- a/docs/animation.md +++ b/docs/animation.md @@ -170,8 +170,8 @@ Some functionality such as `aspect` and `size` are not fully implemented yet. ds = sdfxr.open_mfdataset("tutorial_dataset_2d/*.sdf") # Change the units of the coordinates -ds = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid", "Y_Grid_mid"]) -ds = ds.epoch.rescale_coords(1e15, "fs", ["time"]) +ds = ds.epoch.rescale_coords("µm", ["X_Grid_mid", "Y_Grid_mid"]) +ds = ds.epoch.rescale_coords("fs", ["time"]) ds["time"].attrs["long_name"] = "t" # Change units and name of the variable diff --git a/docs/epoch_workshop_2026/animating/animate_1D_density.py b/docs/epoch_workshop_2026/animating/animate_1D_density.py index 073aa3e..c47a570 100644 --- a/docs/epoch_workshop_2026/animating/animate_1D_density.py +++ b/docs/epoch_workshop_2026/animating/animate_1D_density.py @@ -6,9 +6,9 @@ ds = sdfxr.open_mfdataset(input_dir) # Convert the time to femtoseconds -ds = ds.epoch.rescale_coords(1e15, "fs", "time") +ds = ds.epoch.rescale_coords("fs", "time", 1e15) # Convert the x and y coords to microns -ds = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid"]) +ds = ds.epoch.rescale_coords("µm", ["X_Grid_mid"], 1e6) anim = ds["Derived_Number_Density"].epoch.animate() diff --git a/docs/epoch_workshop_2026/animating/animate_1D_poynting_flux.py b/docs/epoch_workshop_2026/animating/animate_1D_poynting_flux.py index dc8bd0d..f63db10 100644 --- a/docs/epoch_workshop_2026/animating/animate_1D_poynting_flux.py +++ b/docs/epoch_workshop_2026/animating/animate_1D_poynting_flux.py @@ -8,9 +8,9 @@ ds = sdfxr.open_mfdataset(input_dir) # Convert the time to femtoseconds -ds = ds.epoch.rescale_coords(1e15, "fs", "time") +ds = ds.epoch.rescale_coords("fs", "time", 1e15) # Convert the x and y coords to microns -ds = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid"]) +ds = ds.epoch.rescale_coords("µm", ["X_Grid_mid"], 1e6) # Calculate Poynting flux magnitude flux_magnitude = np.sqrt( diff --git a/docs/epoch_workshop_2026/animating/animate_2D_dist_fn_multispecies.py b/docs/epoch_workshop_2026/animating/animate_2D_dist_fn_multispecies.py index a5bb913..e395459 100644 --- a/docs/epoch_workshop_2026/animating/animate_2D_dist_fn_multispecies.py +++ b/docs/epoch_workshop_2026/animating/animate_2D_dist_fn_multispecies.py @@ -8,7 +8,7 @@ ) # Rescale coords to account for kilometers -ds = ds.epoch.rescale_coords(1e-3, "km", ["X_x_px_Left"]) +ds = ds.epoch.rescale_coords("km", ["X_x_px_Left"], 1e-3) # Sum phase-space of species "Left" and "Right" in "x_px" distribution function # NOTE: We only use the values from the right distribution function as if we inherit diff --git a/docs/epoch_workshop_2026/animating/animate_2D_dist_fn_multispecies_alternative.py b/docs/epoch_workshop_2026/animating/animate_2D_dist_fn_multispecies_alternative.py index e8fb86a..7b1eedb 100644 --- a/docs/epoch_workshop_2026/animating/animate_2D_dist_fn_multispecies_alternative.py +++ b/docs/epoch_workshop_2026/animating/animate_2D_dist_fn_multispecies_alternative.py @@ -11,7 +11,7 @@ ) # Rescale coords to account for kilometers -ds = ds.epoch.rescale_coords(1e-3, "km", ["X_x_px_Left"]) +ds = ds.epoch.rescale_coords("km", ["X_x_px_Left"], 1e-3) # Sum phase-space of species "Left" and "Right" in "x_px" distribution function # NOTE: We only use the values from the right distribution function as if we inherit diff --git a/docs/epoch_workshop_2026/animating/animate_2D_poynting_flux.py b/docs/epoch_workshop_2026/animating/animate_2D_poynting_flux.py index 0f09b09..372dc97 100644 --- a/docs/epoch_workshop_2026/animating/animate_2D_poynting_flux.py +++ b/docs/epoch_workshop_2026/animating/animate_2D_poynting_flux.py @@ -8,9 +8,9 @@ ds = sdfxr.open_mfdataset(input_dir) # Convert the time to femtoseconds -ds = ds.epoch.rescale_coords(1e15, "fs", "time") +ds = ds.epoch.rescale_coords("fs", "time", 1e15) # Convert the x and y coords to microns -ds = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid", "Y_Grid_mid"]) +ds = ds.epoch.rescale_coords("µm", ["X_Grid_mid", "Y_Grid_mid"], 1e6) # Calculate Poynting flux magnitude flux_magnitude = np.sqrt( diff --git a/docs/epoch_workshop_2026/live_demo.ipynb b/docs/epoch_workshop_2026/live_demo.ipynb index 6c6faf2..d511230 100644 --- a/docs/epoch_workshop_2026/live_demo.ipynb +++ b/docs/epoch_workshop_2026/live_demo.ipynb @@ -3800,7 +3800,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "2ef27498", "metadata": { "colab": { @@ -3833,7 +3833,7 @@ } ], "source": [ - "ds = ds.epoch.rescale_coords(1e6, \"µm\", [\"X_Grid_mid\", \"Y_Grid_mid\"])\n", + "ds = ds.epoch.rescale_coords(\"µm\", [\"X_Grid_mid\", \"Y_Grid_mid\"], 1e6)\n", "ds[\"Derived_Number_Density_Electron\"].isel(time=0).epoch.plot()" ] }, diff --git a/docs/epoch_workshop_2026/plotting/plot_1D_density.py b/docs/epoch_workshop_2026/plotting/plot_1D_density.py index 7220d40..94ef4df 100644 --- a/docs/epoch_workshop_2026/plotting/plot_1D_density.py +++ b/docs/epoch_workshop_2026/plotting/plot_1D_density.py @@ -8,7 +8,7 @@ input_dir = Path("datasets/1_1_drifting_bunch") ds = sdfxr.open_dataset(input_dir / "0000.sdf") -ds = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid"]) +ds = ds.epoch.rescale_coords("µm", ["X_Grid_mid"], 1e6) ds["Derived_Number_Density"].epoch.plot() plt.tight_layout() diff --git a/docs/epoch_workshop_2026/plotting/plot_1D_poynting_flux.py b/docs/epoch_workshop_2026/plotting/plot_1D_poynting_flux.py index 71e99f6..c9135a2 100644 --- a/docs/epoch_workshop_2026/plotting/plot_1D_poynting_flux.py +++ b/docs/epoch_workshop_2026/plotting/plot_1D_poynting_flux.py @@ -10,7 +10,7 @@ ds = sdfxr.open_dataset(input_dir / "0020.sdf") # Convert the x and y coords to microns -ds = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid"]) +ds = ds.epoch.rescale_coords("µm", ["X_Grid_mid"], 1e6) # Calculate Poynting flux magnitude flux_magnitude = np.sqrt( diff --git a/docs/epoch_workshop_2026/plotting/plot_2D_density.py b/docs/epoch_workshop_2026/plotting/plot_2D_density.py index 811924f..0d1a813 100644 --- a/docs/epoch_workshop_2026/plotting/plot_2D_density.py +++ b/docs/epoch_workshop_2026/plotting/plot_2D_density.py @@ -8,7 +8,7 @@ input_dir = Path("datasets/4_3_basic_target") ds = sdfxr.open_dataset(input_dir / "0000.sdf") -ds = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid", "Y_Grid_mid"]) +ds = ds.epoch.rescale_coords("µm", ["X_Grid_mid", "Y_Grid_mid"], 1e6) ds["Derived_Number_Density"].epoch.plot() plt.tight_layout() diff --git a/docs/epoch_workshop_2026/plotting/plot_2D_dist_fn_multispecies.py b/docs/epoch_workshop_2026/plotting/plot_2D_dist_fn_multispecies.py index a56cea8..e73c311 100644 --- a/docs/epoch_workshop_2026/plotting/plot_2D_dist_fn_multispecies.py +++ b/docs/epoch_workshop_2026/plotting/plot_2D_dist_fn_multispecies.py @@ -8,7 +8,7 @@ input_dir = Path("datasets/2_1_two_stream_instability") ds = sdfxr.open_dataset(input_dir / "0000.sdf") -ds = ds.epoch.rescale_coords(1e-3, "km", ["X_x_px_Left"]) +ds = ds.epoch.rescale_coords("km", ["X_x_px_Left"], 1e-3) # Sum phase-space of species "Left" and "Right" in "x_px" distribution function # NOTE: We only use the values from the right distribution function as if we inherit diff --git a/docs/epoch_workshop_2026/plotting/plot_2D_poynting_flux.py b/docs/epoch_workshop_2026/plotting/plot_2D_poynting_flux.py index 99a6bb3..b46b952 100644 --- a/docs/epoch_workshop_2026/plotting/plot_2D_poynting_flux.py +++ b/docs/epoch_workshop_2026/plotting/plot_2D_poynting_flux.py @@ -10,7 +10,7 @@ ds = sdfxr.open_dataset(input_dir / "0001.sdf") # Convert the x and y coords to microns -ds = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid", "Y_Grid_mid"]) +ds = ds.epoch.rescale_coords("µm", ["X_Grid_mid", "Y_Grid_mid"], 1e6) # Calculate Poynting flux magnitude flux_magnitude = np.sqrt( diff --git a/docs/epoch_workshop_2026/plotting/plot_time_temperature.py b/docs/epoch_workshop_2026/plotting/plot_time_temperature.py index cdab64c..c9ec760 100644 --- a/docs/epoch_workshop_2026/plotting/plot_time_temperature.py +++ b/docs/epoch_workshop_2026/plotting/plot_time_temperature.py @@ -13,7 +13,7 @@ ds = sdfxr.open_mfdataset(input_dir) # Convert the time to femtoseconds -ds = ds.epoch.rescale_coords(1e15, "fs", "time") +ds = ds.epoch.rescale_coords("fs", "time", 1e15) # Averate temperature over all spatial cells at each time-step kB = 1.380649e-23 diff --git a/docs/epoch_workshop_2026/plotting/plot_x_px_scatter.py b/docs/epoch_workshop_2026/plotting/plot_x_px_scatter.py index 19df9ba..f12d97e 100644 --- a/docs/epoch_workshop_2026/plotting/plot_x_px_scatter.py +++ b/docs/epoch_workshop_2026/plotting/plot_x_px_scatter.py @@ -9,7 +9,7 @@ ds = sdfxr.open_dataset(input_dir / "0000.sdf", keep_particles=True) ds = ds.epoch.rescale_coords( - 1e6, "µm", ["X_Particles_subset_Refluxers_Electron", "Y_Target_mid"] + "µm", ["X_Particles_subset_Refluxers_Electron", "Y_Target_mid"], 1e6 ) x = ds["X_Particles_subset_Refluxers_Electron"] diff --git a/docs/unit_conversion.md b/docs/unit_conversion.md index 113c530..39709d1 100644 --- a/docs/unit_conversion.md +++ b/docs/unit_conversion.md @@ -30,7 +30,7 @@ plt.rcParams.update({ For simple scaling and unit relabelling of coordinates (e.g., converting meters to microns), the most straightforward approach is to use the [`xarray.Dataset.epoch.rescale_coords`](project:#sdf_xarray.dataset_accessor.EpochAccessor.rescale_coords) dataset accessor. This function scales the coordinate values by a given multiplier and updates the -`"units"` attribute in one step. +`"units"` attribute in one step. If the multiplier is not specified then the conversion parameter is inferred from the units and converted automatically using `pint` (see [](#unit-conversion-with-pint-xarray) for details). ### Rescaling grid coordinates @@ -41,7 +41,7 @@ We can use the [`xarray.Dataset.epoch.rescale_coords`](project:#sdf_xarray.datas fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6)) ds = sdfxr.open_mfdataset("tutorial_dataset_2d/*.sdf") -ds_in_microns = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid", "Y_Grid_mid"]) +ds_in_microns = ds.epoch.rescale_coords("µm", ["X_Grid_mid", "Y_Grid_mid"], 1e6) ds["Derived_Number_Density_Electron"].isel(time=0).epoch.plot(ax=ax1) ax1.set_title("Original X Coordinate (m)") @@ -55,7 +55,7 @@ fig.tight_layout() ### Rescaling time coordinate We can also use the [`xarray.Dataset.epoch.rescale_coords`](project:#sdf_xarray.dataset_accessor.EpochAccessor.rescale_coords) method to convert the time coordinate from -seconds (`s`) to femto-seconds (`fs`) by applying a multiplier of `1e15`. +seconds (`s`) to femto-seconds (`fs`). ```{code-cell} ipython3 ds = sdfxr.open_mfdataset("tutorial_dataset_2d/*.sdf") @@ -63,7 +63,7 @@ ds["time"] ``` ```{code-cell} ipython3 -ds = ds.epoch.rescale_coords(1e15, "fs", "time") +ds = ds.epoch.rescale_coords("fs", "time") ds["time"] ``` diff --git a/pyproject.toml b/pyproject.toml index 98790c2..33b9c35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,8 @@ test = [ "matplotlib", "pooch>=1.8.2", "tqdm", + "pint", + "pint-xarray", ] [project.entry-points."xarray.backends"] diff --git a/src/sdf_xarray/dataset_accessor.py b/src/sdf_xarray/dataset_accessor.py index c9f5433..bc6e35d 100644 --- a/src/sdf_xarray/dataset_accessor.py +++ b/src/sdf_xarray/dataset_accessor.py @@ -19,37 +19,39 @@ def __init__(self, xarray_obj: xr.Dataset): def rescale_coords( self, - multiplier: float, unit_label: str, coord_names: str | list[str], + multiplier: float | None = None, ) -> xr.Dataset: """ - Rescales specified X and Y coordinates in the Dataset by a given multiplier - and updates the unit label attribute. + Rescales specified coordinates in a Dataset by a given unit. If the + multiplier is not specified then the coordinates are automatically + scaled using `pint `_, if the multiplier is specified then it will be + used to rescale the coordinate. Parameters ---------- - multiplier : float - The factor by which to multiply the coordinate values (e.g., 1e6 for meters to microns). unit_label : str The new unit label for the coordinates (e.g., "µm"). coord_names : str or list of str The name(s) of the coordinate variable(s) to rescale. If a string, only that coordinate is rescaled. If a list, all listed coordinates are rescaled. - - Returns - ------- - xr.Dataset - A new Dataset with the updated and rescaled coordinates. + multiplier : float or None + The factor by which to multiply the coordinate values (e.g., 1e6 + for meters to microns). If not specified then ``pint`` is used to + rescale the units automatically. Examples -------- - # Convert X, Y, and Z from meters to microns - >>> ds_in_microns = ds.epoch.rescale_coords(1e6, "µm", coord_names=["X_Grid", "Y_Grid", "Z_Grid"]) - - # Convert only X to millimeters - >>> ds_in_mm = ds.epoch.rescale_coords(1000, "mm", coord_names="X_Grid") + >>> # Convert X, Y, and Z from meters to microns using pint + >>> ds_in_microns = ds.epoch.rescale_coords("µm", coord_names=["X_Grid", "Y_Grid", "Z_Grid"]) + >>> + >>> # Convert X, Y, and Z from meters to microns + >>> ds_in_microns = ds.epoch.rescale_coords("µm", coord_names=["X_Grid", "Y_Grid", "Z_Grid"], 1e6) + >>> + >>> # Convert time to femtoseconds + >>> ds_in_mm = ds.epoch.rescale_coords("fs", coord_names="time") """ ds = self._ds @@ -72,9 +74,19 @@ def rescale_coords( coord_original = ds[coord_name] - coord_rescaled = coord_original * multiplier - coord_rescaled.attrs = coord_original.attrs.copy() - coord_rescaled.attrs["units"] = unit_label + if multiplier is not None: + coord_rescaled = coord_original * multiplier + coord_rescaled.attrs = coord_original.attrs.copy() + coord_rescaled.attrs["units"] = unit_label + else: + coord_rescaled: xr.DataArray = ( + coord_original.pint.quantify(coord_original.attrs["units"]) + .pint.to(unit_label) + .pint.dequantify() + ) + # Ensure the unit label follows the same naming convension the + # user has specified and not the one given by pint + coord_rescaled.attrs["units"] = unit_label new_coords[coord_name] = coord_rescaled diff --git a/tests/test_epoch_dataset_accessor.py b/tests/test_epoch_dataset_accessor.py index 69b9e09..5eb3e97 100644 --- a/tests/test_epoch_dataset_accessor.py +++ b/tests/test_epoch_dataset_accessor.py @@ -50,6 +50,32 @@ def test_rescale_coords_X(): assert ds_rescaled["Z_Grid_mid"].attrs["full_name"] == "Grid/Grid_mid" +def test_rescale_coords_X_auto(): + unit_label = "mm" + + with xr.open_dataset(TEST_FILES_DIR_3D / "0000.sdf") as ds: + ds_rescaled = ds.epoch.rescale_coords( + unit_label=unit_label, + coord_names="X_Grid_mid", + ) + + expected_x = ds["X_Grid_mid"].values * 1e3 + assert np.allclose(ds_rescaled["X_Grid_mid"].values, expected_x) + assert ds_rescaled["X_Grid_mid"].attrs["units"] == unit_label + assert ds_rescaled["X_Grid_mid"].attrs["long_name"] == "X" + assert ds_rescaled["X_Grid_mid"].attrs["full_name"] == "Grid/Grid_mid" + + assert np.allclose(ds_rescaled["Y_Grid_mid"].values, ds["Y_Grid_mid"].values) + assert ds_rescaled["Y_Grid_mid"].attrs["units"] == "m" + assert ds_rescaled["Y_Grid_mid"].attrs["long_name"] == "Y" + assert ds_rescaled["Y_Grid_mid"].attrs["full_name"] == "Grid/Grid_mid" + + assert np.allclose(ds_rescaled["Z_Grid_mid"].values, ds["Z_Grid_mid"].values) + assert ds_rescaled["Z_Grid_mid"].attrs["units"] == "m" + assert ds_rescaled["Z_Grid_mid"].attrs["long_name"] == "Z" + assert ds_rescaled["Z_Grid_mid"].attrs["full_name"] == "Grid/Grid_mid" + + def test_rescale_coords_X_Y(): multiplier = 1e2 unit_label = "cm" @@ -142,7 +168,7 @@ def test_rescale_coords_non_existent_coord(): def test_rescale_coords_time(): - multiplier = 1e-15 + multiplier = 1e15 unit_label = "fs" with open_mfdataset(TEST_FILES_DIR_3D.glob("*.sdf")) as ds: @@ -159,6 +185,22 @@ def test_rescale_coords_time(): assert ds_rescaled["time"].attrs["full_name"] == "time" +def test_rescale_coords_time_auto(): + unit_label = "fs" + + with open_mfdataset(TEST_FILES_DIR_3D.glob("*.sdf")) as ds: + ds_rescaled = ds.epoch.rescale_coords( + unit_label=unit_label, + coord_names="time", + ) + + expected_time = ds["time"].values * 1e15 + assert np.allclose(ds_rescaled["time"].values, expected_time) + assert ds_rescaled["time"].attrs["units"] == unit_label + assert ds_rescaled["time"].attrs["long_name"] == "Time" + assert ds_rescaled["time"].attrs["full_name"] == "time" + + def test_animate_multiple_accessor(): with open_mfdataset(TEST_FILES_DIR_1D.glob("*.sdf")) as ds: assert hasattr(ds, "epoch")