"""Delay space spectrum estimation and filtering."""
from typing import TypeVar
import numpy as np
import scipy.linalg as la
from caput import config, fftw, memh5, mpiarray
from cora.util import units
from numpy.lib.recfunctions import structured_to_unstructured
from ..core import containers, io, task
from ..util import random, tools
from .delayopt import delay_power_spectrum_maxpost
# A specific subclass of a FreqContainer
FreqContainerType = TypeVar("FreqContainerType", bound=containers.FreqContainer)
# ---------------------
# Delay Filter Classes
# ---------------------
[docs]
class DelayFilter(task.SingleTask):
"""Remove delays less than a given threshold.
This is performed by projecting the data onto the null space that is orthogonal
to any mode at low delays.
Note that for this task to work best the zero entries in the weights dataset
should factorize in frequency-time for each baseline. A mostly optimal masking
can be generated using the `draco.analysis.flagging.MaskFreq` task.
Attributes
----------
delay_cut : float
Delay value to filter at in seconds.
za_cut : float
Sine of the maximum zenith angle included in baseline-dependent delay
filtering. Default is 1 which corresponds to the horizon (ie: filters out all
zenith angles). Setting to zero turns off baseline dependent cut.
extra_cut : float
Increase the delay threshold beyond the baseline dependent term.
weight_tol : float
Maximum weight kept in the masked data, as a fraction of the largest weight
in the original dataset.
telescope_orientation : one of ('NS', 'EW', 'none')
Determines if the baseline-dependent delay cut is based on the north-south
component, the east-west component or the full baseline length. For
cylindrical telescopes oriented in the NS direction (like CHIME) use 'NS'.
The default is 'NS'.
window : bool
Apply the window function to the data when applying the filter.
Notes
-----
The delay cut applied is `max(za_cut * baseline / c + extra_cut, delay_cut)`.
"""
delay_cut = config.Property(proptype=float, default=0.1)
za_cut = config.Property(proptype=float, default=1.0)
extra_cut = config.Property(proptype=float, default=0.0)
weight_tol = config.Property(proptype=float, default=1e-4)
telescope_orientation = config.enum(["NS", "EW", "none"], default="NS")
window = config.Property(proptype=bool, default=False)
[docs]
def setup(self, telescope):
"""Set the telescope needed to obtain baselines.
Parameters
----------
telescope : TransitTelescope
The telescope object to use
"""
self.telescope = io.get_telescope(telescope)
[docs]
def process(self, ss):
"""Filter out delays from a SiderealStream or TimeStream.
Parameters
----------
ss : containers.SiderealStream
Data to filter.
Returns
-------
ss_filt : containers.SiderealStream
Filtered dataset.
"""
tel = self.telescope
ss.redistribute(["input", "prod", "stack"])
freq = ss.freq[:]
bandwidth = np.ptp(freq)
ssv = ss.vis[:].view(np.ndarray)
ssw = ss.weight[:].view(np.ndarray)
ia, ib = structured_to_unstructured(ss.prodstack, dtype=np.int16).T
baselines = tel.feedpositions[ia] - tel.feedpositions[ib]
for lbi, bi in ss.vis[:].enumerate(axis=1):
# Select the baseline length to use
baseline = baselines[bi]
if self.telescope_orientation == "NS":
baseline = abs(baseline[1]) # Y baseline
elif self.telescope_orientation == "EW":
baseline = abs(baseline[0]) # X baseline
else:
baseline = np.linalg.norm(baseline) # Norm
# In micro seconds
baseline_delay_cut = self.za_cut * baseline / units.c * 1e6 + self.extra_cut
delay_cut = np.amax([baseline_delay_cut, self.delay_cut])
# Calculate the number of samples needed to construct the delay null space.
# `4 * tau_max * bandwidth` is the amount recommended in the DAYENU paper
# and seems to work well here
number_cut = int(4.0 * bandwidth * delay_cut + 0.5)
# Flag frequencies and times with zero weight. This works much better if the
# incoming weight can be factorized
f_samp = (ssw[:, lbi] > 0.0).sum(axis=1)
f_mask = (f_samp == f_samp.max()).astype(np.float64)
t_samp = (ssw[:, lbi] > 0.0).sum(axis=0)
t_mask = (t_samp == t_samp.max()).astype(np.float64)
try:
NF = null_delay_filter(
freq,
delay_cut,
f_mask,
num_delay=number_cut,
window=self.window,
)
except la.LinAlgError as e:
raise RuntimeError(
f"Failed to converge while processing baseline {bi}"
) from e
ssv[:, lbi] = np.dot(NF, ssv[:, lbi])
ssw[:, lbi] *= f_mask[:, np.newaxis] * t_mask[np.newaxis, :]
return ss
[docs]
class DelayFilterBase(task.SingleTask):
"""Remove delays less than a given threshold.
This is performed by projecting the data onto the null space that is orthogonal
to any mode at low delays.
Note that for this task to work best the zero entries in the weights dataset
should factorize in frequency-time for each baseline. A mostly optimal masking
can be generated using the `draco.analysis.flagging.MaskFreq` task.
Attributes
----------
delay_cut : float
Delay value to filter at in seconds.
window : bool
Apply the window function to the data when applying the filter.
axis : str
The main axis to iterate over. The delay cut can be varied for each element
of this axis. If not set, a suitable default is picked for the container
type.
dataset : str
Apply the delay filter to this dataset. If not set, a suitable default
is picked for the container type.
Notes
-----
The delay cut applied is `max(za_cut * baseline / c + extra_cut, delay_cut)`.
"""
delay_cut = config.Property(proptype=float, default=0.1)
window = config.Property(proptype=bool, default=False)
axis = config.Property(proptype=str, default=None)
dataset = config.Property(proptype=str, default=None)
[docs]
def setup(self, telescope: io.TelescopeConvertible):
"""Set the telescope needed to obtain baselines.
Parameters
----------
telescope
The telescope object to use
"""
self.telescope = io.get_telescope(telescope)
def _delay_cut(self, ss: FreqContainerType, axis: str, ind: int) -> float:
"""Return the delay cut to use for this entry in microseconds.
Parameters
----------
ss
The container we are processing.
axis
The axis we are looping over.
ind : int
The (global) index along that axis.
Returns
-------
float
The delay cut in microseconds.
"""
return self.delay_cut
[docs]
def process(self, ss: FreqContainerType) -> FreqContainerType:
"""Filter out delays from a SiderealStream or TimeStream.
Parameters
----------
ss
Data to filter.
Returns
-------
ss_filt
Filtered dataset.
"""
if not isinstance(ss, containers.FreqContainer):
raise TypeError(
f"Can only process FreqContainer instances. Got {type(ss)}."
)
_default_axis = {
containers.SiderealStream: "stack",
containers.HybridVisMModes: "m",
containers.RingMap: "el",
containers.GridBeam: "theta",
}
_default_dataset = {
containers.SiderealStream: "vis",
containers.HybridVisMModes: "vis",
containers.RingMap: "map",
containers.GridBeam: "beam",
}
axis = self.axis
if self.axis is None:
for cls, ax in _default_axis.items():
if isinstance(ss, cls):
axis = ax
break
else:
raise ValueError(f"No default axis know for {type(ss)} container.")
dset = self.dataset
if self.dataset is None:
for cls, dataset in _default_dataset.items():
if isinstance(ss, cls):
dset = dataset
break
else:
raise ValueError(f"No default dataset know for {type(ss)} container.")
ss.redistribute(axis)
freq = ss.freq[:]
bandwidth = np.ptp(freq)
# Get views of the relevant datasets, but make sure that the weights have the
# same number of axes as the visibilities (inserting length-1 axes as needed)
ssv = ss.datasets[dset][:].view(np.ndarray)
ssw = match_axes(ss.datasets[dset], ss.weight).view(np.ndarray)
dist_axis_pos = list(ss.datasets[dset].attrs["axis"]).index(axis)
freq_axis_pos = list(ss.datasets[dset].attrs["axis"]).index("freq")
# Once we have selected elements of dist_axis the location of freq_axis_pos may
# be one lower
sel_freq_axis_pos = (
freq_axis_pos if freq_axis_pos < dist_axis_pos else freq_axis_pos - 1
)
for lbi, bi in ss.datasets[dset][:].enumerate(axis=dist_axis_pos):
# Extract the part of the array that we are processing, and
# transpose/reshape to make a 2D array with frequency as axis=0
vis_local = _take_view(ssv, lbi, dist_axis_pos)
vis_2D = _move_front(vis_local, sel_freq_axis_pos, vis_local.shape)
weight_local = _take_view(ssw, lbi, dist_axis_pos)
weight_2D = _move_front(weight_local, sel_freq_axis_pos, weight_local.shape)
# In micro seconds
delay_cut = self._delay_cut(ss, axis, bi)
# Calculate the number of samples needed to construct the delay null space.
# `4 * tau_max * bandwidth` is the amount recommended in the DAYENU paper
# and seems to work well here
number_cut = int(4.0 * bandwidth * delay_cut + 0.5)
# Flag frequencies and times (or all other axes) with zero weight. This
# works much better if the incoming weight can be factorized
f_samp = (weight_2D > 0.0).sum(axis=1)
f_mask = (f_samp == f_samp.max()).astype(np.float64)
t_samp = (weight_2D > 0.0).sum(axis=0)
t_mask = (t_samp == t_samp.max()).astype(np.float64)
# This has occasionally failed to converge, catch this and output enough
# info to debug after the fact
try:
NF = null_delay_filter(
freq,
delay_cut,
f_mask,
num_delay=number_cut,
window=self.window,
)
except la.LinAlgError as e:
raise RuntimeError(
f"Failed to converge while processing baseline {bi}"
) from e
vis_local[:] = _inv_move_front(
np.dot(NF, vis_2D), sel_freq_axis_pos, vis_local.shape
)
weight_local[:] *= _inv_move_front(
f_mask[:, np.newaxis] * t_mask[np.newaxis, :],
sel_freq_axis_pos,
weight_local.shape,
)
return ss
# -----------------------------
# Delay Transform Base Classes
# -----------------------------
[docs]
class DelayPowerSpectrumContainerMixin(GeneralInputContainerMixin):
"""Mixin for creating a delay power spectrum output container.
Attributes
----------
nsamp : int
Number of samples to compute for each power spectrum.
Default is 1.
save_samples : bool
When using a sampling-based power spectrum estimator,
save out every sample in the chain. Otherwise, only save
the final power spectrum. Default is False.
save_spectrum_mask : bool
Save a mask which flags spectra which have significant error,
as determined by the estimator. Default is False.
"""
nsamp = config.Property(proptype=int, default=1)
save_samples = config.Property(proptype=bool, default=False)
save_spectrum_mask = config.Property(proptype=bool, default=False)
def _create_output(
self,
ss: containers.FreqContainer,
delays: np.ndarray,
coord_axes: list[str] | np.ndarray,
) -> containers.ContainerBase:
"""Create the output container for the delay power spectrum.
If `coord_axes` is a list of strings then it is assumed to be a list of the
names of the folded axes. If it's an array then assume it is the actual axis
definition.
"""
# If only one axis is being collapsed, use that as the baseline axis definition,
# otherwise just use integer indices
if isinstance(coord_axes, np.ndarray):
bl = coord_axes
elif len(coord_axes) == 1:
bl = ss.index_map[coord_axes[0]]
else:
bl = np.prod([len(ss.index_map[ax]) for ax in coord_axes])
# Initialise the spectrum container
delay_spec = containers.DelaySpectrum(
baseline=bl,
delay=delays,
sample=self.nsamp,
attrs_from=ss,
)
delay_spec.redistribute("baseline")
delay_spec.spectrum[:] = 0.0
# Copy the index maps for all the flattened axes into the output container, and
# write out their order into an attribute so we can reconstruct this easily
# when loading in the spectrum
if isinstance(coord_axes, list):
for ax in coord_axes:
delay_spec.create_index_map(ax, ss.index_map[ax])
delay_spec.attrs["baseline_axes"] = coord_axes
if self.save_samples:
delay_spec.add_dataset("spectrum_samples")
# Initialize a mask dataset to record the baselines for
# which the estimator did/didn't converge.
if self.save_spectrum_mask:
delay_spec.add_dataset("spectrum_mask")
delay_spec.datasets["spectrum_mask"][:] = 0
# Save the frequency axis of the input data as an attribute in the output
# container
delay_spec.attrs["freq"] = ss.freq
return delay_spec
[docs]
class DelaySpectrumContainerMixin(GeneralInputContainerMixin):
"""Mixin for creating a delay transform output container.
Attributes
----------
save_spectrum_mask : bool
Save a mask which flags spectra which have significant error,
as determined by the estimator. Default is False.
"""
save_spectrum_mask = config.Property(proptype=bool, default=False)
def _create_output(
self, ss: containers.FreqContainer, delays: np.ndarray, coord_axes: list[str]
) -> containers.ContainerBase:
"""Create the output container for the delay transform."""
# Initialise the spectrum container
nbase = np.prod([len(ss.index_map[ax]) for ax in coord_axes])
delay_spec = containers.DelayTransform(
baseline=nbase,
sample=ss.index_map[self.sample_axis],
delay=delays,
attrs_from=ss,
weight_boost=self.weight_boost,
)
delay_spec.redistribute("baseline")
delay_spec.spectrum[:] = 0.0
# Copy the index maps for all the flattened axes into the output container, and
# write out their order into an attribute so we can reconstruct this easily
# when loading in the spectrum
for ax in coord_axes:
delay_spec.create_index_map(ax, ss.index_map[ax])
delay_spec.attrs["baseline_axes"] = coord_axes
# Initialize a mask dataset to record flagged
# samples and baselines.
if self.save_spectrum_mask:
delay_spec.add_dataset("spectrum_mask")
delay_spec.datasets["spectrum_mask"][:] = 0
# Save the frequency axis of the input data as an attribute in the output
# container
delay_spec.attrs["freq"] = ss.freq
return delay_spec
# -------------------------------------
# Classes to compute a delay transform
# -------------------------------------
[docs]
class DelaySpectrumBase(DelaySpectrumContainerMixin, DelayTransformBase):
"""Base class for delay spectrum estimation (non-functional)."""
def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
"""Estimate the delay spectrum via inverse FFT.
Parameters
----------
data_view : `caput.mpiarray.MPIArray`
Data to transform.
weight_view : `caput.mpiarray.MPIArray`
Weights corresponding to `data_view`.
out_cont : `containers.DelayTransform` or `containers.DelaySpectrum`
Container for output delay spectrum or power spectrum.
delays
The delays to evaluate at.
channel_ind
The indices of the available frequency channels in the full set of channels.
Returns
-------
out_cont : `containers.DelaySpectrum`
Output delay spectrum.
"""
nbase = out_cont.spectrum.global_shape[0]
ndelay = len(delays)
prior = self._get_prior(nbase)
# Iterate over the combined baseline axis
for lbi, bi in out_cont.spectrum[:].enumerate(axis=0):
self.log.debug(f"Estimating the delay transform of baseline {bi}/{nbase}")
data = data_view.local_array[lbi]
weight = weight_view.local_array[lbi]
# Apply data cuts
t = self._cut_data(data, weight)
if t is None:
# Record this sample as bad
if self.save_spectrum_mask:
out_cont.datasets["spectrum_mask"][bi] = 1
continue
data, weight, nzf, nzt = t
# Estimate the delay transform using an estimator
y_spec = self._estimator(data, weight, prior[lbi], ndelay, channel_ind[nzf])
out_cont.spectrum[bi, nzt] = y_spec
# Record missing samples in the spectrum mask
if self.save_spectrum_mask:
out_cont.datasets["spectrum_mask"][bi][~nzt] = 1
return out_cont
def _get_prior(self, nbase):
"""Get a power spectrum prior.
Parameters
----------
nbase : int
Number of baselines
Returns
-------
prior : list | np.ndarray
Power spectrum prior.
"""
return NotImplementedError()
def _estimator(self, data, weight, S, ndelay, channel_ind):
"""Use an estimator to calculate the delay spectrum.
Returns
-------
dtransform : np.ndarray
Estimated delay transform.
"""
raise NotImplementedError()
[docs]
class DelaySpectrumFFT(DelaySpectrumBase):
"""Class to measure the delay spectrum of a general container via ifft."""
def _get_prior(self, nbase):
"""Get a power spectrum prior."""
return [None] * nbase
def _estimator(self, data, weight, S, ndelay, channel_ind):
"""Use inverse FFT to calculate the delay transform of a data slice.
Returns
-------
dtransform : np.ndarray
Estimated delay transform.
"""
y_spec = delay_spectrum_fft(
data, ndelay, self.window if self.apply_window else None
)
return np.fft.fftshift(y_spec, axes=-1)
[docs]
class DelaySpectrumWienerFilter(DelaySpectrumBase):
"""Class to measure delay spectrum of general container via Wiener filtering.
The spectrum is calculated by applying a Wiener filter to the input frequency
spectrum, assuming an input model for the delay power spectrum of the signal and
that the noise power is described by the weights of the input container. See
https://arxiv.org/abs/2202.01242, Eq. A6 for details.
"""
[docs]
def setup(self, dps=None):
"""Set the delay power spectrum to use as the signal covariance.
Parameters
----------
dps : `containers.DelaySpectrum`
Delay power spectrum for signal part of Wiener filter.
"""
self.dps = dps
super().setup()
def _get_prior(self, nbase):
"""Get a power spectrum prior."""
return self.dps.spectrum[:].local_array
def _estimator(self, data, weight, S, ndelay, channel_ind):
"""Use a Wiener filter to calculate the delay transform of a data slice.
Returns
-------
dtransform : np.ndarray
Estimated delay transform.
"""
y_spec = delay_spectrum_wiener_filter(
np.fft.fftshift(S),
data,
ndelay,
weight,
window=self.window if self.apply_window else None,
fsel=channel_ind,
complex_timedomain=self.complex_timedomain,
)
return np.fft.fftshift(y_spec, axes=-1)
[docs]
class DelaySpectrumWienerFilterIteratePS(DelaySpectrumWienerFilter):
"""Class to estimate the delay spectrum using Wiener filtering.
This class extends `DelaySpectrumWienerFilter` by allowing the
delay power spectrum (`dps`) to be updated with each call to `process`
instead of being fixed at `setup`. The updated `dps` is used to apply
the Wiener filter to the input frequency spectrum.
"""
[docs]
def process(self, ss, dps):
"""Estimate the delay spectrum.
Parameters
----------
ss : `containers.FreqContainer`
Data to transform. Must have a frequency axis.
dps : `containers.DelaySpectrum`
Delay power spectrum for signal part of Wiener filter.
Returns
-------
out_cont : `containers.DelayTransform` or `containers.DelaySpectrum`
Output delay spectrum or delay power spectrum.
"""
self.dps = dps
return super().process(ss)
# -------------------------------------------------------------
# Class to compute a delay power spectrum from a delay spectrum
# -------------------------------------------------------------
[docs]
class DelaySpectrumToPowerSpectrum(task.SingleTask):
"""Compute a delay power spectrum from a delay spectrum."""
[docs]
def process(self, dspec: containers.DelayTransform) -> containers.DelaySpectrum:
"""Get the delay power spectrum from a delay spectrum.
Parameters
----------
dspec
Delay spectrum container.
Returns
-------
pspec
Delay power spectrum container.
"""
dspec.redistribute("baseline")
# Make the power spectrum container
pspec = containers.DelaySpectrum(attrs_from=dspec, axes_from=dspec)
pspec.redistribute("baseline")
# If a spectrum mask exists, use it
if "spectrum_mask" in dspec.datasets:
w = dspec.datasets["spectrum_mask"][:].local_array
w = ~w[..., np.newaxis]
# Also, add a spectrum mask to the power spectrum
pspec.add_dataset("spectrum_mask")
pspec.datasets["spectrum_mask"][:] = 0
else:
w = None
ps = pspec.spectrum[:].local_array
ds = dspec.spectrum[:].local_array
ps[:] = np.var(ds, axis=1, where=w)
# Check for NaNs and mask them. This happens if an entire slice
# along the variance axis is masked, and should correspond
# to bad baselines. Don't bother if no mask was used.
if w is not None:
nans = np.isnan(ps)
ps[nans] = 0.0
pspec.datasets["spectrum_mask"][:].local_array[:] = np.any(nans, axis=-1)
return pspec
# ---------------------------------------------------
# Classes to directly compute a delay power spectrum
# ---------------------------------------------------
[docs]
class DelayPowerSpectrumBase(DelayPowerSpectrumContainerMixin, DelayTransformBase):
"""Base class for delay power spectrum estimation (non-functional)."""
def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
"""Estimate the delay spectrum or power spectrum.
Parameters
----------
data_view : `caput.mpiarray.MPIArray`
Data to transform.
weight_view : `caput.mpiarray.MPIArray`
Weights corresponding to `data_view`.
out_cont : `containers.DelayTransform` or `containers.DelaySpectrum`
Container for output delay spectrum or power spectrum.
delays
The delays to evaluate at.
channel_ind
The indices of the available frequency channels in the full set of channels.
Returns
-------
out_cont : `containers.DelayTransform` or `containers.DelaySpectrum`
Output delay spectrum or delay power spectrum.
"""
nbase = out_cont.spectrum.global_shape[0]
ndelay = len(delays)
# Set initial conditions for delay power spectrum
prior = self._get_prior(nbase, ndelay, delays.dtype)
# Iterate over all baselines and use the Gibbs sampler to estimate the spectrum
for lbi, bi in out_cont.spectrum[:].enumerate(axis=0):
self.log.debug(f"Delay transforming baseline {bi}/{nbase}")
# Get the local selections
data = data_view.local_array[lbi]
weight = weight_view.local_array[lbi]
# Apply the cuts to the data
t = self._cut_data(data, weight)
if t is None:
# Record this sample as bad
if self.save_spectrum_mask:
out_cont.datasets["spectrum_mask"][bi] = 1
continue
data, weight, nzf, _ = t
spec, samples, success = self._estimator(
data, weight, prior[lbi], ndelay, channel_ind[nzf]
)
# Save out the resulting spectrum, samples, and mask
out_cont.spectrum[bi] = spec
if self.save_spectrum_mask and not success:
out_cont.datasets["spectrum_mask"][bi] = 1
if self.save_samples:
nsamp = len(samples)
out_cont.datasets["spectrum_samples"][:, bi] = 0.0
out_cont.datasets["spectrum_samples"][-nsamp:, bi] = np.array(samples)
if self.save_spectrum_mask:
# Record number of converged baselines for debugging info.
n_conv = nbase - out_cont.datasets["spectrum_mask"][:].sum().allreduce()
self.log.debug(f"{n_conv}/{nbase} unflagged baselines.")
return out_cont
def _get_prior(self, nbase, ndelay, dtype):
"""Get an initial estimate of the power spectrum.
Parameters
----------
nbase : int
Number of baselines.
ndelay : int
Number of delay samples.
dtype : type | np.dtype | str
Datatype for the sample.
"""
raise NotImplementedError()
def _estimator(self, data, weight, S, ndelay, channel_ind):
"""Use an estimator to calculate the power spectrum of a data slice.
Returns
-------
spec : np.ndarray
Estimated power spectrum
samples : list[np.ndarray]
Chain of samples. This can be length-one depending
on the estimator
success : bool
Whether or not the estimator thinks the
result is reasonable.
"""
raise NotImplementedError()
[docs]
class DelayPowerSpectrumGibbs(DelayPowerSpectrumBase, random.RandomTask):
"""Use a Gibbs sampler to estimate the delay power spectrum.
The spectrum returned is the median of the final half of the
samples calulated.
Attributes
----------
initial_amplitude : float, optional
The Gibbs sampler will be initialized with a flat power spectrum with
this amplitude. Unused if maxpost=True (flat spectrum is a bad initial
guess for the max-likelihood estimator). Default: 10.
"""
initial_amplitude = config.Property(proptype=float, default=10.0)
def _get_prior(self, nbase, ndelay, dtype):
"""Start with a flat prior."""
return np.ones((nbase, ndelay), dtype=dtype) * self.initial_amplitude
def _estimator(self, data, weight, S, ndelay, channel_ind):
"""Use a gibbs sampler to calculate a power spectrum."""
samples = delay_power_spectrum_gibbs(
data,
ndelay,
weight,
S,
window=self.window if self.apply_window else None,
fsel=channel_ind,
niter=self.nsamp,
rng=self.rng,
complex_timedomain=self.complex_timedomain,
)
spec = np.median(samples[-(self.nsamp // 2) :], axis=0)
spec = np.fft.fftshift(spec)
return spec, samples, True
[docs]
class DelayPowerSpectrumNRML(DelayPowerSpectrumBase):
"""Use a NRML method to estimate the delay power spectrum.
Attributes
----------
maxpost_tol : float, optional
The convergence tolerance used by scipy.optimize.minimize
in the maximum likelihood estimator.
"""
maxpost_tol = config.Property(proptype=float, default=1e-3)
def _get_prior(self, nbase, ndelay, dtype):
"""Start with a flat prior."""
return [None] * nbase
def _estimator(self, data, weight, S, ndelay, channel_ind):
"""Use a maximum likelihood to calculate a power spectrum."""
samples, success = delay_power_spectrum_maxpost(
data,
ndelay,
weight,
S,
window=self.window if self.apply_window else None,
fsel=channel_ind,
maxiter=self.nsamp,
tol=self.maxpost_tol,
)
spec = np.fft.fftshift(samples[-1])
return spec, samples, success
[docs]
class DelayCrossPowerSpectrumEstimator(DelayPowerSpectrumGibbs, random.RandomTask):
"""A delay cross power spectrum estimator.
This takes multiple compatible `FreqContainer`s as inputs and will return a
`DelayCrossSpectrum` container with the full pair-wise cross power spectrum.
"""
def _prepare_inputs(
self, sslist: list[containers.FreqContainer]
) -> tuple[list[mpiarray.MPIArray], list[mpiarray.MPIArray], list[str]]:
if len(sslist) == 0:
raise ValueError("No datasets passed.")
freq_ref = sslist[0].freq
data_views = []
weight_views = []
coord_axes = None
for ss in sslist:
ss.redistribute("freq")
if (ss.freq != freq_ref).all():
raise ValueError("Input containers must have the same frequencies.")
dv, wv, ca = super()._prepare_inputs(self, ss)
if coord_axes is not None and not coord_axes == ca:
raise ValueError("Different axes found for the input containers.")
data_views.append(dv)
weight_views.append(wv)
coord_axes = ca
return data_views, weight_views, coord_axes
def _create_output(
self,
ss: list[containers.FreqContainer],
delays: np.ndarray,
coord_axes: list[str],
) -> containers.ContainerBase:
"""Create the output container for the delay power spectrum.
If `coord_axes` is a list of strings then it is assumed to be a list of the
names of the folded axes. If it's an array then assume it is the actual axis
definition.
"""
ssref = ss[0]
ndata = len(ss)
# If only one axis is being collapsed, use that as the baseline axis definition,
# otherwise just use integer indices
if len(coord_axes) == 1:
bl = ssref.index_map[coord_axes[0]]
else:
bl = np.prod([len(ssref.index_map[ax]) for ax in coord_axes])
# Initialise the spectrum container
delay_spec = containers.DelayCrossSpectrum(
baseline=bl,
dataset=ndata,
delay=delays,
sample=self.nsamp,
attrs_from=ssref,
)
delay_spec.redistribute("baseline")
delay_spec.spectrum[:] = 0.0
# Copy the index maps for all the flattened axes into the output container, and
# write out their order into an attribute so we can reconstruct this easily
# when loading in the spectrum
if isinstance(coord_axes, list):
for ax in coord_axes:
delay_spec.create_index_map(ax, ssref.index_map[ax])
delay_spec.attrs["baseline_axes"] = coord_axes
if self.save_samples:
delay_spec.add_dataset("spectrum_samples")
# Save the frequency axis of the input data as an attribute in the output
# container
delay_spec.attrs["freq"] = ssref.freq
return delay_spec
def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
ndata = len(data_view)
ndelay = len(delays)
nbase = out_cont.spectrum.shape[-2]
initial_S = self._get_prior(nbase, ndelay, delays.dtype)
if initial_S.ndim == 2:
# Expand the sample shape to match the number of datasets
initial_S = (
np.identity(ndata)[np.newaxis, ..., np.newaxis]
* initial_S[:, np.newaxis, np.newaxis]
)
elif (initial_S.ndim != 4) or (initial_S.shape[1] != ndata):
raise ValueError(
f"Expected an initial sample with dimension 4 and {ndata} datasets. "
f"Got sample with dimension {initial_S.ndim} and shape {initial_S.shape}."
)
# Initialize the random number generator we'll use
rng = self.rng
# Iterate over all baselines and use the Gibbs sampler to estimate the spectrum
for lbi, bi in out_cont.spectrum[:].enumerate(axis=-2):
self.log.debug(f"Delay transforming baseline {bi}/{nbase}")
# Get the local selections for all datasets and combine into a single array
data = np.array([d.local_array[lbi] for d in data_view])
weight = np.array([w.local_array[lbi] for w in weight_view])
# Apply the cuts to the data
t = self._cut_data(data, weight)
if t is None:
continue
data, weight, nzf, _ = t
spec = delay_spectrum_gibbs_cross(
data,
ndelay,
weight,
initial_S[lbi],
window=self.window if self.apply_window else None,
fsel=channel_ind[nzf],
niter=self.nsamp,
rng=rng,
)
# Take an average over the last half of the delay spectrum samples
# (presuming that removes the burn-in)
spec_av = np.median(spec[-(self.nsamp // 2) :], axis=0)
out_cont.spectrum[..., bi, :] = np.fft.fftshift(spec_av)
if self.save_samples:
out_cont.datasets["spectrum_samples"][..., bi, :] = spec
return out_cont
# Raise a deprecation warning
[docs]
class DelayPowerSpectrumStokesIEstimator(DelayPowerSpectrumGibbs):
"""Deprecated."""
[docs]
def setup(self, requires=None):
"""Raise a deprecation warnings."""
raise DeprecationWarning(
"`DelayPowerSpectrumStokesIEstimator` is deprecated. "
"Use `draco.transform.StokesI` to generate Stokes I "
"visibilities, then use `DelayPowerSpectrumGibbs` "
"or `DelayPowerSpectrumNRML`."
)
[docs]
class DelayPowerSpectrumGeneralEstimator(DelayPowerSpectrumGibbs):
"""Deprecated."""
[docs]
def setup(self, requires=None):
"""Raise a deprecation warnings."""
raise DeprecationWarning(
"`DelayPowerSpectrumGeneralEstimator` is deprecated. "
"Use `DelayPowerSpectrumGibbs` or `DelayPowerSpectrumNRML`."
)
# -------------------------------------
# Functions to create Fourier matrices
# -------------------------------------
[docs]
def fourier_matrix_r2c(N, fsel=None):
"""Generate a Fourier matrix to represent a real to complex FFT.
Parameters
----------
N : integer
Length of timestream that we are transforming to. Must be even.
fsel : array_like, optional
Indexes of the frequency channels to include in the transformation
matrix. By default, assume all channels.
Returns
-------
Fr : np.ndarray
An array performing the Fourier transform from a real time series to
frequencies packed as alternating real and imaginary elements,
"""
if fsel is None:
fa = np.arange(N // 2 + 1)
else:
fa = np.array(fsel)
fa = fa[:, np.newaxis]
ta = np.arange(N)[np.newaxis, :]
Fr = np.zeros((2 * fa.shape[0], N), dtype=np.float64)
Fr[0::2] = np.cos(2 * np.pi * ta * fa / N)
Fr[1::2] = -np.sin(2 * np.pi * ta * fa / N)
return Fr
[docs]
def fourier_matrix_c2r(N, fsel=None):
"""Generate a Fourier matrix to represent a complex to real FFT.
Parameters
----------
N : integer
Length of timestream that we are transforming to. Must be even.
fsel : array_like, optional
Indexes of the frequency channels to include in the transformation
matrix. By default, assume all channels.
Returns
-------
Fr : np.ndarray
An array performing the Fourier transform from frequencies packed as
alternating real and imaginary elements, to the real time series.
"""
if fsel is None:
fa = np.arange(N // 2 + 1)
else:
fa = np.array(fsel)
fa = fa[np.newaxis, :]
mul = np.where((fa == 0) | (fa == N // 2), 1.0, 2.0) / N
ta = np.arange(N)[:, np.newaxis]
Fr = np.zeros((N, 2 * fa.shape[1]), dtype=np.float64)
Fr[:, 0::2] = np.cos(2 * np.pi * ta * fa / N) * mul
Fr[:, 1::2] = -np.sin(2 * np.pi * ta * fa / N) * mul
return Fr
[docs]
def fourier_matrix_c2c(N, fsel=None):
"""Generate a Fourier matrix to represent a complex to complex FFT.
These Fourier conventions match `numpy.fft.fft()`.
Parameters
----------
N : integer
Length of timestream that we are transforming to.
fsel : array_like, optional
Indices of the frequency channels to include in the transformation
matrix. By default, assume all channels.
Returns
-------
F : np.ndarray
An array performing the Fourier transform from a complex time series to
frequencies, with both input and output packed as alternating real and
imaginary elements.
"""
if fsel is None:
fa = np.arange(N)
else:
fa = np.array(fsel)
fa = fa[:, np.newaxis]
ta = np.arange(N)[np.newaxis, :]
F = np.zeros((2 * fa.shape[0], 2 * N), dtype=np.float64)
arg = 2 * np.pi * ta * fa / N
F[0::2, 0::2] = np.cos(arg)
F[0::2, 1::2] = np.sin(arg)
F[1::2, 0::2] = -np.sin(arg)
F[1::2, 1::2] = np.cos(arg)
return F
[docs]
def fourier_matrix(N: int, fsel: np.ndarray | None = None) -> np.ndarray:
"""Generate a Fourier matrix to represent a real to complex FFT.
Parameters
----------
N : integer
Length of timestream that we are transforming to. Must be even.
fsel : array_like, optional
Indexes of the frequency channels to include in the transformation
matrix. By default, assume all channels.
Returns
-------
Fr : np.ndarray
An array performing the Fourier transform from a real time series to
frequencies packed as alternating real and imaginary elements,
"""
if fsel is None:
fa = np.arange(N)
else:
fa = np.array(fsel)
fa = fa[:, np.newaxis]
ta = np.arange(N)[np.newaxis, :]
return np.exp(-2.0j * np.pi * ta * fa / N)
def _complex_to_alternating_real(array):
"""View complex numbers as an array with alternating real and imaginary components.
Parameters
----------
array : array_like
Input array of complex numbers.
Returns
-------
out : array_like
Output array of alternating real and imaginary components. These components are
expanded along the last axis, such that if `array` has `N` complex elements in
its last axis, `out` will have `2N` real elements.
"""
return array.astype(np.complex128, order="C").view(np.float64)
def _alternating_real_to_complex(array):
"""View real numbers as complex, interpreted as alternating real and imag. components.
Parameters
----------
array : array_like
Input array of real numbers. Last axis must have even number of elements.
Returns
-------
out : array_like
Output array of complex numbers, derived from compressing the last axis (if
`array` has `N` real elements in the last axis, `out` will have `N/2` complex
elements).
"""
return array.astype(np.float64, order="C").view(np.complex128)
# ----------------------------------------------------------------
# Implementation of delay transform and power spectrum algorithms
# ----------------------------------------------------------------
def _compute_delay_spectrum_inputs(data, N, Ni, fsel, window, complex_timedomain):
"""Compute quantities needed for Gibbs sampling and/or Wiener filtering.
These quantities are needed by both :func:`delay_power_spectrum_gibbs` and
:func:`delay_spectrum_wiener_filter`, so we compute them in this separate routine.
"""
total_freq = N if complex_timedomain else N // 2 + 1
if fsel is None:
fsel = np.arange(total_freq)
# Construct the Fourier matrix
F = (
fourier_matrix_c2c(N, fsel)
if complex_timedomain
else fourier_matrix_r2c(N, fsel)
)
# Construct a view of the data with alternating real and imaginary parts
data = _complex_to_alternating_real(data).T.copy()
# Window the frequency data
if window is not None:
# Construct the window function
x = fsel * 1.0 / total_freq
w = tools.window_generalised(x, window=window)
w = np.repeat(w, 2)
# Apply to the projection matrix and the data
F *= w[:, np.newaxis]
data *= w[:, np.newaxis]
if complex_timedomain:
is_real_freq = np.zeros_like(fsel).astype(bool)
else:
is_real_freq = (fsel == 0) | (fsel == N // 2)
# Construct the Noise inverse array for the real and imaginary parts of the
# frequency spectrum (taking into account that the zero and Nyquist frequencies are
# strictly real if the delay spectrum is assumed to be real)
Ni_r = np.zeros(2 * Ni.shape[0])
Ni_r[0::2] = np.where(is_real_freq, Ni, Ni * 2)
Ni_r[1::2] = np.where(is_real_freq, 0.0, Ni * 2)
# Create the transpose of the Fourier matrix weighted by the noise
# (this is used multiple times)
FTNih = F.T * Ni_r[np.newaxis, :] ** 0.5
FTNiF = np.dot(FTNih, FTNih.T)
# Pre-whiten the data to save doing it repeatedly
data = data * Ni_r[:, np.newaxis] ** 0.5
# Return data and inverse-noise-weighted Fourier matrices
return data, FTNih, FTNiF
[docs]
def delay_power_spectrum_gibbs(
data,
N,
Ni,
initial_S,
window="nuttall",
fsel=None,
niter=20,
rng=None,
complex_timedomain=False,
):
"""Estimate the delay power spectrum by Gibbs sampling.
This routine estimates the spectrum at the `N` delay samples conjugate to
an input frequency spectrum with ``N/2 + 1`` channels (if the delay spectrum is
assumed real) or `N` channels (if the delay spectrum is assumed complex).
A subset of these channels can be specified using the `fsel` argument.
Parameters
----------
data : np.ndarray[:, freq]
Data to estimate the delay spectrum of.
N : int
The length of the output delay spectrum. There are assumed to be `N/2 + 1`
total frequency channels if assuming a real delay spectrum, or `N` channels
for a complex delay spectrum.
Ni : np.ndarray[freq]
Inverse noise variance.
initial_S : np.ndarray[delay]
The initial delay power spectrum guess.
window : one of {'nuttall', 'blackman_nuttall', 'blackman_harris', None}, optional
Apply an apodisation function. Default: 'nuttall'.
fsel : np.ndarray[freq], optional
Indices of channels that we have data at. By default assume all channels.
niter : int, optional
Number of Gibbs samples to generate.
rng : np.random.Generator, optional
A generator to use to produce the random samples.
complex_timedomain : bool, optional
If True, assume input data arose from a complex timestream. If False, assume
input data arose from a real timestream, such that the first and last frequency
channels have purely real values. Default: False.
Returns
-------
spec : list
List of spectrum samples.
"""
# Get reference to RNG
if rng is None:
rng = random.default_rng()
spec = []
# Pre-whiten and apply frequency window to data, and compute F^dagger N^{-1/2}
# and F^dagger N^{-1} F
data, FTNih, FTNiF = _compute_delay_spectrum_inputs(
data, N, Ni, fsel, window, complex_timedomain
)
# Set the initial guess for the delay power spectrum.
S_samp = initial_S
def _draw_signal_sample_f(S):
# Draw a random sample of the signal (delay spectrum) assuming a Gaussian model
# with a given delay power spectrum `S`. Do this using the perturbed Wiener
# filter approach
# This method is fastest if the number of frequencies is larger than the number
# of delays we are solving for. Typically this isn't true, so we probably want
# `_draw_signal_sample_t`
# Construct the Wiener covariance
if complex_timedomain:
# If delay spectrum is complex, extend S to correspond to the individual
# real and imaginary components of the delay spectrum, each of which have
# power spectrum equal to 0.5 times the power spectrum of the complex
# delay spectrum, if the statistics are circularly symmetric
S = 0.5 * np.repeat(S, 2)
Si = 1.0 * tools.invert_no_zero(S)
Ci = np.diag(Si) + FTNiF
# Draw random vectors that form the perturbations
if complex_timedomain:
# If delay spectrum is complex, draw for real and imaginary components
# separately
w1 = rng.standard_normal((2 * N, data.shape[1]))
else:
w1 = rng.standard_normal((N, data.shape[1]))
w2 = rng.standard_normal(data.shape)
# Construct the random signal sample by forming a perturbed vector and
# then doing a matrix solve
y = np.dot(FTNih, data + w2) + Si[:, np.newaxis] ** 0.5 * w1
return la.solve(Ci, y, assume_a="pos")
def _draw_signal_sample_t(S):
# This method is fastest if the number of delays is larger than the number of
# frequencies. This is usually the regime we are in.
# Construct various dependent matrices
if complex_timedomain:
# If delay spectrum is complex, extend S to correspond to the individual
# real and imaginary components of the delay spectrum, each of which have
# power spectrum equal to 0.5 times the power spectrum of the complex
# delay spectrum, if the statistics are circularly symmetric
S = 0.5 * np.repeat(S, 2)
Sh = S**0.5
Rt = Sh[:, np.newaxis] * FTNih
R = Rt.T.conj()
# Draw random vectors that form the perturbations
if complex_timedomain:
# If delay spectrum is complex, draw for real and imaginary components
# separately
w1 = rng.standard_normal((2 * N, data.shape[1]))
else:
w1 = rng.standard_normal((N, data.shape[1]))
w2 = rng.standard_normal(data.shape)
# Perform the solve step (rather than explicitly using the inverse)
y = data + w2 - np.dot(R, w1)
Ci = np.identity(2 * Ni.shape[0]) + np.dot(R, Rt)
x = la.solve(Ci, y, assume_a="pos")
return Sh[:, np.newaxis] * (np.dot(Rt, x) + w1)
def _draw_ps_sample(d):
# Draw a random delay power spectrum sample assuming the signal is Gaussian and
# we have a flat prior on the power spectrum.
# This means drawing from a inverse chi^2.
if complex_timedomain:
# If delay spectrum is complex, combine real and imaginary components
# stored in d, such that variance below is variance of complex spectrum
d = d[0::2] + 1.0j * d[1::2]
S_hat = d.var(axis=1)
df = d.shape[1]
chi2 = rng.chisquare(df, size=d.shape[0])
return S_hat * df / chi2
# Select the method to use for the signal sample based on how many frequencies
# versus delays there are
_draw_signal_sample = (
_draw_signal_sample_f if (len(fsel) > 0.25 * N) else _draw_signal_sample_t
)
# Perform the Gibbs sampling iteration for a given number of loops and
# return the power spectrum output of them.
for ii in range(niter):
d_samp = _draw_signal_sample(S_samp)
S_samp = _draw_ps_sample(d_samp)
spec.append(S_samp)
return spec
[docs]
def delay_spectrum_gibbs_cross(
data: np.ndarray,
N: int,
Ni: np.ndarray,
initial_S: np.ndarray,
window: str = "nuttall",
fsel: np.ndarray | None = None,
niter: int = 20,
rng: np.random.Generator | None = None,
) -> list[np.ndarray]:
"""Estimate the delay power spectrum by Gibbs sampling.
This routine estimates the spectrum at the `N` delay samples conjugate to
an input frequency spectrum with ``N/2 + 1`` channels (if the delay spectrum is
assumed real) or `N` channels (if the delay spectrum is assumed complex).
A subset of these channels can be specified using the `fsel` argument.
Parameters
----------
data
A 3D array of [dataset, sample, freq]. The delay cross-power spectrum of these
will be calculated.
N
The length of the output delay spectrum. There are assumed to be `N/2 + 1`
total frequency channels if assuming a real delay spectrum, or `N` channels
for a complex delay spectrum.
Ni
Inverse noise variance as a 3D [dataset, sample, freq] array.
initial_S
The initial delay cross-power spectrum guess. A 3D array of [data1, data2,
delay].
window : one of {'nuttall', 'blackman_nuttall', 'blackman_harris', None}, optional
Apply an apodisation function. Default: 'nuttall'.
fsel
Indices of channels that we have data at. By default assume all channels.
niter
Number of Gibbs samples to generate.
rng
A generator to use to produce the random samples.
Returns
-------
spec : list
List of cross-power spectrum samples.
"""
# Get reference to RNG
if rng is None:
rng = random.default_rng()
spec = []
nd, nsamp, Nf = data.shape
if fsel is None:
fsel = np.arange(Nf)
elif len(fsel) != Nf:
raise ValueError(
"Length of frequency selection must match frequencies passed. "
f"{len(fsel)} != {data.shape[-1]}"
)
# Construct the Fourier matrix
F = fourier_matrix(N, fsel)
if nd == 0:
raise ValueError("Need at least one set of data")
# We want the sample axis to be last
data = data.transpose(0, 2, 1)
# Window the frequency data
if window is not None:
# Construct the window function
x = fsel * 1.0 / N
w = tools.window_generalised(x, window=window)
# Apply to the projection matrix and the data
F *= w[:, np.newaxis]
data *= w[:, np.newaxis]
# Create the transpose of the Fourier matrix weighted by the noise
# (this is used multiple times)
# This is packed as a single freq -> delay projection per dataset
FTNih = F.T[np.newaxis, :, :] * Ni[:, np.newaxis, :] ** 0.5
# This should be an array for each dataset i of F_i^H N_i^{-1} F_i
FTNiF = np.zeros((nd, N, nd, N), dtype=np.complex128)
for ii in range(nd):
FTNiF[ii, :, ii] = FTNih[ii] @ FTNih[ii].T.conj()
# Pre-whiten the data to save doing it repeatedly
data *= Ni[:, :, np.newaxis] ** 0.5
# Set the initial guess for the delay power spectrum.
S_samp = initial_S
def _draw_signal_sample_f(S):
# Draw a random sample of the signal (delay spectrum) assuming a Gaussian model
# with a given delay power spectrum `S`. Do this using the perturbed Wiener
# filter approach
# This method is fastest if the number of frequencies is larger than the number
# of delays we are solving for. Typically this isn't true, so we probably want
# `_draw_signal_sample_t`
Si = np.empty_like(S)
Sh = np.empty((N, nd, nd), dtype=S.dtype)
for ii in range(N):
inv = la.inv(S[:, :, ii])
Si[:, :, ii] = inv
Sh[ii, :, :] = la.cholesky(S[:, :, ii], lower=False)
Ci = FTNiF.copy()
for ii in range(nd):
for jj in range(nd):
Ci[ii, :, jj] += np.diag(Si[ii, jj])
w1 = random.standard_complex_normal((N, nd, nsamp), rng=rng)
w2 = random.standard_complex_normal(data.shape, rng=rng)
# Construct the random signal sample by forming a perturbed vector and
# then doing a matrix solve
y = FTNih @ (data + w2)
for ii in range(N):
w1s = la.solve_triangular(
Sh[ii],
w1[ii],
overwrite_b=True,
lower=False,
check_finite=False,
)
y[:, ii] += w1s
# NOTE: Other combinations that you might think would work don't appear to
# be stable. Don't try these:
# y[:, ii] += Si[:, :, ii] @ Sh[:, :, ii] @ w1[:, ii]
# y[:, ii] += Shi[:, :, ii] @ w1[:, ii]
cf = la.cho_factor(
Ci.reshape(nd * N, nd * N),
overwrite_a=True,
check_finite=False,
)
return la.cho_solve(
cf,
y.reshape(nd * N, nsamp),
overwrite_b=True,
check_finite=False,
).reshape(nd, N, nsamp)
def _draw_signal_sample_t(S):
# This method is fastest if the number of delays is larger than the number of
# frequencies. This is usually the regime we are in.
raise NotImplementedError("Drawing samples in the time basis not yet written.")
def _draw_ps_sample(d):
# Draw a random delay power spectrum sample assuming the signal is Gaussian and
# we have a flat prior on the power spectrum.
# This means drawing from a inverse chi^2.
# Estimate the sample covariance
S = np.empty((nd, nd, N), dtype=np.complex128)
for ii in range(N):
S[:, :, ii] = np.cov(d[:, ii], bias=True)
# Then in place draw a sample of the true covariance from the posterior which
# is an inverse Wishart
for ii in range(N):
Si = la.inv(S[:, :, ii])
Si_samp = random.complex_wishart(Si, nsamp, rng=rng) / nsamp
S[:, :, ii] = la.inv(Si_samp)
return S
# Select the method to use for the signal sample based on how many frequencies
# versus delays there are. At the moment only the _f method is implemented.
_draw_signal_sample = _draw_signal_sample_f
# Perform the Gibbs sampling iteration for a given number of loops and
# return the power spectrum output of them.
try:
for ii in range(niter):
d_samp = _draw_signal_sample(S_samp)
S_samp = _draw_ps_sample(d_samp)
spec.append(S_samp)
except la.LinAlgError as e:
raise RuntimeError("Exiting earlier as singular") from e
return spec
[docs]
def delay_spectrum_fft(data, N, window="nuttall"):
"""Estimate the delay transform from an input frequency spectrum by IFFT.
This routine makes no attempt to account for data masking, and only
supports complex to complex fft.
Parameters
----------
data : np.ndarray[nsample, freq]
Data to estimate the delay spectrum of.
N : int
The length of the output delay spectrum. There are assumed to be `N/2 + 1`
total frequency channels if assuming a real delay spectrum, or `N` channels
for a complex delay spectrum.
window : one of {'nuttall', 'blackman_nuttall', 'blackman_harris', None}, optional
Apply an apodisation function. Default: 'nuttall'.
Returns
-------
y_spec : np.ndarray[nsample, ndelay]
Delay spectrum for each element of the `sample` axis.
"""
if window is not None:
wx = np.arange(N) / N
window = tools.window_generalised(wx, window=window)[np.newaxis]
data *= window
return fftw.ifft(data, axes=-1)
[docs]
def delay_spectrum_wiener_filter(
delay_PS, data, N, Ni, window="nuttall", fsel=None, complex_timedomain=False
):
"""Estimate the delay spectrum from an input frequency spectrum by Wiener filtering.
This routine estimates the spectrum at the `N` delay samples conjugate to
an input frequency spectrum with ``N/2 + 1`` channels (if the delay spectrum is
assumed real) or `N` channels (if the delay spectrum is assumed complex).
A subset of these channels can be specified using the `fsel` argument.
Parameters
----------
delay_PS : np.ndarray[ndelay]
Delay power spectrum to use for the signal covariance in the Wiener filter.
data : np.ndarray[nsample, freq]
Data to estimate the delay spectrum of.
N : int
The length of the output delay spectrum. There are assumed to be `N/2 + 1`
total frequency channels if assuming a real delay spectrum, or `N` channels
for a complex delay spectrum.
Ni : np.ndarray[freq]
Inverse noise variance.
fsel : np.ndarray[freq], optional
Indices of channels that we have data at. By default assume all channels.
window : one of {'nuttall', 'blackman_nuttall', 'blackman_harris', None}, optional
Apply an apodisation function. Default: 'nuttall'.
complex_timedomain : bool, optional
If True, assume input data arose from a complex timestream. If False, assume
input data arose from a real timestream, such that the first and last frequency
channels have purely real values. Default: False.
Returns
-------
y_spec : np.ndarray[nsample, ndelay]
Delay spectrum for each element of the `sample` axis.
"""
# Pre-whiten and apply frequency window to data, and compute F^dagger N^{-1/2}
# and F^dagger N^{-1} F
data, FTNih, FTNiF = _compute_delay_spectrum_inputs(
data, N, Ni, fsel, window, complex_timedomain
)
# Apply F^dagger N^{-1/2} to input frequency spectrum
y = FTNih @ data
# Get the inverse signal variance
Si = tools.invert_no_zero(delay_PS)
# Construct the Wiener covariance
if complex_timedomain:
# If delay spectrum is complex, extend delay_PS to correspond to the individual
# real and imaginary components of the delay spectrum, each of which have
# power spectrum equal to 0.5 times the power spectrum of the complex
# delay spectrum, if the statistics are circularly symmetric
Si = 2.0 * np.repeat(Si, 2)
# Add the inverse signal component
np.einsum("ii->i", FTNiF)[:] += Si
# Do a cholesky decomposition of the covariance.
# This solve is pretty much always faster than a
# standard one
CiL = la.cho_factor(FTNiF, check_finite=False, lower=False)
# Solve the linear equation for the Wiener-filtered spectrum,
# and transpose to [sample_axis, delay]
y_spec = la.cho_solve(CiL, y, check_finite=False).T
if complex_timedomain:
y_spec = _alternating_real_to_complex(y_spec)
return y_spec
[docs]
def null_delay_filter(
freq,
delay_cut,
mask,
num_delay=200,
tol=1e-8,
window=True,
type_="high",
lapack_driver="gesvd",
):
"""Take frequency data and null out any delays below some value.
Parameters
----------
freq : np.ndarray[freq]
Frequencies we have data at.
delay_cut : float
Delay cut to apply.
mask : np.ndarray[freq]
Frequencies to mask out.
num_delay : int, optional
Number of delay values to use.
tol : float, optional
Cut off value for singular values.
window : bool, optional
Apply a window function to the data while filtering.
type_ : str, optional
Whether to apply a high-pass or low-pass filter. Options are
`high` or `low`. Default is `high`.
lapack_driver : str, optional
Which lapack driver to use in the SVD. Options are 'gesvd' or 'gesdd'.
'gesdd' is generally faster, but seems to experience convergence issues.
Default is 'gesvd'.
Returns
-------
filter : np.ndarray[freq, freq]
The filter as a 2D matrix.
"""
if type_ not in {"high", "low"}:
raise ValueError(f"Filter type must be one of [high, low]. Got {type_}")
# Construct the window function
x = (freq - freq.min()) / np.ptp(freq)
w = tools.window_generalised(x, window="nuttall")
delay = np.linspace(-delay_cut, delay_cut, num_delay)
# Construct the Fourier matrix
F = mask[:, np.newaxis] * np.exp(
2.0j * np.pi * delay[np.newaxis, :] * freq[:, np.newaxis]
)
if window:
F *= w[:, np.newaxis]
# Use an SVD to figure out the set of significant modes spanning the delays
# we are wanting to get rid of.
# NOTE: we've experienced some convergence failures in here which ultimately seem
# to be the fault of MKL (see https://github.com/scipy/scipy/issues/10032 and links
# therein). This seems to be limited to the `gesdd` LAPACK routine, so we can get
# around it by switching to `gesvd`.
u, sig, vh = la.svd(F, full_matrices=False, lapack_driver=lapack_driver)
nmodes = np.sum(sig > tol * sig.max())
# Select the modes to null out based on the filter type
if type_ == "high":
p = u[:, :nmodes]
elif type_ == "low":
p = u[:, nmodes:]
# Construct a projection matrix for the filter
proj = np.identity(len(freq)) - np.dot(p, p.T.conj())
proj *= mask[np.newaxis, :]
if window:
proj *= w[np.newaxis, :]
return proj
# ----------------------------------------
# Helper functions for array manipulation
# ----------------------------------------
[docs]
def match_axes(dset1, dset2):
"""Make sure that dset2 has the same set of axes as dset1.
Sometimes the weights are missing axes (usually where the entries would all be
the same), we need to map these into one another and expand the weights to the
same size as the visibilities. This assumes that the vis/weight axes are in the
same order when present
Parameters
----------
dset1
The dataset with more axes.
dset2
The dataset with a subset of axes. For the moment these are assumed to be in
the same order.
Returns
-------
dset2_view
A view of dset2 with length-1 axes inserted to match the axes missing from
dset1.
"""
axes1 = dset1.attrs["axis"]
axes2 = dset2.attrs["axis"]
bcast_slice = tuple(slice(None) if ax in axes2 else np.newaxis for ax in axes1)
return dset2[:][bcast_slice]
[docs]
def flatten_axes(
dset: memh5.MemDatasetDistributed,
axes_to_keep: list[str],
match_dset: memh5.MemDatasetDistributed | None = None,
) -> tuple[mpiarray.MPIArray, list[str]]:
"""Move the specified axes of the dataset to the back, and flatten all others.
Optionally this will add length-1 axes to match the axes of another dataset.
Parameters
----------
dset
The dataset to reshape.
axes_to_keep
The names of the axes to keep.
match_dset
An optional dataset to match the shape of.
Returns
-------
flat_array
The MPIArray representing the re-arranged dataset. Distributed along the
flattened axis.
flat_axes
The names of the flattened axes from slowest to fastest varying.
"""
# Find the relevant axis positions
data_axes = list(dset.attrs["axis"])
# Check that the requested datasets actually exist
for axis in axes_to_keep:
if axis not in data_axes:
raise ValueError(f"Specified {axis=} not present in dataset.")
# If specified, add extra axes to match the shape of the given dataset
if match_dset and tuple(dset.attrs["axis"]) != tuple(match_dset.attrs["axis"]):
dset_full = np.empty_like(match_dset[:])
dset_full[:] = match_axes(match_dset, dset)
axes_ind = [data_axes.index(axis) for axis in axes_to_keep]
# Get an MPIArray and make sure it is distributed along one of the preserved axes
data_array = dset[:]
if data_array.axis not in axes_ind:
data_array = data_array.redistribute(axes_ind[0])
# Create a view of the dataset with the relevant axes at the back,
# and all others moved to the front (retaining their relative order)
other_axes = [ax for ax in range(len(data_axes)) if ax not in axes_ind]
data_array = data_array.transpose(other_axes + axes_ind)
# Get the explicit shape of the axes that will remain, but set the distributed one
# to None (as will be needed for MPIArray.reshape)
remaining_shape = list(data_array.shape)
remaining_shape[data_array.axis] = None
new_ax_len = np.prod(remaining_shape[: -len(axes_ind)])
remaining_shape = remaining_shape[-len(axes_ind) :]
# Reshape the MPIArray, and redistribute over the flattened axis
data_array = data_array.reshape((new_ax_len, *remaining_shape))
data_array = data_array.redistribute(axis=0)
other_axes_names = [data_axes[ax] for ax in other_axes]
return data_array, other_axes_names
def _move_front(arr: np.ndarray, axis: int, shape: tuple) -> np.ndarray:
# Move the specified axis to the front and flatten to give a 2D array
new_arr = np.moveaxis(arr, axis, 0)
return new_arr.reshape(shape[axis], -1)
def _inv_move_front(arr: np.ndarray, axis: int, shape: tuple) -> np.ndarray:
# Move the first axis back to it's original position and return the original shape,
# i.e. reverse the above operation
rshape = (shape[axis],) + shape[:axis] + shape[(axis + 1) :]
new_arr = arr.reshape(rshape)
new_arr = np.moveaxis(new_arr, 0, axis)
return new_arr.reshape(shape)
def _take_view(arr: np.ndarray, ind: int, axis: int) -> np.ndarray:
# Like np.take but returns a view (instead of a copy), but only supports a scalar
# index
sl = (slice(None),) * axis
return arr[(*sl, ind)]