Source code for gunpowder.nodes.scan

import logging
import multiprocessing
from abc import ABC

import numpy as np
import tqdm

from gunpowder.array import Array
from gunpowder.batch import Batch
from gunpowder.coordinate import Coordinate
from gunpowder.graph import Graph
from gunpowder.producer_pool import ProducerPool
from gunpowder.roi import Roi

from .batch_filter import BatchFilter

logger = logging.getLogger(__name__)


[docs] class ScanCallback(ABC): """Base class for :class:`Scan` callbacks. Implement any of ``start``, ``update``, and ``stop`` in a subclass to create your own callback. """ def start(self, num_total): """Called once before :class:`Scan` starts scanning over chunks. Args: num_total (int): The total number of chunks to process. """ pass def update(self, num_processed): """Called periodically by :class:`Scan` while processing chunks. Args: num_processed (int): The number of chunks already processed. """ pass def stop(self): """Called once after :class:`Scan` scanned over all chunks.""" pass
class TqdmCallback(ScanCallback): """A default callback that uses ``tqdm`` to show a progress bar.""" def start(self, num_total): logger.info("scanning over %d chunks", num_total) self.progress_bar = tqdm.tqdm(desc="Scan, chunks processed", total=num_total) self.num_processed = 0 def update(self, num_processed): self.progress_bar.update(num_processed - self.num_processed) self.num_processed = num_processed def stop(self): self.progress_bar.close()
[docs] class Scan(BatchFilter): """Iteratively requests batches of size ``reference`` from upstream providers in a scanning fashion, until all requested ROIs are covered. If the batch request to this node is empty, it will scan the complete upstream ROIs (and return nothing). Otherwise, it scans only the requested ROIs and returns a batch assembled of the smaller requests. In either case, the upstream requests will be contained in the downstream requested ROI or upstream ROIs. See also :class:`Hdf5Write`. Args: reference (:class:`BatchRequest`): A reference :class:`BatchRequest`. This request will be shifted in a scanning fashion over the upstream ROIs of the requested arrays or points. num_workers (``int``, optional): If set to >1, upstream requests are made in parallel with that number of workers. cache_size (``int``, optional): If multiple workers are used, how many batches to hold at most. progress_callback (class:`ScanCallback`, optional): A callback instance to get updated from this node while processing chunks. See :class:`ScanCallback` for details. The default is a callback that shows a ``tqdm`` progress bar. """ def __init__(self, reference, num_workers=1, cache_size=50, progress_callback=None): self.reference = reference.copy() self.num_workers = num_workers self.cache_size = cache_size self.workers = None self.batch = None if progress_callback is None: self.progress_callback = TqdmCallback() else: self.progress_callback = progress_callback def setup(self): if self.num_workers > 1: self.request_queue = multiprocessing.Queue(maxsize=0) self.workers = ProducerPool( [self._worker_get_chunk for _ in range(self.num_workers)], queue_size=self.cache_size, ) self.workers.start() def teardown(self): if self.num_workers > 1: self.workers.stop() def provide(self, request): empty_request = len(request) == 0 if empty_request: scan_spec = self.spec else: scan_spec = request stride = self._get_stride() shift_roi = self._get_shift_roi(scan_spec) shifts = self._enumerate_shifts(shift_roi, stride) num_chunks = len(shifts) if self.progress_callback is not None: self.progress_callback.start(num_chunks) # the batch to return self.batch = Batch() if self.num_workers > 1: for shift in shifts: shifted_reference = self._shift_request(self.reference, shift) self.request_queue.put(shifted_reference) for i in range(num_chunks): chunk = self.workers.get() if not empty_request: self._add_to_batch(request, chunk) if self.progress_callback is not None: self.progress_callback.update(i + 1) logger.debug("processed chunk %d/%d", i + 1, num_chunks) else: for i, shift in enumerate(shifts): shifted_reference = self._shift_request(self.reference, shift) chunk = self._get_chunk(shifted_reference) if not empty_request: self._add_to_batch(request, chunk) if self.progress_callback is not None: self.progress_callback.update(i + 1) logger.debug("processed chunk %d/%d", i + 1, num_chunks) if self.progress_callback is not None: self.progress_callback.stop() batch = self.batch self.batch = None logger.debug("returning batch %s", batch) return batch def _get_stride(self): """Get the maximal amount by which ``reference`` can be moved, such that it tiles the space.""" stride = None # get the least common multiple of all voxel sizes, we have to stride # at least that far lcm_voxel_size = self.spec.get_lcm_voxel_size(self.reference.array_specs.keys()) # that's just the minimal size in each dimension for key, reference_spec in self.reference.items(): shape = reference_spec.roi.shape for d in range(len(lcm_voxel_size)): assert shape[d] >= lcm_voxel_size[d], ( "Shape of reference " "ROI %s for %s is " "smaller than least " "common multiple of " "voxel size " "%s" % (reference_spec.roi, key, lcm_voxel_size) ) if stride is None: stride = shape else: stride = Coordinate((min(a, b) for a, b in zip(stride, shape))) return stride def _get_shift_roi(self, spec): """Get the minimal and maximal shift (as a ROI) to apply to ``self.reference``, such that it is still fully contained in ``spec``. """ total_shift_roi = None # get individual shift ROIs and intersect them for key, reference_spec in self.reference.items(): logger.debug("getting shift roi for %s with spec %s", key, reference_spec) if key not in spec: logger.debug("skipping, %s not in upstream spec", key) continue if spec[key].roi is None: logger.debug("skipping, %s has not ROI", key) continue logger.debug("upstream ROI is %s", spec[key].roi) for r, s in zip(reference_spec.roi.shape, spec[key].roi.shape): assert s is None or r <= s, ( "reference %s with ROI %s does not fit into provided " "upstream %s" % (key, reference_spec.roi, spec[key].roi) ) # we have a reference ROI # # [--------) [9] # 3 12 # # and a spec ROI # # [---------------) [16] # 16 32 # # min and max shifts of reference are # # [--------) [9] # 16 25 # [--------) [9] # 23 32 # # therefore, all possible ways to shift the reference such that it # is contained in the spec are at least 16-3=13 and at most 23-3=20 # (inclusive) # # [-------) [8] # 13 21 # # 1. the starting point is beginning of spec - beginning of reference # 2. the length is length of spec - length of reference + 1 # 1. get the starting point of the shift ROI shift_begin = spec[key].roi.begin - reference_spec.roi.begin # 2. get the shape of the shift ROI shift_shape = spec[key].roi.shape - reference_spec.roi.shape + 1 # create a ROI... shift_roi = Roi(shift_begin, shift_shape) logger.debug("shift ROI for %s is %s", key, shift_roi) # ...and intersect it with previous shift ROIs if total_shift_roi is None: total_shift_roi = shift_roi else: total_shift_roi = total_shift_roi.intersect(shift_roi) if total_shift_roi.empty: raise RuntimeError( "There is no location where the ROIs " "the reference %s are contained in the " "request/upstream ROIs " "%s." % (self.reference, spec) ) logger.debug( "intersected with total shift ROI this yields %s", total_shift_roi ) if total_shift_roi is None: raise RuntimeError( "None of the upstream ROIs are bounded (all " "ROIs are None). Scan needs at least one " "bounded upstream ROI." ) return total_shift_roi def _enumerate_shifts(self, shift_roi, stride): """Produces a sequence of shift coordinates starting at the beginning of ``shift_roi``, progressing with ``stride``. The maximum shift coordinate in any dimension will be the last point inside the shift roi in this dimension.""" min_shift = shift_roi.offset max_shift = max(min_shift, Coordinate(m - 1 for m in shift_roi.end)) shift = np.array(min_shift) shifts = [] dims = len(min_shift) logger.debug("enumerating possible shifts of %s in %s", stride, shift_roi) while True: logger.debug("adding %s", shift) shifts.append(Coordinate(shift)) if (shift == max_shift).all(): break # count up dimensions for d in range(dims): if shift[d] >= max_shift[d]: if d == dims - 1: break shift[d] = min_shift[d] else: shift[d] += stride[d] # snap to last possible shift, don't overshoot if shift[d] > max_shift[d]: shift[d] = max_shift[d] break return shifts def _shift_request(self, request, shift): shifted = request.copy() for _, spec in shifted.items(): spec.roi = spec.roi.shift(shift) return shifted def _worker_get_chunk(self): request = self.request_queue.get() return self._get_chunk(request) def _get_chunk(self, request): return self.get_upstream_provider().request_batch(request) def _add_to_batch(self, spec, chunk): if self.batch.get_total_roi() is None: self.batch = self._setup_batch(spec, chunk) self.batch.profiling_stats.merge_with(chunk.profiling_stats) for array_key, array in chunk.arrays.items(): if array_key not in spec: continue self._fill( self.batch.arrays[array_key].data, array.data, spec.array_specs[array_key].roi, array.spec.roi, self.spec[array_key].voxel_size, ) for graph_key, graphs in chunk.graphs.items(): if graph_key not in spec: continue self._fill_points( self.batch.graphs[graph_key], graphs, spec.graph_specs[graph_key].roi, graphs.spec.roi, ) def _setup_batch(self, batch_spec, chunk): """Allocate a batch matching the sizes of ``batch_spec``, using ``chunk`` as template.""" batch = Batch() for array_key, spec in batch_spec.array_specs.items(): roi = spec.roi voxel_size = self.spec[array_key].voxel_size # get the 'non-spatial' shape of the chunk-batch # and append the shape of the request to it array = chunk.arrays[array_key] shape = array.data.shape[: -roi.dims] shape += roi.shape // voxel_size spec = self.spec[array_key].copy() spec.roi = roi logger.info("allocating array of shape %s for %s", shape, array_key) batch.arrays[array_key] = Array( data=np.zeros(shape, dtype=spec.dtype), spec=spec ) for graph_key, spec in batch_spec.graph_specs.items(): roi = spec.roi spec = self.spec[graph_key].copy() spec.roi = roi batch.graphs[graph_key] = Graph(nodes=[], edges=[], spec=spec) logger.debug("setup batch to fill %s", batch) return batch def _fill(self, a, b, roi_a, roi_b, voxel_size): logger.debug("filling " + str(roi_b) + " into " + str(roi_a)) roi_a = roi_a // voxel_size roi_b = roi_b // voxel_size common_roi = roi_a.intersect(roi_b) if common_roi.empty: return common_in_a_roi = common_roi - roi_a.offset common_in_b_roi = common_roi - roi_b.offset slices_a = common_in_a_roi.get_bounding_box() slices_b = common_in_b_roi.get_bounding_box() if len(a.shape) > len(slices_a): slices_a = (slice(None),) * (len(a.shape) - len(slices_a)) + slices_a slices_b = (slice(None),) * (len(b.shape) - len(slices_b)) + slices_b a[slices_a] = b[slices_b] def _fill_points(self, a, b, roi_a, roi_b): """ Take points from b and add them to a. Nodes marked temporary must be ignored. Temporary nodes are nodes that were created during processing. Since it is impossible to know in general, that a node created during processing of a subgraph was not assigned an id that is already used by the full graph, we cannot include temporary nodes and assume there will not be ambiguous node id's that correspond to multiple distinct nodes. """ logger.debug("filling points of " + str(roi_b) + " into points of" + str(roi_a)) common_roi = roi_a.intersect(roi_b) if common_roi is None: return for node in b.nodes: if not node.temporary and roi_a.contains(node.location): a.add_node(node) for e in b.edges: bu = b.node(e.u) bv = b.node(e.v) if ( not bu.temporary and not bv.temporary and a.contains(bu.id) and a.contains(bv.id) ): a.add_edge(e)