Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/animation.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ Some functionality such as `aspect` and `size` are not fully implemented yet.
ds = sdfxr.open_mfdataset("tutorial_dataset_2d/*.sdf")

# Change the units of the coordinates
ds = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid", "Y_Grid_mid"])
ds = ds.epoch.rescale_coords(1e15, "fs", ["time"])
ds = ds.epoch.rescale_coords("µm", ["X_Grid_mid", "Y_Grid_mid"])
ds = ds.epoch.rescale_coords("fs", ["time"])
ds["time"].attrs["long_name"] = "t"

# Change units and name of the variable
Expand Down
4 changes: 2 additions & 2 deletions docs/epoch_workshop_2026/animating/animate_1D_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
ds = sdfxr.open_mfdataset(input_dir)

# Convert the time to femtoseconds
ds = ds.epoch.rescale_coords(1e15, "fs", "time")
ds = ds.epoch.rescale_coords("fs", "time", 1e15)
# Convert the x and y coords to microns
ds = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid"])
ds = ds.epoch.rescale_coords("µm", ["X_Grid_mid"], 1e6)

anim = ds["Derived_Number_Density"].epoch.animate()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
ds = sdfxr.open_mfdataset(input_dir)

# Convert the time to femtoseconds
ds = ds.epoch.rescale_coords(1e15, "fs", "time")
ds = ds.epoch.rescale_coords("fs", "time", 1e15)
# Convert the x and y coords to microns
ds = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid"])
ds = ds.epoch.rescale_coords("µm", ["X_Grid_mid"], 1e6)

# Calculate Poynting flux magnitude
flux_magnitude = np.sqrt(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)

# Rescale coords to account for kilometers
ds = ds.epoch.rescale_coords(1e-3, "km", ["X_x_px_Left"])
ds = ds.epoch.rescale_coords("km", ["X_x_px_Left"], 1e-3)

# Sum phase-space of species "Left" and "Right" in "x_px" distribution function
# NOTE: We only use the values from the right distribution function as if we inherit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)

# Rescale coords to account for kilometers
ds = ds.epoch.rescale_coords(1e-3, "km", ["X_x_px_Left"])
ds = ds.epoch.rescale_coords("km", ["X_x_px_Left"], 1e-3)

# Sum phase-space of species "Left" and "Right" in "x_px" distribution function
# NOTE: We only use the values from the right distribution function as if we inherit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
ds = sdfxr.open_mfdataset(input_dir)

# Convert the time to femtoseconds
ds = ds.epoch.rescale_coords(1e15, "fs", "time")
ds = ds.epoch.rescale_coords("fs", "time", 1e15)
# Convert the x and y coords to microns
ds = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid", "Y_Grid_mid"])
ds = ds.epoch.rescale_coords("µm", ["X_Grid_mid", "Y_Grid_mid"], 1e6)

# Calculate Poynting flux magnitude
flux_magnitude = np.sqrt(
Expand Down
4 changes: 2 additions & 2 deletions docs/epoch_workshop_2026/live_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3800,7 +3800,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"id": "2ef27498",
"metadata": {
"colab": {
Expand Down Expand Up @@ -3833,7 +3833,7 @@
}
],
"source": [
"ds = ds.epoch.rescale_coords(1e6, \"µm\", [\"X_Grid_mid\", \"Y_Grid_mid\"])\n",
"ds = ds.epoch.rescale_coords(\"µm\", [\"X_Grid_mid\", \"Y_Grid_mid\"], 1e6)\n",
"ds[\"Derived_Number_Density_Electron\"].isel(time=0).epoch.plot()"
]
},
Expand Down
2 changes: 1 addition & 1 deletion docs/epoch_workshop_2026/plotting/plot_1D_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
input_dir = Path("datasets/1_1_drifting_bunch")

ds = sdfxr.open_dataset(input_dir / "0000.sdf")
ds = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid"])
ds = ds.epoch.rescale_coords("µm", ["X_Grid_mid"], 1e6)

ds["Derived_Number_Density"].epoch.plot()
plt.tight_layout()
Expand Down
2 changes: 1 addition & 1 deletion docs/epoch_workshop_2026/plotting/plot_1D_poynting_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ds = sdfxr.open_dataset(input_dir / "0020.sdf")

# Convert the x and y coords to microns
ds = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid"])
ds = ds.epoch.rescale_coords("µm", ["X_Grid_mid"], 1e6)

