Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions PyFHD/plotting/image.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,37 @@
import numpy as np
from numpy.typing import NDArray
import matplotlib.pyplot as plt
import os
from pathlib import Path
from logging import Logger

from astropy.wcs import WCS
from astropy.io import fits
from astropy import units as u
from logging import Logger
import os
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
from numpy.typing import NDArray


def truncate_colormap(cmap, *, minval=0.0, maxval=1.0, nseg=100):
new_cmap = colors.LinearSegmentedColormap.from_list(
"trunc({n},{a:.2f},{b:.2f})".format(n=cmap.name, a=minval, b=maxval),
cmap(np.linspace(minval, maxval, nseg)),
)
return new_cmap


def quick_image(
image: NDArray[np.integer | np.floating | np.complexfloating],
xvals: NDArray[np.integer | np.floating] = None,
yvals: NDArray[np.integer | np.floating] = None,
*,
data_range: NDArray[np.integer | np.floating] = None,
data_min_abs: float = None,
xrange: NDArray[np.integer | np.floating] = None,
yrange: NDArray[np.integer | np.floating] = None,
data_aspect: float = None,
log: bool = False,
color_profile: str = "log_cut",
cmap: str | None = None,
xtitle: str = None,
ytitle: str = None,
title: str = None,
Expand Down Expand Up @@ -66,6 +78,8 @@ def quick_image(
color_profile : str, optional
Color bar profiles for logarithmic scaling.
"log_cut", "sym_log", "abs", by default "log_cut"
cmap : str, optional
Matplotlib colormap to use.
xtitle : str, optional
The title of the x-axis, by default None
ytitle : str, optional
Expand Down Expand Up @@ -184,13 +198,13 @@ def quick_image(

# Validate that 2-value inputs are only 2 values
if data_range is not None:
if not isinstance(data_range, np.ndarray) or len(data_range) != 2:
if not isinstance(data_range, np.ndarray | list) or len(data_range) != 2:
raise ValueError("data_range must be an array with exactly two values.")
if xrange is not None:
if not isinstance(xrange, np.ndarray) or len(xrange) != 2:
if not isinstance(xrange, np.ndarray | list) or len(xrange) != 2:
raise ValueError("xrange must be an array with exactly two values.")
if yrange is not None:
if not isinstance(yrange, np.ndarray) or len(yrange) != 2:
if not isinstance(yrange, np.ndarray | list) or len(yrange) != 2:
raise ValueError("yrange must be an array with exactly two values.")

# Apply logarithmic scaling if set. This modifies the image input directly
Expand All @@ -215,18 +229,19 @@ def quick_image(

data_color_range, data_n_colors = color_range(count_missing=count_missing)

# Find out-of-bounds values
wh_low = np.nonzero(image < data_range[0])
wh_high = np.nonzero(image > data_range[1])

# Scale image data to be in the color range
image = (image - data_range[0]) * (data_n_colors - 1) / (
data_range[1] - data_range[0]
) + data_color_range[0]
print(data_range, data_color_range, data_n_colors)

# Handle out-of-bounds values
wh_low = np.where(image < data_range[0])
if len(wh_low[0]) > 0:
if wh_low[0].size > 0:
image[wh_low] = data_color_range[0]
wh_high = np.where(image > data_range[1])
if len(wh_high[0]) > 0:
if wh_high[0].size > 0:
image[wh_high] = data_color_range[1]

# Handle missing values
Expand All @@ -238,11 +253,19 @@ def quick_image(
f"{tick * (data_range[1] - data_range[0]) / (data_n_colors - 1) + data_range[0]:.2g}"
for tick in cb_ticks
]
print(cb_ticks, cb_ticknames)

# Set up the plot
fig, ax = plt.subplots()
cmap = plt.get_cmap("viridis")
if cmap == "idl":
cmap = plt.get_cmap("Spectral_r")
cmap = truncate_colormap(cmap, minval=(20 / 255), maxval=1, nseg=256)
elif cmap is None:
if log and color_profile == "sym_log":
cmap = "RdBu"
else:
cmap = "viridis"
else:
cmap = plt.get_cmap(cmap)

# Set up the x and y ranges
extent = None
Expand Down Expand Up @@ -274,6 +297,7 @@ def quick_image(
vmin=0,
vmax=255,
alpha=alpha,
origin="lower",
)

# Add titles and labels
Expand Down Expand Up @@ -337,6 +361,7 @@ def quick_image(

def log_color_calc(
data: NDArray[np.integer | np.floating | np.complexfloating],
*,
data_range: NDArray[np.integer | np.floating] = None,
color_profile: str = "log_cut",
log_cut_val: float = None,
Expand Down
4 changes: 4 additions & 0 deletions docs/source/changelog/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
### Breaking Changes!

### New Features
* Added an option to `plotting.image.quick_image` to set the colormap to any
matplotlib colormap or to match the colormap used in the IDL version of `quick_image`.
* Added handling for `~` in paths in config yamls.

### Bug Fixes
* Fixed a couple small bugs in `plotting.image.quick_image` that caused the
linear scaling to be wrong and the image to be flipped vertically.
* Fixed checkpointing to actually work.
* Fixed a bug in the uvfits reader where it assumed the presence of "ra" and
"dec" header items which often present in MWA uvfits files but are non-standard.
Expand Down
91 changes: 91 additions & 0 deletions tests/test_plotting/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np
import pytest
from scipy.signal import convolve2d

from PyFHD.plotting.image import quick_image


@pytest.fixture
def pyramid():
top_hat_arr = np.zeros((12, 12), dtype=float)
# set middle to 1, leave zero edges
top_hat_arr[1:-1, 1:-1] = np.ones((10, 10), dtype=float)

pyramid = convolve2d(top_hat_arr, top_hat_arr)
yield pyramid


@pytest.mark.github_actions
@pytest.mark.parametrize("file_type", ["png", "eps", "pdf"])
@pytest.mark.parametrize("file_is_path", [True, False])
def test_quick_image_pyramid(tmp_path, pyramid, file_type, file_is_path):
"""This is just a smoke test to make sure the code runs."""

savefile = tmp_path / f"pyramid.{file_type}"
cmap = "idl"
missing_value = None
log = False
color_profile = "log_cut"
data_range = None
xvals = None
yvals = None
xrange = None
yrange = None
title = "pyramid"
xtitle = "East (m)"
ytitle = "North (m)"
cb_title = "Height (m)"
note = None

if not file_is_path:
savepath = savefile
savefile = str(savefile)
# set parameters differently to access different parts of the code
cmap = "magma"
log = True
pyramid_shape = pyramid.shape
xvals = np.arange(pyramid_shape[0])
yvals = np.arange(pyramid_shape[1])

# set parameters differently to access different parts of the code
if file_type == "pdf":
title = None
xtitle = None
ytitle = None
cb_title = None
note = "foo"
xrange = [1, 21]
yrange = [1, 21]
missing_value = 0
elif file_type == "eps":
pyramid_max = pyramid.max()
nonzero_min = np.min(pyramid[pyramid > 0])
cmap = None
color_profile = "sym_log"
if log:
data_range = [-1 * pyramid_max, pyramid_max]
else:
data_range = [nonzero_min, pyramid_max - nonzero_min]

quick_image(
pyramid,
xvals=xvals,
yvals=yvals,
xrange=xrange,
yrange=yrange,
cmap=cmap,
log=log,
missing_value=missing_value,
color_profile=color_profile,
data_range=data_range,
title=title,
xtitle=xtitle,
ytitle=ytitle,
cb_title=cb_title,
note=note,
savefile=savefile,
)
if not file_is_path:
savefile = savepath

assert savefile.is_file()
Loading