Source code for draco.analysis.sourcestack

"""Source Stack Analysis Tasks."""

import numpy as np
from caput import config, pipeline
from cora.util import units
from mpi4py import MPI

from ..core import containers, task
from ..util.random import RandomTask
from ..util.tools import invert_no_zero

# Constants
NU21 = units.nu21
C = units.c


[docs] class SourceStack(task.SingleTask): """Stack the product of `draco.analysis.BeamForm` accross sources. For this to work BeamForm must have been run with `collapse_ha = True` (default). Attributes ---------- freqside : int Number of frequency bins to keep on each side of source bin when stacking. Default: 50. single_source_bin_index : int, optional Only stack on sources in frequency bin with this index. Useful for isolating stacking signal from a narrow frequency range. Default: None. """ # Number of frequencies to keep on each side of source RA freqside = config.Property(proptype=int, default=50) # Only consider sources within frequency channel with this index single_source_bin_index = config.Property(proptype=int, default=None)
[docs] def process(self, formed_beam): """Receives a formed beam object and stack across sources. Parameters ---------- formed_beam : `containers.FormedBeam` object Formed beams to stack over sources. Returns ------- stack : `containers.FrequencyStack` object The stack of sources. """ # Get communicator comm = formed_beam.comm # Ensure formed_beam is distributed in sources formed_beam.redistribute("object_id") # Local shape and offset loff = formed_beam.beam.local_offset[0] lshape = formed_beam.beam.local_shape[0] # Frequency axis freq = formed_beam.freq nfreq = len(freq) # Polarisation axis pol = formed_beam.pol npol = len(pol) # Frequency of sources in MHz source_freq = NU21 / (formed_beam["redshift"]["z"][loff : loff + lshape] + 1.0) # Size of source stack array self.nstack = 2 * self.freqside + 1 # Construct frequency offset axis (for stack container) self.stack_axis = np.copy( formed_beam.frequency[ int(nfreq / 2) - self.freqside : int(nfreq / 2) + self.freqside + 1 ] ) self.stack_axis["centre"] = ( self.stack_axis["centre"] - self.stack_axis["centre"][self.freqside] ) # Get f_mask and source_indices freqdiff = freq[np.newaxis, :] - source_freq[:, np.newaxis] # Stack axis bin edges to digitize each source at, in either increasing # or decreasing order depending on order of frequencies if self.stack_axis["centre"][0] > self.stack_axis["centre"][-1]: stackbins = self.stack_axis["centre"] + 0.5 * self.stack_axis["width"] stackbins = np.append( stackbins, self.stack_axis["centre"][-1] - 0.5 * self.stack_axis["width"][-1], ) else: stackbins = self.stack_axis["centre"] - 0.5 * self.stack_axis["width"] stackbins = np.append( stackbins, self.stack_axis["centre"][-1] + 0.5 * self.stack_axis["width"][-1], ) # Index of each frequency in stack axis, for each source source_indices = np.digitize(freqdiff, stackbins) - 1 # Indices to be processed in full frequency axis for each source f_mask = (source_indices >= 0) & (source_indices < self.nstack) # Only sources in the frequency range of the data. source_mask = (np.sum(f_mask, axis=1) > 0).astype(bool) # If desired, also restrict to sources within a specific channel. # This works because the frequency axis is not distributed between # ranks. if self.single_source_bin_index is not None: fs = formed_beam.index_map["freq"][self.single_source_bin_index] restricted_chan_mask = np.abs(source_freq - fs["centre"]) < ( 0.5 * fs["width"] ) source_mask *= restricted_chan_mask # Container to hold the stack if npol > 1: stack = containers.FrequencyStackByPol( freq=self.stack_axis, pol=pol, attrs_from=formed_beam ) else: stack = containers.FrequencyStack( freq=self.stack_axis, attrs_from=formed_beam ) # Loop over polarisations for pp, pstr in enumerate(pol): fb = formed_beam.beam[:, pp].view(np.ndarray) fw = formed_beam.weight[:, pp].view(np.ndarray) # Source stack array. source_stack = np.zeros(self.nstack, dtype=np.float64) source_weight = np.zeros(self.nstack, dtype=np.float64) count = 0 # Source counter # For each source in the range of this process for lq in range(lshape): if not source_mask[lq]: # Source not in the data redshift range continue count += 1 # Indices and slice for frequencies included in the stack. f_indices = np.arange(nfreq, dtype=np.int32)[f_mask[lq]] f_slice = np.s_[f_indices[0] : f_indices[-1] + 1] source_stack += np.bincount( source_indices[lq, f_slice], weights=fw[lq, f_slice] * fb[lq, f_slice], minlength=self.nstack, ) source_weight += np.bincount( source_indices[lq, f_slice], weights=fw[lq, f_slice], minlength=self.nstack, ) # Gather source stack for all ranks. Each contains the sum # over a different subset of sources. source_stack_full = np.zeros( comm.size * self.nstack, dtype=source_stack.dtype ) source_weight_full = np.zeros( comm.size * self.nstack, dtype=source_weight.dtype ) # Gather all ranks comm.Allgather(source_stack, source_stack_full) comm.Allgather(source_weight, source_weight_full) # Determine the index for the output container oslc = (pp, slice(None)) if npol > 1 else slice(None) # Sum across ranks stack.weight[oslc] = np.sum( source_weight_full.reshape(comm.size, self.nstack), axis=0 ) stack.stack[oslc] = np.sum( source_stack_full.reshape(comm.size, self.nstack), axis=0 ) * invert_no_zero(stack.weight[oslc]) # Gather all ranks of count. Report number of sources stacked full_count = comm.reduce(count, op=MPI.SUM, root=0) if comm.rank == 0: self.log.info(f"Number of sources stacked for pol {pstr}: {full_count}") return stack
[docs] class RandomSubset(task.SingleTask, RandomTask): """Take a large mock catalog and draw `number` catalogs of a given `size`. Attributes ---------- number : int Number of catalogs to construct. size : int Number of objects in each catalog. """ number = config.Property(proptype=int) size = config.Property(proptype=int) def __init__(self): super().__init__() self.catalog_ind = 0
[docs] def setup(self, catalog): """Set the full mock catalog. Parameters ---------- catalog : containers.SourceCatalog or containers.FormedBeam The mock catalog to draw from. """ # If the catalog is distributed, then we need to make sure that it # is distributed over an axis other than the object_id axis. if catalog.distributed: axis_size = { key: len(val) for key, val in catalog.index_map.items() if key != "object_id" } if len(axis_size) > 0: self.distributed_axis = max(axis_size, key=axis_size.get) self.log.info( f"Distributing over the {self.distributed_axis} axis " "to take random subsets of objects." ) catalog.redistribute(self.distributed_axis) else: raise ValueError( "The catalog that was provided is distributed " "over the object_id axis. Unable to take a " "random subset over object_id." ) else: self.distributed_axis = None if "tag" in catalog.attrs: self.base_tag = f"{catalog.attrs['tag']}_mock_{{:05d}}" else: self.base_tag = "mock_{{:05d}}" self.catalog = catalog
[docs] def process(self): """Draw a new random catalog. Returns ------- new_catalog : containers.SourceCatalog or containers.FormedBeam A catalog of the same type as the input catalog, with a random set of objects. """ if self.catalog_ind >= self.number: raise pipeline.PipelineStopIteration objects = self.catalog.index_map["object_id"] num_cat = len(objects) # NOTE: We need to be very careful here, the RNG is initialised at first access # and this is a collective operation. So we need to ensure all ranks do it even # though only rank=0 is going to use the RNG in this task rng = self.rng # Generate a random selection of objects on rank=0 and broadcast to all other # ranks if self.comm.rank == 0: ind = np.sort(rng.choice(num_cat, self.size, replace=False)) else: ind = np.zeros(self.size, dtype=np.int64) self.comm.Bcast(ind, root=0) # Create new container new_catalog = self.catalog.__class__( object_id=objects[ind], attrs_from=self.catalog, axes_from=self.catalog, comm=self.catalog.comm, ) for name in self.catalog.datasets.keys(): if name not in new_catalog.datasets: new_catalog.add_dataset(name) if self.distributed_axis is not None: new_catalog.redistribute(self.distributed_axis) new_catalog.attrs["tag"] = self.base_tag.format(self.catalog_ind) # Loop over all datasets and if they have an object_id axis, select the # relevant objects along that axis for name, dset in self.catalog.datasets.items(): if dset.attrs["axis"][0] == "object_id": new_catalog.datasets[name][:] = dset[:][ind] else: new_catalog.datasets[name][:] = dset[:] self.catalog_ind += 1 return new_catalog
[docs] class GroupSourceStacks(task.SingleTask): """Accumulate many frequency stacks into a single container. Attributes ---------- ngroup : int The number of frequency stacks to accumulate into a single container. """ ngroup = config.Property(proptype=int, default=100)
[docs] def setup(self): """Create a list to be populated by the process method.""" self.stack = [] self.nmock = 0 self.counter = 0 self._container_lookup = { containers.FrequencyStack: containers.MockFrequencyStack, containers.FrequencyStackByPol: containers.MockFrequencyStackByPol, containers.MockFrequencyStack: containers.MockFrequencyStack, containers.MockFrequencyStackByPol: containers.MockFrequencyStackByPol, }
[docs] def process(self, stack): """Add a FrequencyStack to the list. As soon as list contains `ngroup` items, they will be collapsed into a single container and output by the task. Parameters ---------- stack : containers.FrequencyStack, containers.FrequencyStackByPol, containers.MockFrequencyStack, containers.MockFrequencyStackByPol Returns ------- out : containers.MockFrequencyStack, containers.MockFrequencyStackByPol The previous `ngroup` FrequencyStacks accumulated into a single container. """ self.stack.append(stack) if "mock" in stack.index_map: self.nmock += stack.index_map["mock"].size else: self.nmock += 1 self.log.info( f"Collected frequency stack. Current size is {len(self.stack):d}." ) if (len(self.stack) % self.ngroup) == 0: return self._reset() return None
[docs] def process_finish(self): """Return whatever FrequencyStacks are currently in the list. Returns ------- out : containers.MockFrequencyStack, containers.MockFrequencyStackByPol The remaining frequency stacks accumulated into a single container. """ if len(self.stack) > 0: return self._reset() return None
def _reset(self): """Combine all frequency stacks currently in the list into new container. Then, empty the list, reset the stack counter, and increment the group counter. """ self.log.info( f"We have accumulated {self.nmock:d} mock realizations. " f"Saving to file. [group {self.counter:03d}]" ) mock = np.arange(self.nmock, dtype=np.int64) # Create the output container OutputContainer = self._container_lookup[self.stack[0].__class__] out = OutputContainer( mock=mock, axes_from=self.stack[0], attrs_from=self.stack[0] ) counter_str = f"{self.counter:03d}" # Update tag using the hierarchy that a group contains multiple mocks, # and a supergroup contains multiple groups. if "tag" in out.attrs: tag = out.attrs["tag"].split("_") if "group" in tag: ig = max(ii for ii, tt in enumerate(tag) if tt == "group") tag[ig] = "supergroup" tag[ig + 1] = counter_str elif "mock" in tag: im = max(ii for ii, tt in enumerate(tag) if tt == "mock") tag[im] = "group" tag[im + 1] = counter_str else: tag.append(f"group_{counter_str}") out.attrs["tag"] = "_".join(tag) else: out.attrs["tag"] = f"group_{counter_str}" for name in self.stack[0].datasets.keys(): if name not in out.datasets: out.add_dataset(name) # Loop over mock stacks and save to output container for name, odset in out.datasets.items(): mock_count = 0 for ss, stack in enumerate(self.stack): dset = stack.datasets[name] if dset.attrs["axis"][0] == "mock": data = dset[:] else: data = dset[np.newaxis, ...] for mdata in data: odset[mock_count] = mdata[:] mock_count += 1 # Reset the class attributes self.stack = [] self.nmock = 0 self.counter += 1 return out