# Calculate Poynting flux magnitude
flux_magnitude = np.sqrt(
Expand Down
2 changes: 1 addition & 1 deletion docs/epoch_workshop_2026/plotting/plot_2D_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
input_dir = Path("datasets/4_3_basic_target")

ds = sdfxr.open_dataset(input_dir / "0000.sdf")
ds = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid", "Y_Grid_mid"])
ds = ds.epoch.rescale_coords("µm", ["X_Grid_mid", "Y_Grid_mid"], 1e6)

ds["Derived_Number_Density"].epoch.plot()
plt.tight_layout()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
input_dir = Path("datasets/2_1_two_stream_instability")
ds = sdfxr.open_dataset(input_dir / "0000.sdf")

ds = ds.epoch.rescale_coords(1e-3, "km", ["X_x_px_Left"])
ds = ds.epoch.rescale_coords("km", ["X_x_px_Left"], 1e-3)

# Sum phase-space of species "Left" and "Right" in "x_px" distribution function
# NOTE: We only use the values from the right distribution function as if we inherit
Expand Down
2 changes: 1 addition & 1 deletion docs/epoch_workshop_2026/plotting/plot_2D_poynting_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ds = sdfxr.open_dataset(input_dir / "0001.sdf")

# Convert the x and y coords to microns
ds = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid", "Y_Grid_mid"])
ds = ds.epoch.rescale_coords("µm", ["X_Grid_mid", "Y_Grid_mid"], 1e6)

# Calculate Poynting flux magnitude
flux_magnitude = np.sqrt(
Expand Down
2 changes: 1 addition & 1 deletion docs/epoch_workshop_2026/plotting/plot_time_temperature.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ds = sdfxr.open_mfdataset(input_dir)

# Convert the time to femtoseconds
ds = ds.epoch.rescale_coords(1e15, "fs", "time")
ds = ds.epoch.rescale_coords("fs", "time", 1e15)

# Averate temperature over all spatial cells at each time-step
kB = 1.380649e-23
Expand Down
2 changes: 1 addition & 1 deletion docs/epoch_workshop_2026/plotting/plot_x_px_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
ds = sdfxr.open_dataset(input_dir / "0000.sdf", keep_particles=True)

ds = ds.epoch.rescale_coords(
1e6, "µm", ["X_Particles_subset_Refluxers_Electron", "Y_Target_mid"]
"µm", ["X_Particles_subset_Refluxers_Electron", "Y_Target_mid"], 1e6
)

x = ds["X_Particles_subset_Refluxers_Electron"]
Expand Down
8 changes: 4 additions & 4 deletions docs/unit_conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ plt.rcParams.update({
For simple scaling and unit relabelling of coordinates (e.g., converting meters to microns),
the most straightforward approach is to use the [`xarray.Dataset.epoch.rescale_coords`](project:#sdf_xarray.dataset_accessor.EpochAccessor.rescale_coords) dataset accessor.
This function scales the coordinate values by a given multiplier and updates the
`"units"` attribute in one step.
`"units"` attribute in one step. If the multiplier is not specified then the conversion parameter is inferred from the units and converted automatically using `pint` (see [](#unit-conversion-with-pint-xarray) for details).

### Rescaling grid coordinates

Expand All @@ -41,7 +41,7 @@ We can use the [`xarray.Dataset.epoch.rescale_coords`](project:#sdf_xarray.datas
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

ds = sdfxr.open_mfdataset("tutorial_dataset_2d/*.sdf")
ds_in_microns = ds.epoch.rescale_coords(1e6, "µm", ["X_Grid_mid", "Y_Grid_mid"])
ds_in_microns = ds.epoch.rescale_coords("µm", ["X_Grid_mid", "Y_Grid_mid"], 1e6)

ds["Derived_Number_Density_Electron"].isel(time=0).epoch.plot(ax=ax1)
ax1.set_title("Original X Coordinate (m)")
Expand All @@ -55,15 +55,15 @@ fig.tight_layout()
### Rescaling time coordinate

We can also use the [`xarray.Dataset.epoch.rescale_coords`](project:#sdf_xarray.dataset_accessor.EpochAccessor.rescale_coords) method to convert the time coordinate from
seconds (`s`) to femto-seconds (`fs`) by applying a multiplier of `1e15`.
seconds (`s`) to femto-seconds (`fs`).

```{code-cell} ipython3
ds = sdfxr.open_mfdataset("tutorial_dataset_2d/*.sdf")
ds["time"]
```

```{code-cell} ipython3
ds = ds.epoch.rescale_coords(1e15, "fs", "time")
ds = ds.epoch.rescale_coords("fs", "time")
ds["time"]
```

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ test = [
"matplotlib",
"pooch>=1.8.2",
"tqdm",
"pint",
"pint-xarray",
]

[project.entry-points."xarray.backends"]
Expand Down
48 changes: 30 additions & 18 deletions src/sdf_xarray/dataset_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,39 @@ def __init__(self, xarray_obj: xr.Dataset):

def rescale_coords(
self,
multiplier: float,
unit_label: str,
coord_names: str | list[str],
multiplier: float | None = None,
) -> xr.Dataset:
"""
Rescales specified X and Y coordinates in the Dataset by a given multiplier
and updates the unit label attribute.
Rescales specified coordinates in a Dataset by a given unit. If the
multiplier is not specified then the coordinates are automatically
scaled using `pint <https://pint.readthedocs.io/en/stable>`_, if the multiplier is specified then it will be
used to rescale the coordinate.

Parameters
----------
multiplier : float
The factor by which to multiply the coordinate values (e.g., 1e6 for meters to microns).
unit_label : str
The new unit label for the coordinates (e.g., "µm").
coord_names : str or list of str
The name(s) of the coordinate variable(s) to rescale.
If a string, only that coordinate is rescaled.
If a list, all listed coordinates are rescaled.

Returns
-------
xr.Dataset
A new Dataset with the updated and rescaled coordinates.
multiplier : float or None
The factor by which to multiply the coordinate values (e.g., 1e6
for meters to microns). If not specified then ``pint`` is used to
rescale the units automatically.

Examples
--------
# Convert X, Y, and Z from meters to microns
>>> ds_in_microns = ds.epoch.rescale_coords(1e6, "µm", coord_names=["X_Grid", "Y_Grid", "Z_Grid"])

# Convert only X to millimeters
>>> ds_in_mm = ds.epoch.rescale_coords(1000, "mm", coord_names="X_Grid")
>>> # Convert X, Y, and Z from meters to microns using pint
>>> ds_in_microns = ds.epoch.rescale_coords("µm", coord_names=["X_Grid", "Y_Grid", "Z_Grid"])
>>>
>>> # Convert X, Y, and Z from meters to microns
>>> ds_in_microns = ds.epoch.rescale_coords("µm", coord_names=["X_Grid", "Y_Grid", "Z_Grid"], 1e6)
>>>
>>> # Convert time to femtoseconds
>>> ds_in_mm = ds.epoch.rescale_coords("fs", coord_names="time")
"""

ds = self._ds
Expand All @@ -72,9 +74,19 @@ def rescale_coords(

coord_original = ds[coord_name]

coord_rescaled = coord_original * multiplier
coord_rescaled.attrs = coord_original.attrs.copy()
coord_rescaled.attrs["units"] = unit_label
if multiplier is not None:
coord_rescaled = coord_original * multiplier
coord_rescaled.attrs = coord_original.attrs.copy()
coord_rescaled.attrs["units"] = unit_label
else:
coord_rescaled: xr.DataArray = (
coord_original.pint.quantify(coord_original.attrs["units"])
.pint.to(unit_label)
.pint.dequantify()
)
# Ensure the unit label follows the same naming convension the
# user has specified and not the one given by pint
coord_rescaled.attrs["units"] = unit_label

new_coords[coord_name] = coord_rescaled

Expand Down
44 changes: 43 additions & 1 deletion tests/test_epoch_dataset_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,32 @@ def test_rescale_coords_X():
assert ds_rescaled["Z_Grid_mid"].attrs["full_name"] == "Grid/Grid_mid"


def test_rescale_coords_X_auto():
unit_label = "mm"

with xr.open_dataset(TEST_FILES_DIR_3D / "0000.sdf") as ds:
ds_rescaled = ds.epoch.rescale_coords(
unit_label=unit_label,
coord_names="X_Grid_mid",
)

expected_x = ds["X_Grid_mid"].values * 1e3
assert np.allclose(ds_rescaled["X_Grid_mid"].values, expected_x)
assert ds_rescaled["X_Grid_mid"].attrs["units"] == unit_label
assert ds_rescaled["X_Grid_mid"].attrs["long_name"] == "X"
assert ds_rescaled["X_Grid_mid"].attrs["full_name"] == "Grid/Grid_mid"

assert np.allclose(ds_rescaled["Y_Grid_mid"].values, ds["Y_Grid_mid"].values)
assert ds_rescaled["Y_Grid_mid"].attrs["units"] == "m"
assert ds_rescaled["Y_Grid_mid"].attrs["long_name"] == "Y"
assert ds_rescaled["Y_Grid_mid"].attrs["full_name"] == "Grid/Grid_mid"

assert np.allclose(ds_rescaled["Z_Grid_mid"].values, ds["Z_Grid_mid"].values)
assert ds_rescaled["Z_Grid_mid"].attrs["units"] == "m"
assert ds_rescaled["Z_Grid_mid"].attrs["long_name"] == "Z"
assert ds_rescaled["Z_Grid_mid"].attrs["full_name"] == "Grid/Grid_mid"


def test_rescale_coords_X_Y():
multiplier = 1e2
unit_label = "cm"
Expand Down Expand Up @@ -142,7 +168,7 @@ def test_rescale_coords_non_existent_coord():


def test_rescale_coords_time():
multiplier = 1e-15
multiplier = 1e15
unit_label = "fs"

with open_mfdataset(TEST_FILES_DIR_3D.glob("*.sdf")) as ds:
Expand All @@ -159,6 +185,22 @@ def test_rescale_coords_time():
assert ds_rescaled["time"].attrs["full_name"] == "time"


def test_rescale_coords_time_auto():
unit_label = "fs"

with open_mfdataset(TEST_FILES_DIR_3D.glob("*.sdf")) as ds:
ds_rescaled = ds.epoch.rescale_coords(
unit_label=unit_label,
coord_names="time",
)

expected_time = ds["time"].values * 1e15
assert np.allclose(ds_rescaled["time"].values, expected_time)
assert ds_rescaled["time"].attrs["units"] == unit_label
assert ds_rescaled["time"].attrs["long_name"] == "Time"
assert ds_rescaled["time"].attrs["full_name"] == "time"


def test_animate_multiple_accessor():
with open_mfdataset(TEST_FILES_DIR_1D.glob("*.sdf")) as ds:
assert hasattr(ds, "epoch")
Expand Down
Loading