From a760a5b417a36ecd3498b6abe459402f8379ee6a Mon Sep 17 00:00:00 2001 From: jwong-nd Date: Fri, 23 Jun 2023 11:57:48 -0700 Subject: [PATCH 1/7] Adding support for Zarr stitching/fusion --- coarse_registration.py | 120 +++++++++++ processor/warp.py | 71 +++---- setup.cfg | 1 + stitch_elastic.py | 4 +- zarr_processor.py | 440 +++++++++++++++++++++++++++++++++++++++++ zarr_utils.py | 114 +++++++++++ 6 files changed, 703 insertions(+), 47 deletions(-) create mode 100644 coarse_registration.py create mode 100644 zarr_processor.py create mode 100644 zarr_utils.py diff --git a/coarse_registration.py b/coarse_registration.py new file mode 100644 index 0000000..5a015bf --- /dev/null +++ b/coarse_registration.py @@ -0,0 +1,120 @@ +import functools as ft +import gc +import jax +import numpy as np +import tensorstore as ts + +from sofima import flow_field + + +QUERY_R_ORTHO = 100 +QUERY_OVERLAP_OFFSET = 0 # Overlap = 'starting line' in neighboring tile +QUERY_R_OVERLAP = 100 + +SEARCH_OVERLAP = 300 # Boundary - overlap = 'starting line' in search tile +SEARCH_R_ORTHO = 100 + + +@ft.partial(jax.jit) +def _estimate_relative_offset_zyx(base, + kernel + ) -> list[float, float, float]: + # Calculate FFT: left = base, right = kernel + xc = flow_field.masked_xcorr(base, kernel, use_jax=True, dim=3) + xc = xc.astype(np.float32) + xc = xc[None, ...] + + # Find strongest peak in FFT, pass in FFT image center + r = flow_field._batched_peaks(xc, + ((xc.shape[1] + 1) // 2, (xc.shape[2] + 1) // 2, xc.shape[3] // 2), + min_distance=2, + threshold_rel=0.5) + + # r returns a list, relative offset is here + relative_offset_xyz = r[0][0:3] + return [relative_offset_xyz[2], relative_offset_xyz[1], relative_offset_xyz[0]] + + +def _estimate_h_offset_zyx(left_tile: ts.TensorStore, + right_tile: ts.TensorStore + ) -> tuple[list[float], float]: + tile_size_xyz = left_tile.shape + mz = tile_size_xyz[2] // 2 + my = tile_size_xyz[1] // 2 + + # Search Space, fixed + left = left_tile[tile_size_xyz[0]-SEARCH_OVERLAP:, + my-SEARCH_R_ORTHO:my+SEARCH_R_ORTHO, + mz-SEARCH_R_ORTHO:mz+SEARCH_R_ORTHO].read().result().T + + # Query Patch, scanned against search space + right = right_tile[QUERY_OVERLAP_OFFSET:QUERY_OVERLAP_OFFSET + QUERY_R_OVERLAP*2, + my-QUERY_R_ORTHO:my+QUERY_R_ORTHO, + mz-QUERY_R_ORTHO:mz+QUERY_R_ORTHO].read().result().T + + start_zyx = np.array(left.shape) // 2 - np.array(right.shape) // 2 + pc_init_zyx = np.array([0, 0, tile_size_xyz[0] - SEARCH_OVERLAP + start_zyx[2]]) + pc_zyx = np.array(_estimate_relative_offset_zyx(left, right)) + + return pc_init_zyx + pc_zyx + + +def _estimate_v_offset_zyx(top_tile: ts.TensorStore, + bot_tile: ts.TensorStore, + ) -> tuple[list[float], float]: + tile_size_xyz = top_tile.shape + mz = tile_size_xyz[2] // 2 + mx = tile_size_xyz[0] // 2 + + top = top_tile[mx-SEARCH_R_ORTHO:mx+SEARCH_R_ORTHO, + tile_size_xyz[1]-SEARCH_OVERLAP:, + mz-SEARCH_R_ORTHO:mz+SEARCH_R_ORTHO].read().result().T + bot = bot_tile[mx-QUERY_R_ORTHO:mx+QUERY_R_ORTHO, + 0:QUERY_R_OVERLAP*2, + mz-QUERY_R_ORTHO:mz+QUERY_R_ORTHO].read().result().T + + start_zyx = np.array(top.shape) // 2 - np.array(bot.shape) // 2 + pc_init_zyx = np.array([0, tile_size_xyz[1] - SEARCH_OVERLAP + start_zyx[1], 0]) + pc_zyx = np.array(_estimate_relative_offset_zyx(top, bot)) + + return pc_init_zyx + pc_zyx + + +def compute_coarse_offsets(tile_layout: np.ndarray, + tile_volumes: list[ts.TensorStore] + ) -> tuple[np.ndarray, np.ndarray]: + layout_y, layout_x = tile_layout.shape + + # Output Containers, sofima uses cartesian convention + conn_x = np.full((3, 1, layout_y, layout_x), np.nan) + conn_y = np.full((3, 1, layout_y, layout_x), np.nan) + + # Row Pairs + for y in range(layout_y): + for x in range(layout_x - 1): # Stop one before the end + left_id = tile_layout[y, x] + right_id = tile_layout[y, x + 1] + left_tile = tile_volumes[left_id] + right_tile = tile_volumes[right_id] + + conn_x[:, 0, y, x] = _estimate_h_offset_zyx(left_tile, right_tile) + gc.collect() + + print(f'Left Id: {left_id}, Right Id: {right_id}') + print(f'Left: ({y}, {x}), Right: ({y}, {x + 1})', conn_x[:, 0, y, x]) + + # Column Pairs -- Reversed Loops + for x in range(layout_x): + for y in range(layout_y - 1): + top_id = tile_layout[y, x] + bot_id = tile_layout[y + 1, x] + top_tile = tile_volumes[top_id] + bot_tile = tile_volumes[bot_id] + + conn_y[:, 0, y, x] = _estimate_v_offset_zyx(top_tile, bot_tile) + gc.collect() + + print(f'Top Id: {top_id}, Bottom Id: {bot_id}') + print(f'Top: ({y}, {x}), Bot: ({y + 1}, {x})', conn_y[:, 0, y, x]) + + return conn_x, conn_y \ No newline at end of file diff --git a/processor/warp.py b/processor/warp.py index 7d2fab1..ff3283f 100644 --- a/processor/warp.py +++ b/processor/warp.py @@ -36,7 +36,7 @@ class StitchAndRender3dTiles(subvolume_processor.SubvolumeProcessor): """Renders a volume by stitching 3d tiles placed on a 2d grid.""" _tile_meshes = None - _tile_idx_to_xy = None + _mesh_index_to_xy = {} _tile_boxes = {} _inverted_meshes = {} @@ -44,28 +44,21 @@ class StitchAndRender3dTiles(subvolume_processor.SubvolumeProcessor): def __init__( self, - *, - tile_map: Sequence[Sequence[int]], - tile_mesh_path: str, - tile_pattern_path: str, + tile_layout: Sequence[Sequence[int]], + tile_mesh: str, + xy_to_mesh_index: dict[int, tuple], stride: ZYX, offset: XYZ = (0, 0, 0), margin: int = 0, work_size: XYZ = (128, 128, 128), order: int = 1, - parallelism: int = 16, - input_volinfo=None, + parallelism: int = 16 ): """Constructor. Args: tile_map: yx-shaped grid of tile IDs - tile_mesh_path: path to a npz file containing 'key_to_idx' and 'x' arrays, - as generated by `stitch_elastic.aggregate_arrays` and `mesh.solve_mesh`, - respectively - tile_pattern_path: volinfo path for the volumes containing individual - tiles; must contain '{tile_id}', which will be substituted with values - from `tile_map` + tile_idx_to_xy: index stride: ZYX stride of the mesh in pixels offset: XYZ global offset to apply to the rendered image margin: number of pixels away from the tile boundary to ignore during @@ -75,12 +68,8 @@ def __init__( work_size: see `warp.ndimage_warp` order: see `warp.ndimage_warp` parallelism: see `warp.ndimage_warp` - input_volinfo: not used """ - del input_volinfo - self._tile_map = np.array(tile_map) - self._tile_mesh_path = tile_mesh_path - self._tile_pattern_path = tile_pattern_path + self._tile_layout = tile_layout self._stride = stride self._offset = offset self._margin = margin @@ -88,10 +77,18 @@ def __init__( self._parallelism = parallelism self._work_size = work_size - self._key_to_idx = {} - for y, row in enumerate(tile_map): + StitchAndRender3dTiles._tile_meshes = tile_mesh + StitchAndRender3dTiles._mesh_index_to_xy = { + v:k for k, v in xy_to_mesh_index.items() + } + assert StitchAndRender3dTiles._tile_meshes.shape[1] == len( + StitchAndRender3dTiles._mesh_index_to_xy + ) + + self._xy_to_tile_id = {} + for y, row in enumerate(tile_layout): for x, tile_id in enumerate(row): - self._key_to_idx[(x, y)] = tile_id + self._xy_to_tile_id[(x, y)] = tile_id def _open_tile_volume(self, tile_id: int) -> Any: """Returns a ZYX-shaped ndarray-like object representing the tile data.""" @@ -109,7 +106,7 @@ def _collect_tile_boxes(self, tile_shape_zyx: ZYX): ) for i in range(StitchAndRender3dTiles._tile_meshes.shape[1]): - tx, ty = StitchAndRender3dTiles._tile_idx_to_xy[i] + tx, ty = StitchAndRender3dTiles._mesh_index_to_xy[i] mesh = StitchAndRender3dTiles._tile_meshes[:, i, ...] tg_box = map_utils.outer_box(mesh, map_box, self._stride) @@ -141,9 +138,9 @@ def _get_dts(self, shape: ZYX, tx: int, ty: int) -> np.ndarray: mask = np.zeros(shape[1:], dtype=bool) if self._margin > 0: x0 = self._margin if tx > 0 else 0 - x1 = -self._margin if tx < self._tile_map.shape[-1] - 1 else -1 + x1 = -self._margin if tx < self._tile_layout.shape[-1] - 1 else -1 y0 = self._margin if ty > 0 else 0 - y1 = -self._margin if ty < self._tile_map.shape[-2] - 1 else -1 + y1 = -self._margin if ty < self._tile_layout.shape[-2] - 1 else -1 mask[y0:y1, x0:x1] = 1 else: mask[...] = 1 @@ -178,7 +175,7 @@ def _load_tile_images( logging.info('Processing source %r (%r)', i, out_box) coord_map = StitchAndRender3dTiles._tile_meshes[:, i, ...] - tx, ty = StitchAndRender3dTiles._tile_idx_to_xy[i] + tx, ty = StitchAndRender3dTiles._mesh_index_to_xy[i] if i not in StitchAndRender3dTiles._inverted_meshes: # Add context to avoid rounding issues in map inversion. @@ -209,8 +206,8 @@ def _load_tile_images( local_rel_box = sub_box.translate(-out_box.start) local_warp_box = local_rel_box.translate(local_out_box.start) - # Part of the inverted mesh that is needed to render the current - # region of interest. + # Part of the inverted mesh that is needed to render + # the current region of interest. s = 1.0 / np.array(self._stride)[::-1] local_map_box = local_warp_box.scale(s).adjusted_by( start=(-2, -2, -2), end=(2, 2, 2) @@ -251,30 +248,14 @@ def process( box = subvol.bbox logging.info('Processing %r', box) - mesh_init = False - - if StitchAndRender3dTiles._tile_meshes is None: - data_path = self._tile_mesh_path - with file.Open(data_path, 'rb') as f: - data = np.load(f, allow_pickle=True) - StitchAndRender3dTiles._tile_idx_to_xy = { - v: k for k, v in data['key_to_idx'].item().items() - } - StitchAndRender3dTiles._tile_meshes = data['x'] - assert StitchAndRender3dTiles._tile_meshes.shape[1] == len( - StitchAndRender3dTiles._tile_idx_to_xy - ) - mesh_init = True - volstores = {} for i in range(StitchAndRender3dTiles._tile_meshes.shape[1]): - tile_id = self._key_to_idx[StitchAndRender3dTiles._tile_idx_to_xy[i]] + tile_id = self._xy_to_tile_id[StitchAndRender3dTiles._mesh_index_to_xy[i]] volstores[i] = self._open_tile_volume(tile_id) # Bounding boxes representing a single tile placed the origin. tile_shape_zyx = next(iter(volstores.values())).shape - if mesh_init: - self._collect_tile_boxes(tile_shape_zyx) + self._collect_tile_boxes(tile_shape_zyx) # For blending, accumulate (weighted) image data as floats. This will # be normalized and cast to the desired output type once the image is diff --git a/setup.cfg b/setup.cfg index f76369c..ddf3920 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,6 +28,7 @@ install_requires = opencv-python>=4.5.5.62 scipy>=1.2.3 scikit-image>=0.17.2 + tensorstore>=0.1.39 [options.packages.find] where = . diff --git a/stitch_elastic.py b/stitch_elastic.py index 977f55a..3881139 100644 --- a/stitch_elastic.py +++ b/stitch_elastic.py @@ -134,8 +134,8 @@ def compute_flow_map3d( curr_box = bounding_box.BoundingBox(start=(0, 0, 0), size=tile_shape) nbor_box = bounding_box.BoundingBox( start=( - tile_shape[0] * (1 - axis) + offset[0], - tile_shape[1] * axis + offset[1], + offset[0], + offset[1], offset[2], ), size=tile_shape, diff --git a/zarr_processor.py b/zarr_processor.py new file mode 100644 index 0000000..e11407a --- /dev/null +++ b/zarr_processor.py @@ -0,0 +1,440 @@ +"""Object Wrapper around SOFIMA on Zarr Datasets.""" + +import functools as ft +import jax +import jax.numpy as jnp +import numpy as np +import tensorstore as ts +import time + +from connectomics.common import bounding_box +from connectomics.common import box_generator +from connectomics.volume import subvolume +from sofima import coarse_registration, flow_utils, stitch_elastic, stitch_rigid, map_utils, mesh, zarr_utils +from sofima.processor import warp + + +# NOTE: +# - SOFIMA/ZarrStitcher follows following basis convention: +# o -- x +# | +# y +# Any reference to 'x' or 'y' adopt this basis. + +# - All displacements are defined in pixel space established +# by the downsample_exp/resolution of the input images. + + +class SyncAdapter: + """Makes it possible to use a TensorStore objects as a numpy array.""" + + def __init__(self, tstore): + self.tstore = tstore + + def __getitem__(self, ind): + print(ind) + return np.array(self.tstore[ind]) + + def __getattr__(self, attr): + return getattr(self.tstore, attr) + + @property + def shape(self): + return self.tstore.shape + + @property + def ndim(self): + return self.tstore.ndim + + +class ZarrFusion(warp.StitchAndRender3dTiles): + """ + Fusion renderer subclass + that implements data loading for Zarr datasets. + """ + cache = {} + + def __init__(self, + zarr_params: zarr_utils.ZarrDataset, + tile_layout: np.ndarray, + fine_tile_mesh: np.ndarray, + fine_mesh_xy_to_index: dict[tuple[int, int], int], + stride_zyx: tuple[int, int, int], + offset_xyz: tuple[float, float, float], + parallelism=16) -> None: + super().__init__(tile_layout, + fine_tile_mesh, + fine_mesh_xy_to_index, + stride_zyx, + offset_xyz, + parallelism) + self.zarr_params = zarr_params + + + def _open_tile_volume(self, tile_id: int): + if tile_id in self.cache: + return self.cache[tile_id] + + tile_volumes, tile_size_xyz = zarr_utils.load_zarr_data(self.zarr_params) + tile = tile_volumes[tile_id] + self.cache[tile_id] = SyncAdapter(tile[0,0,:,:,:]) + return self.cache[tile_id] + + +class ZarrStitcher: + """ + Object wrapper around SOFIMA for operating on Zarr datasets. + """ + + def __init__(self, + input_zarr: zarr_utils.ZarrDataset, + tile_layout: np.ndarray) -> None: + """ + zarr_params: See ZarrDataset, params for input dataset + tile_layout: 2D array of tile ids defining relative tile placement. + Tile ids correspond to indices of ZarrDataset.tile_names. + """ + + self.input_zarr = input_zarr + + self.tile_volumes: list[ts.TensorStore] = [] # 5D tczyx homogenous shape + self.tile_volumes, self.tile_size_xyz = zarr_utils.load_zarr_data(input_zarr) + self.tile_layout = tile_layout + + self.tile_map: dict[tuple[int, int], ts.TensorStore] = {} + for y, row in enumerate(tile_layout): + for x, tile_id in enumerate(row): + self.tile_map[(x, y)] = self.tile_volumes[tile_id] + + + def run_coarse_registration(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Runs coarse registration. + Returns: + cx: tile_layout shape + Each entry represents displacement of current tile towards right neighbor. + cy: tile_layout shape + Each entry represents displacement of current tile towards bottom neighbor. + coarse_mesh: (3, 1, tile_layout) shape + Each entry net displacement of current tile. + """ + + # Custom data loading for coarse registration + _tile_volumes: list[ts.TensorStore] = [] + for vol in self.tile_volumes: + _tile_volumes.append(vol.T[:,:,:,0,0]) + + cx, cy = coarse_registration.compute_coarse_offsets(self.tile_layout, _tile_volumes) + coarse_mesh = stitch_rigid.optimize_coarse_mesh(cx, + cy, + mesh_fn=stitch_rigid.elastic_tile_mesh_3d) + return cx, cy, coarse_mesh + + + def run_fine_registration(self, + cx: np.ndarray, + cy: np.ndarray, + coarse_mesh: np.ndarray, + stride_zyx: tuple[int, int, int] + ) -> tuple[np.ndarray, dict[tuple[int, int], int]]: + """ + Runs fine registration. + Inputs: + cx: Coarse offsets in x direction, output of coarse registration. + cy: Coarse offsets in y direction, output of coarse registration. + coarse_mesh: Coarse offsets in combined array, output of coarse registration. + stride_zyx: Subdivision of each tile to create fine mesh. + + Outputs: + solved_fine_mesh: Fine mesh containing offsets of each subdivision. + Shape is (3, tile_index, stride_z, stride_y, stride_x). + fine_mesh_xy_to_index: Map of tile coordinates to custom mesh tile index. + stride_zyx: Same as input, by returned as important parameter. + """ + + # Custom data loading for fine registration + _tile_map = {} + for key, tstore in self.tile_map.items(): + _tile_map[key] = SyncAdapter(tstore[0,:,:,:,:]) + + # Compute flow map + flow_x, offsets_x = stitch_elastic.compute_flow_map3d(_tile_map, + self.tile_size_xyz, + cx, axis=0, + stride=stride_zyx, + patch_size=(80, 80, 80)) + + flow_y, offsets_y = stitch_elastic.compute_flow_map3d(_tile_map, + self.tile_size_xyz, + cy, axis=1, + stride=stride_zyx, + patch_size=(80, 80, 80)) + + # Filter patch flows + kwargs = {"min_peak_ratio": 1.4, "min_peak_sharpness": 1.4, + "max_deviation": 5, "max_magnitude": 0, "dim": 3} + fine_x = {k: flow_utils.clean_flow(v, **kwargs) for k, v in flow_x.items()} + fine_y = {k: flow_utils.clean_flow(v, **kwargs) for k, v in flow_y.items()} + + kwargs = {"min_patch_size": 10, "max_gradient": -1, "max_deviation": -1} + fine_x = {k: flow_utils.reconcile_flows([v], **kwargs) for k, v in fine_x.items()} + fine_y = {k: flow_utils.reconcile_flows([v], **kwargs) for k, v in fine_y.items()} + + # Update mesh (convert coarse tile mesh into fine patch mesh) + data_x = (cx[:, 0, ...], fine_x, offsets_x) + data_y = (cy[:, 0, ...], fine_y, offsets_y) + fx, fy, fine_mesh, nbors, fine_mesh_xy_to_index = stitch_elastic.aggregate_arrays( + data_x, data_y, list(self.tile_map.keys()), + coarse_mesh[:, 0, ...], stride=stride_zyx, tile_shape=self.tile_size_xyz[::-1]) + + @jax.jit + def prev_fn(x): + target_fn = ft.partial(stitch_elastic.compute_target_mesh, + x=x, fx=fx, fy=fy, stride=stride_zyx) + x = jax.vmap(target_fn)(nbors) + return jnp.transpose(x, [1, 0, 2, 3, 4]) + + config = mesh.IntegrationConfig(dt=0.001, gamma=0., k0=0.01, k=0.1, stride=stride_zyx, + num_iters=1000, max_iters=20000, stop_v_max=0.001, + dt_max=100, prefer_orig_order=False, + start_cap=0.1, final_cap=10., remove_drift=True) + + solved_fine_mesh, ekin, t = mesh.relax_mesh(fine_mesh, None, config, + prev_fn=prev_fn, mesh_force=mesh.elastic_mesh_3d) + + return solved_fine_mesh, fine_mesh_xy_to_index, stride_zyx + + + def _run_fusion(self, + output_cloud_storage: zarr_utils.CloudStorage, + output_bucket: str, + output_path: str, + downsample_exp: int, + cx: np.ndarray, + cy: np.ndarray, + fine_mesh: np.ndarray, + fine_mesh_xy_to_index: dict[tuple[int, int], int], + stride_zyx: tuple[int, int, int], + parallelism: int = 16 + ) -> None: + """ + Runs fusion. + Inputs: + output_cloud_storage, output_bucket, output_path: + Output storage parameters + downsample_exp: + Desired output resolution, 0 for highest resolution. + fine_mesh, fine_mesh_xy_to_index, stride_zyx: + Fine mesh offsets and accompanying metadata, + output of coarse/fine registration. + parallelism: + Multithreading. + """ + + if output_cloud_storage == zarr_utils.CloudStorage.S3: + raise NotImplementedError( + 'TensorStore does not support s3 writes.' + ) + + fusion_zarr = self.input_zarr + fusion_mesh = fine_mesh + fusion_stride_zyx = stride_zyx + fusion_tile_size_zyx = self.tile_size_xyz[::-1] + if downsample_exp != self.input_zarr.downsample_exp: + # Reload the data at target resolution + fusion_zarr = zarr_utils.ZarrDataset(self.input_zarr.cloud_storage, + self.input_zarr.bucket, + self.input_zarr.dataset_path, + self.input_zarr.tile_names, + downsample_exp) + + # Rescale fine mesh, stride + curr_exp = self.input_zarr.downsample_exp + target_exp = downsample_exp + scale_factor = 2**(curr_exp - target_exp) + fusion_mesh = fine_mesh * scale_factor + fusion_stride_zyx = tuple(np.array(stride_zyx) * scale_factor) + fusion_tile_size_zyx = tuple(np.array(self.tile_size_xyz)[::-1] * scale_factor) + print(f'{scale_factor=}') + + start = np.array([np.inf, np.inf, np.inf]) + map_box = bounding_box.BoundingBox( + start=(0, 0, 0), + size=fusion_mesh.shape[2:][::-1], + ) + fine_mesh_index_to_xy = { + v: k for k, v in fine_mesh_xy_to_index.items() + } + for i in range(0, fusion_mesh.shape[1]): + tx, ty = fine_mesh_index_to_xy[i] + mesh = fusion_mesh[:, i, ...] + tg_box = map_utils.outer_box(mesh, map_box, fusion_stride_zyx) + + out_box = bounding_box.BoundingBox( + start=( + tg_box.start[0] * fusion_stride_zyx[2] + tx * fusion_tile_size_zyx[2], + tg_box.start[1] * fusion_stride_zyx[1] + ty * fusion_tile_size_zyx[1], + tg_box.start[2] * fusion_stride_zyx[0], + ), + size=( + tg_box.size[0] * fusion_stride_zyx[2], + tg_box.size[1] * fusion_stride_zyx[1], + tg_box.size[2] * fusion_stride_zyx[0], + ) + ) + start = np.minimum(start, out_box.start) + print(f'{tg_box=}') + print(f'{out_box=}') + + crop_offset = -start + print(f'{crop_offset=}') + + # Fused shape + cx[np.isnan(cx)] = 0 + cy[np.isnan(cy)] = 0 + x_overlap = cx[2,0,0,0] / self.tile_size_xyz[0] + y_overlap = cy[1,0,0,0] / self.tile_size_xyz[1] + y_shape, x_shape = cx.shape[2], cx.shape[3] + + fused_x = fusion_tile_size_zyx[2] * (1 + ((x_shape - 1) * (1 - x_overlap))) + fused_y = fusion_tile_size_zyx[1] * (1 + ((y_shape - 1) * (1 - y_overlap))) + fused_z = fusion_tile_size_zyx[0] + fused_shape_5d = [1, 1, int(fused_z), int(fused_y), int(fused_x)] + print(f'{fused_shape_5d=}') + + # Perform fusion + ds_out = zarr_utils.write_zarr(output_bucket, fused_shape_5d, output_path) + renderer = ZarrFusion(zarr_params=fusion_zarr, + tile_layout=self.tile_layout, + fine_tile_mesh=fusion_mesh, + fine_mesh_xy_to_index=fine_mesh_xy_to_index, + stride_zyx=fusion_stride_zyx, + offset_xyz=crop_offset, + parallelism=parallelism) + + box = bounding_box.BoundingBox(start=(0,0,0), size=ds_out.shape[4:1:-1]) # Needs xyz + gen = box_generator.BoxGenerator(box, (512, 512, 512), (0, 0, 0), True) # These are xyz + renderer.set_effective_subvol_and_overlap((512, 512, 512), (0, 0, 0)) + for i, sub_box in enumerate(gen.boxes): + t_start = time.time() + + # Feed in an empty subvol, with dimensions of sub_box. + inp_subvol = subvolume.Subvolume(np.zeros(sub_box.size[::-1], dtype=np.uint16)[None, ...], sub_box) + ret_subvol = renderer.process(inp_subvol) # czyx + + t_render = time.time() + + # ret_subvol is a 4D CZYX volume + slice = ret_subvol.bbox.to_slice3d() + slice = (0, 0, slice[0], slice[1], slice[2]) + ds_out[slice].write(ret_subvol.data[0, ...]).result() + + t_write = time.time() + + print('box {i}: {t1:0.2f} render {t2:0.2f} write'.format(i=i, t1=t_render - t_start, t2=t_write - t_render)) + + + def run_fusion_on_coarse_mesh(self, + output_cloud_storage: zarr_utils.CloudStorage, + output_bucket: str, + output_path: str, + downsample_exp: int, + cx: np.ndarray, + cy: np.ndarray, + coarse_mesh: np.ndarray, + stride_zyx: tuple[int, int, int] = (20, 20, 20), + parallelism: int = 16) -> None: + """ + Transforms coarse mesh into fine mesh before + passing along to ZarrStitcher._run_fusion(...) + """ + + # Fine Mesh Tile Index + fine_mesh_xy_to_index = {(tx, ty): i for i, (tx, ty) in enumerate(self.tile_map.keys())} + + # Fine Mesh + dim = len(stride_zyx) + mesh_shape = (np.array(self.tile_size_xyz[::-1]) // stride_zyx).tolist() + fine_mesh = np.zeros([dim, len(fine_mesh_xy_to_index)] + mesh_shape, dtype=np.float32) + for (tx, ty) in self.tile_map.keys(): + fine_mesh[:, fine_mesh_xy_to_index[tx, ty], ...] = coarse_mesh[:, 0, ty, tx].reshape( + (dim,) + (1,) * dim) + + self._run_fusion(output_cloud_storage, + output_bucket, + output_path, + downsample_exp, + cx, + cy, + fine_mesh, + fine_mesh_xy_to_index, + stride_zyx, + parallelism) + + + def run_fusion_on_fine_mesh(self, + output_cloud_storage: zarr_utils.CloudStorage, + output_bucket: str, + output_path: str, + downsample_exp: int, + cx: np.ndarray, + cy: np.ndarray, + fine_mesh: np.ndarray, + fine_mesh_xy_to_index: dict[tuple[int, int], int], + stride_zyx: tuple[int, int, int], + parallelism: int = 16 + ) -> None: + """ + Simply passes all input parameters to + private method ZarrStitcher._run_fusion(...) + """ + + self._run_fusion(output_cloud_storage, + output_bucket, + output_path, + downsample_exp, + cx, + cy, + fine_mesh, + fine_mesh_xy_to_index, + stride_zyx, + parallelism) + + +if __name__ == '__main__': + # Example set of Application Inputs + cloud_storage = zarr_utils.CloudStorage.S3 + bucket = 'aind-open-data' + dataset_path = 'diSPIM_647459_2022-12-07_00-00-00/diSPIM.zarr' + downsample_exp = 2 + tile_names = ['tile_X_0000_Y_0000_Z_0000_CH_0405_cam1.zarr', + 'tile_X_0001_Y_0000_Z_0000_CH_0405_cam1.zarr'] + tile_layout = np.array([[1], + [0]]) + input_zarr = zarr_utils.ZarrDataset(cloud_storage=cloud_storage, + bucket=bucket, + dataset_path=dataset_path, + tile_names=tile_names, + downsample_exp=downsample_exp) + + # Application Outputs + output_cloud_storage = zarr_utils.CloudStorage.GCS + output_bucket = 'YOUR-BUCKET-HERE' + output_path = 'YOU-OUTPUT-NAME.zarr' + + # Processing + zarr_stitcher = ZarrStitcher(input_zarr, tile_layout) + cx, cy, coarse_mesh = zarr_stitcher.run_coarse_registration() + fine_mesh, fine_mesh_xy_to_index, stride_zyx = zarr_stitcher.run_fine_registration(cx, + cy, + coarse_mesh, + stride_zyx=(20, 20, 20)) + zarr_stitcher._run_fusion(output_cloud_storage=output_cloud_storage, + output_bucket=output_bucket, + output_path=output_path, + downsample_exp=0, # For full resolution fusion. + cx=cx, + cy=cy, + fine_mesh=fine_mesh, + fine_mesh_xy_to_index=fine_mesh_xy_to_index, + stride_zyx=stride_zyx) \ No newline at end of file diff --git a/zarr_utils.py b/zarr_utils.py new file mode 100644 index 0000000..ca6c2fc --- /dev/null +++ b/zarr_utils.py @@ -0,0 +1,114 @@ +from dataclasses import dataclass +from enum import Enum +import numpy as np +import tensorstore as ts + + +class CloudStorage(Enum): + """ + Documented Cloud Storage Options + """ + S3 = 1 + GCS = 2 + + +@dataclass +class ZarrDataset: + """ + Parameters for locating Zarr dataset living on the cloud. + """ + cloud_storage: CloudStorage + bucket: str + dataset_path: str + tile_names: list[str] + downsample_exp: int + + +def open_zarr_gcs(bucket: str, path: str): + return ts.open({ + 'driver': 'zarr', + 'kvstore': { + 'driver': 'gcs', + 'bucket': bucket, + }, + 'path': path, + }).result() + + +def open_zarr_s3(bucket: str, path: str): + return ts.open({ + 'driver': 'zarr', + 'kvstore': { + 'driver': 'http', + 'base_url': f'https://{bucket}.s3.us-west-2.amazonaws.com/{path}', + }, + }).result() + + +def load_zarr_data(params: ZarrDataset + ) -> tuple[list[ts.TensorStore], tuple[int, int, int]]: + """ + Reads Zarr dataset from input location + and returns list of equally-sized tensorstores + in matching order as ZarrDataset.tile_names and tile size. + """ + + def load_zarr(bucket: str, tile_location: str) -> ts.TensorStore: + if params.cloud_storage == CloudStorage.S3: + return open_zarr_s3(bucket, tile_location) + else: # cloud == 'gcs' + return open_zarr_gcs(bucket, tile_location) + tile_volumes = [] + min_x, min_y, min_z = np.inf, np.inf, np.inf + for t_name in params.tile_names: + tile_location = f"{params.dataset_path}/{t_name}/{params.downsample_exp}" + tile = load_zarr(params.bucket, tile_location) + tile_volumes.append(tile) + + _, _, tz, ty, tx = tile.shape + min_x, min_y, min_z = int(np.minimum(min_x, tx)), \ + int(np.minimum(min_y, ty)), \ + int(np.minimum(min_z, tz)) + tile_size_xyz = min_x, min_y, min_z + + # Standardize size of tile volumes + for i, tile_vol in enumerate(tile_volumes): + tile_volumes[i] = tile_vol[:, :, :min_z, :min_y, :min_x] + + return tile_volumes, tile_size_xyz + + +def write_zarr(bucket: str, shape: list, path: str): + """ + Shape must be 5D vector in tczyx order: + Ex: [1, 1, 3551, 576, 576] + """ + + return ts.open({ + 'driver': 'zarr', + 'dtype': 'uint16', + 'kvstore' : { + 'driver': 'gcs', + 'bucket': bucket, + }, + 'create': True, + 'delete_existing': True, + 'path': path, + 'metadata': { + 'chunks': [1, 1, 128, 256, 256], + 'compressor': { + 'blocksize': 0, + 'clevel': 1, + 'cname': 'zstd', + 'id': 'blosc', + 'shuffle': 1, + }, + 'dimension_separator': '/', + 'dtype': ' Date: Mon, 3 Jul 2023 16:49:06 -0700 Subject: [PATCH 2/7] revert core sofima changes, export feat into zarr/ --- processor/warp.py | 69 ++++--- stitch_elastic.py | 6 +- zarr/__init__.py | 14 ++ zarr_utils.py => zarr/zarr_io.py | 26 ++- .../zarr_register_and_fuse_3d.py | 175 +++++++++--------- 5 files changed, 165 insertions(+), 125 deletions(-) create mode 100644 zarr/__init__.py rename zarr_utils.py => zarr/zarr_io.py (72%) rename zarr_processor.py => zarr/zarr_register_and_fuse_3d.py (75%) diff --git a/processor/warp.py b/processor/warp.py index ff3283f..a1833d3 100644 --- a/processor/warp.py +++ b/processor/warp.py @@ -36,7 +36,7 @@ class StitchAndRender3dTiles(subvolume_processor.SubvolumeProcessor): """Renders a volume by stitching 3d tiles placed on a 2d grid.""" _tile_meshes = None - _mesh_index_to_xy = {} + _tile_idx_to_xy = None _tile_boxes = {} _inverted_meshes = {} @@ -44,21 +44,28 @@ class StitchAndRender3dTiles(subvolume_processor.SubvolumeProcessor): def __init__( self, - tile_layout: Sequence[Sequence[int]], - tile_mesh: str, - xy_to_mesh_index: dict[int, tuple], + *, + tile_map: Sequence[Sequence[int]], + tile_mesh_path: str, + tile_pattern_path: str, stride: ZYX, offset: XYZ = (0, 0, 0), margin: int = 0, work_size: XYZ = (128, 128, 128), order: int = 1, - parallelism: int = 16 + parallelism: int = 16, + input_volinfo=None, ): """Constructor. Args: tile_map: yx-shaped grid of tile IDs - tile_idx_to_xy: index + tile_mesh_path: path to a npz file containing 'key_to_idx' and 'x' arrays, + as generated by `stitch_elastic.aggregate_arrays` and `mesh.solve_mesh`, + respectively + tile_pattern_path: volinfo path for the volumes containing individual + tiles; must contain '{tile_id}', which will be substituted with values + from `tile_map` stride: ZYX stride of the mesh in pixels offset: XYZ global offset to apply to the rendered image margin: number of pixels away from the tile boundary to ignore during @@ -68,8 +75,12 @@ def __init__( work_size: see `warp.ndimage_warp` order: see `warp.ndimage_warp` parallelism: see `warp.ndimage_warp` + input_volinfo: not used """ - self._tile_layout = tile_layout + del input_volinfo + self._tile_map = np.array(tile_map) + self._tile_mesh_path = tile_mesh_path + self._tile_pattern_path = tile_pattern_path self._stride = stride self._offset = offset self._margin = margin @@ -77,18 +88,10 @@ def __init__( self._parallelism = parallelism self._work_size = work_size - StitchAndRender3dTiles._tile_meshes = tile_mesh - StitchAndRender3dTiles._mesh_index_to_xy = { - v:k for k, v in xy_to_mesh_index.items() - } - assert StitchAndRender3dTiles._tile_meshes.shape[1] == len( - StitchAndRender3dTiles._mesh_index_to_xy - ) - - self._xy_to_tile_id = {} - for y, row in enumerate(tile_layout): + self._key_to_idx = {} + for y, row in enumerate(tile_map): for x, tile_id in enumerate(row): - self._xy_to_tile_id[(x, y)] = tile_id + self._key_to_idx[(x, y)] = tile_id def _open_tile_volume(self, tile_id: int) -> Any: """Returns a ZYX-shaped ndarray-like object representing the tile data.""" @@ -106,7 +109,7 @@ def _collect_tile_boxes(self, tile_shape_zyx: ZYX): ) for i in range(StitchAndRender3dTiles._tile_meshes.shape[1]): - tx, ty = StitchAndRender3dTiles._mesh_index_to_xy[i] + tx, ty = StitchAndRender3dTiles._tile_idx_to_xy[i] mesh = StitchAndRender3dTiles._tile_meshes[:, i, ...] tg_box = map_utils.outer_box(mesh, map_box, self._stride) @@ -138,9 +141,9 @@ def _get_dts(self, shape: ZYX, tx: int, ty: int) -> np.ndarray: mask = np.zeros(shape[1:], dtype=bool) if self._margin > 0: x0 = self._margin if tx > 0 else 0 - x1 = -self._margin if tx < self._tile_layout.shape[-1] - 1 else -1 + x1 = -self._margin if tx < self._tile_map.shape[-1] - 1 else -1 y0 = self._margin if ty > 0 else 0 - y1 = -self._margin if ty < self._tile_layout.shape[-2] - 1 else -1 + y1 = -self._margin if ty < self._tile_map.shape[-2] - 1 else -1 mask[y0:y1, x0:x1] = 1 else: mask[...] = 1 @@ -175,7 +178,7 @@ def _load_tile_images( logging.info('Processing source %r (%r)', i, out_box) coord_map = StitchAndRender3dTiles._tile_meshes[:, i, ...] - tx, ty = StitchAndRender3dTiles._mesh_index_to_xy[i] + tx, ty = StitchAndRender3dTiles._tile_idx_to_xy[i] if i not in StitchAndRender3dTiles._inverted_meshes: # Add context to avoid rounding issues in map inversion. @@ -248,14 +251,30 @@ def process( box = subvol.bbox logging.info('Processing %r', box) + mesh_init = False + + if StitchAndRender3dTiles._tile_meshes is None: + data_path = self._tile_mesh_path + with file.Open(data_path, 'rb') as f: + data = np.load(f, allow_pickle=True) + StitchAndRender3dTiles._tile_idx_to_xy = { + v: k for k, v in data['key_to_idx'].item().items() + } + StitchAndRender3dTiles._tile_meshes = data['x'] + assert StitchAndRender3dTiles._tile_meshes.shape[1] == len( + StitchAndRender3dTiles._tile_idx_to_xy + ) + mesh_init = True + volstores = {} for i in range(StitchAndRender3dTiles._tile_meshes.shape[1]): - tile_id = self._xy_to_tile_id[StitchAndRender3dTiles._mesh_index_to_xy[i]] + tile_id = self._key_to_idx[StitchAndRender3dTiles._tile_idx_to_xy[i]] volstores[i] = self._open_tile_volume(tile_id) # Bounding boxes representing a single tile placed the origin. tile_shape_zyx = next(iter(volstores.values())).shape - self._collect_tile_boxes(tile_shape_zyx) + if mesh_init: + self._collect_tile_boxes(tile_shape_zyx) # For blending, accumulate (weighted) image data as floats. This will # be normalized and cast to the desired output type once the image is @@ -313,4 +332,4 @@ def process( ret[norm > 0] /= norm[norm > 0] ret = ret.astype(self.output_type(subvol.data.dtype)) - return self.crop_box_and_data(box, ret[None, ...]) + return self.crop_box_and_data(box, ret[None, ...]) \ No newline at end of file diff --git a/stitch_elastic.py b/stitch_elastic.py index 3881139..a653e66 100644 --- a/stitch_elastic.py +++ b/stitch_elastic.py @@ -134,8 +134,8 @@ def compute_flow_map3d( curr_box = bounding_box.BoundingBox(start=(0, 0, 0), size=tile_shape) nbor_box = bounding_box.BoundingBox( start=( - offset[0], - offset[1], + tile_shape[0] * (1 - axis) + offset[0], + tile_shape[1] * axis + offset[1], offset[2], ), size=tile_shape, @@ -673,4 +673,4 @@ def compute_target_mesh( if dim == 2: return updated[:, : x.shape[-2], : x.shape[-1]] else: - return updated[:, : x.shape[-3], : x.shape[-2], : x.shape[-1]] + return updated[:, : x.shape[-3], : x.shape[-2], : x.shape[-1]] \ No newline at end of file diff --git a/zarr/__init__.py b/zarr/__init__.py new file mode 100644 index 0000000..7766812 --- /dev/null +++ b/zarr/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright 2022 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/zarr_utils.py b/zarr/zarr_io.py similarity index 72% rename from zarr_utils.py rename to zarr/zarr_io.py index ca6c2fc..6f3331c 100644 --- a/zarr_utils.py +++ b/zarr/zarr_io.py @@ -3,6 +3,7 @@ import numpy as np import tensorstore as ts +from sofima import stitch_elastic class CloudStorage(Enum): """ @@ -16,15 +17,27 @@ class CloudStorage(Enum): class ZarrDataset: """ Parameters for locating Zarr dataset living on the cloud. + Args: + cloud_storage: CloudStorage option + bucket: Name of bucket + dataset_path: Path to directory containing zarr files within bucket + tile_names: List of zarr tiles to include in dataset. + Order of tile_names defines an index that + is expected to be used in tile_layout. + tile_layout: 2D array of indices that defines relative position of tiles. + downsample_exp: Level in image pyramid with each level + downsampling the original resolution by 2**downsmaple_exp. """ + cloud_storage: CloudStorage bucket: str dataset_path: str tile_names: list[str] + tile_layout: np.ndarray downsample_exp: int -def open_zarr_gcs(bucket: str, path: str): +def open_zarr_gcs(bucket: str, path: str) -> ts.TensorStore: return ts.open({ 'driver': 'zarr', 'kvstore': { @@ -35,7 +48,7 @@ def open_zarr_gcs(bucket: str, path: str): }).result() -def open_zarr_s3(bucket: str, path: str): +def open_zarr_s3(bucket: str, path: str) -> ts.TensorStore: return ts.open({ 'driver': 'zarr', 'kvstore': { @@ -46,11 +59,12 @@ def open_zarr_s3(bucket: str, path: str): def load_zarr_data(params: ZarrDataset - ) -> tuple[list[ts.TensorStore], tuple[int, int, int]]: + ) -> tuple[list[ts.TensorStore], stitch_elastic.ShapeXYZ]: """ Reads Zarr dataset from input location and returns list of equally-sized tensorstores in matching order as ZarrDataset.tile_names and tile size. + Tensorstores are cropped to tiles at origin to the smallest tile in the set. """ def load_zarr(bucket: str, tile_location: str) -> ts.TensorStore: @@ -80,8 +94,10 @@ def load_zarr(bucket: str, tile_location: str) -> ts.TensorStore: def write_zarr(bucket: str, shape: list, path: str): """ - Shape must be 5D vector in tczyx order: - Ex: [1, 1, 3551, 576, 576] + Args: + bucket: Name of gcs cloud storage bucket + shape: 5D vector in tczyx order, ex: [1, 1, 3551, 576, 576] + path: Output path inside bucket """ return ts.open({ diff --git a/zarr_processor.py b/zarr/zarr_register_and_fuse_3d.py similarity index 75% rename from zarr_processor.py rename to zarr/zarr_register_and_fuse_3d.py index e11407a..785a0b2 100644 --- a/zarr_processor.py +++ b/zarr/zarr_register_and_fuse_3d.py @@ -10,9 +10,10 @@ from connectomics.common import bounding_box from connectomics.common import box_generator from connectomics.volume import subvolume -from sofima import coarse_registration, flow_utils, stitch_elastic, stitch_rigid, map_utils, mesh, zarr_utils +from sofima import coarse_registration, flow_utils, stitch_elastic, stitch_rigid, map_utils, mesh from sofima.processor import warp +import zarr_io # NOTE: # - SOFIMA/ZarrStitcher follows following basis convention: @@ -32,7 +33,6 @@ def __init__(self, tstore): self.tstore = tstore def __getitem__(self, ind): - print(ind) return np.array(self.tstore[ind]) def __getattr__(self, attr): @@ -49,33 +49,29 @@ def ndim(self): class ZarrFusion(warp.StitchAndRender3dTiles): """ - Fusion renderer subclass - that implements data loading for Zarr datasets. + Fusion renderer loading tile data from Zarr. """ cache = {} def __init__(self, - zarr_params: zarr_utils.ZarrDataset, - tile_layout: np.ndarray, - fine_tile_mesh: np.ndarray, - fine_mesh_xy_to_index: dict[tuple[int, int], int], + zarr_params: zarr_io.ZarrDataset, + tile_mesh_path: str, stride_zyx: tuple[int, int, int], offset_xyz: tuple[float, float, float], parallelism=16) -> None: - super().__init__(tile_layout, - fine_tile_mesh, - fine_mesh_xy_to_index, - stride_zyx, - offset_xyz, - parallelism) + super().__init__(zarr_params.tile_layout, + tile_mesh_path, + "", + stride_zyx, + offset_xyz, + parallelism) self.zarr_params = zarr_params - def _open_tile_volume(self, tile_id: int): if tile_id in self.cache: return self.cache[tile_id] - tile_volumes, tile_size_xyz = zarr_utils.load_zarr_data(self.zarr_params) + tile_volumes, tile_size_xyz = zarr_io.load_zarr_data(self.zarr_params) tile = tile_volumes[tile_id] self.cache[tile_id] = SyncAdapter(tile[0,0,:,:,:]) return self.cache[tile_id] @@ -87,19 +83,16 @@ class ZarrStitcher: """ def __init__(self, - input_zarr: zarr_utils.ZarrDataset, - tile_layout: np.ndarray) -> None: + input_zarr: zarr_io.ZarrDataset) -> None: """ - zarr_params: See ZarrDataset, params for input dataset - tile_layout: 2D array of tile ids defining relative tile placement. - Tile ids correspond to indices of ZarrDataset.tile_names. + zarr_params: See ZarrDataset, params for input dataset """ self.input_zarr = input_zarr self.tile_volumes: list[ts.TensorStore] = [] # 5D tczyx homogenous shape - self.tile_volumes, self.tile_size_xyz = zarr_utils.load_zarr_data(input_zarr) - self.tile_layout = tile_layout + self.tile_volumes, self.tile_size_xyz = zarr_io.load_zarr_data(input_zarr) + self.tile_layout = input_zarr.tile_layout self.tile_map: dict[tuple[int, int], ts.TensorStore] = {} for y, row in enumerate(tile_layout): @@ -135,8 +128,9 @@ def run_fine_registration(self, cx: np.ndarray, cy: np.ndarray, coarse_mesh: np.ndarray, - stride_zyx: tuple[int, int, int] - ) -> tuple[np.ndarray, dict[tuple[int, int], int]]: + stride_zyx: tuple[int, int, int], + save_mesh_path: str = "solved_meshes.npy" + ) -> None: """ Runs fine registration. Inputs: @@ -145,7 +139,7 @@ def run_fine_registration(self, coarse_mesh: Coarse offsets in combined array, output of coarse registration. stride_zyx: Subdivision of each tile to create fine mesh. - Outputs: + Outputs (inside of output mesh path): solved_fine_mesh: Fine mesh containing offsets of each subdivision. Shape is (3, tile_index, stride_z, stride_y, stride_x). fine_mesh_xy_to_index: Map of tile coordinates to custom mesh tile index. @@ -156,14 +150,19 @@ def run_fine_registration(self, _tile_map = {} for key, tstore in self.tile_map.items(): _tile_map[key] = SyncAdapter(tstore[0,:,:,:,:]) + + # INPUT FORMATTING: + # For axis 0, subtract tile_size x from the offset[0] + # For axis 1, subtract tile_size y from the offset[1] + # Tile size is readded inside of stitch_elastic.compute_flow_map3d. + cx[:, 0, :, :] = cx[:, 0, :, :] - np.array([self.tile_size_xyz[0], 0, 0]) + cy[:, 0, :, :] = cy[:, 0, :, :] - np.array([0, self.tile_size_xyz[1], 0]) - # Compute flow map flow_x, offsets_x = stitch_elastic.compute_flow_map3d(_tile_map, self.tile_size_xyz, cx, axis=0, stride=stride_zyx, patch_size=(80, 80, 80)) - flow_y, offsets_y = stitch_elastic.compute_flow_map3d(_tile_map, self.tile_size_xyz, cy, axis=1, @@ -202,19 +201,21 @@ def prev_fn(x): solved_fine_mesh, ekin, t = mesh.relax_mesh(fine_mesh, None, config, prev_fn=prev_fn, mesh_force=mesh.elastic_mesh_3d) - return solved_fine_mesh, fine_mesh_xy_to_index, stride_zyx - - - def _run_fusion(self, - output_cloud_storage: zarr_utils.CloudStorage, + # Save the mesh/mesh index map + np.savez_compressed(save_mesh_path, + x=solved_fine_mesh, + key_to_idx=fine_mesh_xy_to_index, + stride_zyx=stride_zyx) + + + def run_fusion(self, + output_cloud_storage: zarr_io.CloudStorage, output_bucket: str, output_path: str, downsample_exp: int, cx: np.ndarray, cy: np.ndarray, - fine_mesh: np.ndarray, - fine_mesh_xy_to_index: dict[tuple[int, int], int], - stride_zyx: tuple[int, int, int], + tile_mesh_path: str, parallelism: int = 16 ) -> None: """ @@ -231,7 +232,12 @@ def _run_fusion(self, Multithreading. """ - if output_cloud_storage == zarr_utils.CloudStorage.S3: + data = np.load(tile_mesh_path) + fine_mesh = data['x'] + fine_mesh_xy_to_index = data['key_to_idx'] + stride_zyx = data['stride_zyx'] + + if output_cloud_storage == zarr_io.CloudStorage.S3: raise NotImplementedError( 'TensorStore does not support s3 writes.' ) @@ -242,7 +248,7 @@ def _run_fusion(self, fusion_tile_size_zyx = self.tile_size_xyz[::-1] if downsample_exp != self.input_zarr.downsample_exp: # Reload the data at target resolution - fusion_zarr = zarr_utils.ZarrDataset(self.input_zarr.cloud_storage, + fusion_zarr = zarr_io.ZarrDataset(self.input_zarr.cloud_storage, self.input_zarr.bucket, self.input_zarr.dataset_path, self.input_zarr.tile_names, @@ -302,12 +308,19 @@ def _run_fusion(self, fused_shape_5d = [1, 1, int(fused_z), int(fused_y), int(fused_x)] print(f'{fused_shape_5d=}') + # INPUT FORMATTING: + # Save rescaled mesh back into .npz volume + # as this is the expected input of warp.StitchAndRender3dTiles.process + rescaled_mesh_path = 'rescaled_fusion_mesh.npz' + np.savez_compressed(rescaled_mesh_path, + x=fusion_mesh, + key_to_idx=fine_mesh_xy_to_index, + stride_zyx=fusion_stride_zyx) + # Perform fusion - ds_out = zarr_utils.write_zarr(output_bucket, fused_shape_5d, output_path) + ds_out = zarr_io.write_zarr(output_bucket, fused_shape_5d, output_path) renderer = ZarrFusion(zarr_params=fusion_zarr, - tile_layout=self.tile_layout, - fine_tile_mesh=fusion_mesh, - fine_mesh_xy_to_index=fine_mesh_xy_to_index, + tile_mesh_path=rescaled_mesh_path, stride_zyx=fusion_stride_zyx, offset_xyz=crop_offset, parallelism=parallelism) @@ -334,25 +347,27 @@ def _run_fusion(self, print('box {i}: {t1:0.2f} render {t2:0.2f} write'.format(i=i, t1=t_render - t_start, t2=t_write - t_render)) + # TODO fix this too def run_fusion_on_coarse_mesh(self, - output_cloud_storage: zarr_utils.CloudStorage, + output_cloud_storage: zarr_io.CloudStorage, output_bucket: str, output_path: str, downsample_exp: int, cx: np.ndarray, cy: np.ndarray, coarse_mesh: np.ndarray, - stride_zyx: tuple[int, int, int] = (20, 20, 20), + stride_zyx: tuple[int, int, int] = (20, 20, 20), + save_mesh_path: str = "solved_meshes.npy", parallelism: int = 16) -> None: """ Transforms coarse mesh into fine mesh before passing along to ZarrStitcher._run_fusion(...) """ - # Fine Mesh Tile Index + # Create Fine Mesh Tile Index fine_mesh_xy_to_index = {(tx, ty): i for i, (tx, ty) in enumerate(self.tile_map.keys())} - # Fine Mesh + # Convert Coarse Mesh into Fine Mesh dim = len(stride_zyx) mesh_shape = (np.array(self.tile_size_xyz[::-1]) // stride_zyx).tolist() fine_mesh = np.zeros([dim, len(fine_mesh_xy_to_index)] + mesh_shape, dtype=np.float32) @@ -360,50 +375,25 @@ def run_fusion_on_coarse_mesh(self, fine_mesh[:, fine_mesh_xy_to_index[tx, ty], ...] = coarse_mesh[:, 0, ty, tx].reshape( (dim,) + (1,) * dim) - self._run_fusion(output_cloud_storage, + # Save the mesh/mesh index map + np.savez_compressed(save_mesh_path, + x=fine_mesh, + key_to_idx=fine_mesh_xy_to_index, + stride_zyx=stride_zyx) + + self.run_fusion(output_cloud_storage, output_bucket, output_path, downsample_exp, cx, cy, - fine_mesh, - fine_mesh_xy_to_index, - stride_zyx, - parallelism) - - - def run_fusion_on_fine_mesh(self, - output_cloud_storage: zarr_utils.CloudStorage, - output_bucket: str, - output_path: str, - downsample_exp: int, - cx: np.ndarray, - cy: np.ndarray, - fine_mesh: np.ndarray, - fine_mesh_xy_to_index: dict[tuple[int, int], int], - stride_zyx: tuple[int, int, int], - parallelism: int = 16 - ) -> None: - """ - Simply passes all input parameters to - private method ZarrStitcher._run_fusion(...) - """ - - self._run_fusion(output_cloud_storage, - output_bucket, - output_path, - downsample_exp, - cx, - cy, - fine_mesh, - fine_mesh_xy_to_index, - stride_zyx, + save_mesh_path, parallelism) if __name__ == '__main__': # Example set of Application Inputs - cloud_storage = zarr_utils.CloudStorage.S3 + cloud_storage = zarr_io.CloudStorage.S3 bucket = 'aind-open-data' dataset_path = 'diSPIM_647459_2022-12-07_00-00-00/diSPIM.zarr' downsample_exp = 2 @@ -411,30 +401,31 @@ def run_fusion_on_fine_mesh(self, 'tile_X_0001_Y_0000_Z_0000_CH_0405_cam1.zarr'] tile_layout = np.array([[1], [0]]) - input_zarr = zarr_utils.ZarrDataset(cloud_storage=cloud_storage, + input_zarr = zarr_io.ZarrDataset(cloud_storage=cloud_storage, bucket=bucket, dataset_path=dataset_path, tile_names=tile_names, + tile_layout=tile_layout, downsample_exp=downsample_exp) # Application Outputs - output_cloud_storage = zarr_utils.CloudStorage.GCS + output_cloud_storage = zarr_io.CloudStorage.GCS output_bucket = 'YOUR-BUCKET-HERE' - output_path = 'YOU-OUTPUT-NAME.zarr' + output_path = 'YOUR-OUTPUT-NAME.zarr' # Processing - zarr_stitcher = ZarrStitcher(input_zarr, tile_layout) + save_mesh_path = 'solved_mesh.npy' + zarr_stitcher = ZarrStitcher(input_zarr) cx, cy, coarse_mesh = zarr_stitcher.run_coarse_registration() - fine_mesh, fine_mesh_xy_to_index, stride_zyx = zarr_stitcher.run_fine_registration(cx, - cy, - coarse_mesh, - stride_zyx=(20, 20, 20)) - zarr_stitcher._run_fusion(output_cloud_storage=output_cloud_storage, + zarr_stitcher.run_fine_registration(cx, + cy, + coarse_mesh, + stride_zyx=(20, 20, 20), + save_mesh_path=save_mesh_path) + zarr_stitcher.run_fusion(output_cloud_storage=output_cloud_storage, output_bucket=output_bucket, output_path=output_path, downsample_exp=0, # For full resolution fusion. cx=cx, cy=cy, - fine_mesh=fine_mesh, - fine_mesh_xy_to_index=fine_mesh_xy_to_index, - stride_zyx=stride_zyx) \ No newline at end of file + tile_mesh_path=save_mesh_path) \ No newline at end of file From 9fe84c74afd686b2b3df618a915ddae2d91e91da Mon Sep 17 00:00:00 2001 From: jwong-nd Date: Thu, 6 Jul 2023 18:47:33 +0000 Subject: [PATCH 3/7] env setup, temporary files --- .gitignore | 2 ++ sofima_env.sh | 5 +++ tmp.py | 7 +++++ zarr/zarr_register_and_fuse_3d.py | 52 ++++++++++++++++++++++--------- 4 files changed, 51 insertions(+), 15 deletions(-) create mode 100755 sofima_env.sh create mode 100644 tmp.py diff --git a/.gitignore b/.gitignore index aa58525..f9baae0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ __pycache__ dist sofima.egg-info +_version.py +*.npz \ No newline at end of file diff --git a/sofima_env.sh b/sofima_env.sh new file mode 100755 index 0000000..8fc0d7f --- /dev/null +++ b/sofima_env.sh @@ -0,0 +1,5 @@ +#!/bin/bash +conda create --name py311 -c conda-forge python=3.11 -y +conda run -n py311 pip install git+https://github.com/google-research/sofima +conda run -n py311 pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +conda run -n py311 pip install tensorstore \ No newline at end of file diff --git a/tmp.py b/tmp.py new file mode 100644 index 0000000..fccf1ba --- /dev/null +++ b/tmp.py @@ -0,0 +1,7 @@ +from sofima.zarr import zarr_io + +output_bucket = 'sofima-test-bucket' +fused_shape_5d = [1, 1, 3543, 867, 576] +output_path = 'tmp.zarr' + +ds_out = zarr_io.write_zarr(output_bucket, fused_shape_5d, output_path) \ No newline at end of file diff --git a/zarr/zarr_register_and_fuse_3d.py b/zarr/zarr_register_and_fuse_3d.py index 785a0b2..0ab7609 100644 --- a/zarr/zarr_register_and_fuse_3d.py +++ b/zarr/zarr_register_and_fuse_3d.py @@ -33,6 +33,7 @@ def __init__(self, tstore): self.tstore = tstore def __getitem__(self, ind): + print(ind) # FIXME: remove later return np.array(self.tstore[ind]) def __getattr__(self, attr): @@ -99,7 +100,6 @@ def __init__(self, for x, tile_id in enumerate(row): self.tile_map[(x, y)] = self.tile_volumes[tile_id] - def run_coarse_registration(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Runs coarse registration. @@ -155,8 +155,20 @@ def run_fine_registration(self, # For axis 0, subtract tile_size x from the offset[0] # For axis 1, subtract tile_size y from the offset[1] # Tile size is readded inside of stitch_elastic.compute_flow_map3d. - cx[:, 0, :, :] = cx[:, 0, :, :] - np.array([self.tile_size_xyz[0], 0, 0]) - cy[:, 0, :, :] = cy[:, 0, :, :] - np.array([0, self.tile_size_xyz[1], 0]) + cx[np.isnan(cx)] = 0 + cy[np.isnan(cy)] = 0 + # cx[:, :, 1:, 1:] = cx[:, :, 1:, 1:] - np.array([self.tile_size_xyz[0], 0, 0]) # cx[b, ...] is in xyz order + # cy[:, :, 1:, 1:] = cy[:, :, 1:, 1:] - np.array([0, self.tile_size_xyz[1], 0]) + + for y in range(0, cx.shape[-2:][0]): + for x in range(0, cx.shape[-2:][1] - 1): + cx[0, 0, y, x] = cx[0, 0, y, x] - self.tile_size_xyz[0] + cx[1, 0, y, x] = cx[1, 0, y, x] + + for y in range(0, cy.shape[-2:][0] - 1): + for x in range(0, cy.shape[-2:][1]): + cy[0, 0, y, x] = cy[0, 0, y, x] + cy[1, 0, y, x] = cy[1, 0, y, x] - self.tile_size_xyz[1] flow_x, offsets_x = stitch_elastic.compute_flow_map3d(_tile_map, self.tile_size_xyz, @@ -232,10 +244,10 @@ def run_fusion(self, Multithreading. """ - data = np.load(tile_mesh_path) + data = np.load(tile_mesh_path, allow_pickle=True) fine_mesh = data['x'] - fine_mesh_xy_to_index = data['key_to_idx'] - stride_zyx = data['stride_zyx'] + fine_mesh_xy_to_index = data['key_to_idx'].item() # extract the dictionary + stride_zyx = tuple(data['stride_zyx']) if output_cloud_storage == zarr_io.CloudStorage.S3: raise NotImplementedError( @@ -410,22 +422,32 @@ def run_fusion_on_coarse_mesh(self, # Application Outputs output_cloud_storage = zarr_io.CloudStorage.GCS - output_bucket = 'YOUR-BUCKET-HERE' - output_path = 'YOUR-OUTPUT-NAME.zarr' + # output_bucket = 'YOUR-BUCKET-HERE' + # output_path = 'YOUR-OUTPUT-NAME.zarr' + output_bucket = 'sofima-test-bucket-2' + # output_path = 'fused_level_2_refactor.zarr' + output_path = 'tmp.zarr' + + # What test runs do I need? + # Low res 2 defintely -- main path + # Low res 1 -- main path + # Low res 2 off path + # Processing - save_mesh_path = 'solved_mesh.npy' + save_mesh_path = 'solved_mesh_refactor.npz' zarr_stitcher = ZarrStitcher(input_zarr) cx, cy, coarse_mesh = zarr_stitcher.run_coarse_registration() - zarr_stitcher.run_fine_registration(cx, - cy, - coarse_mesh, - stride_zyx=(20, 20, 20), - save_mesh_path=save_mesh_path) + # zarr_stitcher.run_fine_registration(cx, + # cy, + # coarse_mesh, + # stride_zyx=(20, 20, 20), + # save_mesh_path=save_mesh_path) + zarr_stitcher.run_fusion(output_cloud_storage=output_cloud_storage, output_bucket=output_bucket, output_path=output_path, - downsample_exp=0, # For full resolution fusion. + downsample_exp=2, # For full resolution fusion. cx=cx, cy=cy, tile_mesh_path=save_mesh_path) \ No newline at end of file From 09fad0ba379920b5ea7704f09d6c645d4ebfc266 Mon Sep 17 00:00:00 2001 From: jwong-nd Date: Fri, 7 Jul 2023 18:01:42 +0000 Subject: [PATCH 4/7] tested changes, removed tmp comments/files --- tmp.py | 7 --- zarr/zarr_register_and_fuse_3d.py | 84 ++++++++++++++++++------------- 2 files changed, 50 insertions(+), 41 deletions(-) delete mode 100644 tmp.py diff --git a/tmp.py b/tmp.py deleted file mode 100644 index fccf1ba..0000000 --- a/tmp.py +++ /dev/null @@ -1,7 +0,0 @@ -from sofima.zarr import zarr_io - -output_bucket = 'sofima-test-bucket' -fused_shape_5d = [1, 1, 3543, 867, 576] -output_path = 'tmp.zarr' - -ds_out = zarr_io.write_zarr(output_bucket, fused_shape_5d, output_path) \ No newline at end of file diff --git a/zarr/zarr_register_and_fuse_3d.py b/zarr/zarr_register_and_fuse_3d.py index 0ab7609..716935a 100644 --- a/zarr/zarr_register_and_fuse_3d.py +++ b/zarr/zarr_register_and_fuse_3d.py @@ -33,7 +33,6 @@ def __init__(self, tstore): self.tstore = tstore def __getitem__(self, ind): - print(ind) # FIXME: remove later return np.array(self.tstore[ind]) def __getattr__(self, attr): @@ -60,12 +59,12 @@ def __init__(self, stride_zyx: tuple[int, int, int], offset_xyz: tuple[float, float, float], parallelism=16) -> None: - super().__init__(zarr_params.tile_layout, - tile_mesh_path, - "", - stride_zyx, - offset_xyz, - parallelism) + super().__init__(tile_map=zarr_params.tile_layout, + tile_mesh_path=tile_mesh_path, + tile_pattern_path="", + stride=stride_zyx, + offset=offset_xyz, + parallelism=parallelism) self.zarr_params = zarr_params def _open_tile_volume(self, tile_id: int): @@ -236,10 +235,11 @@ def run_fusion(self, output_cloud_storage, output_bucket, output_path: Output storage parameters downsample_exp: - Desired output resolution, 0 for highest resolution. - fine_mesh, fine_mesh_xy_to_index, stride_zyx: - Fine mesh offsets and accompanying metadata, - output of coarse/fine registration. + Desired output resolution level, 0 for highest resolution. + cx, cy: + Output of coarse registration + tile_mesh_path: + Output of elastic registration parallelism: Multithreading. """ @@ -264,6 +264,7 @@ def run_fusion(self, self.input_zarr.bucket, self.input_zarr.dataset_path, self.input_zarr.tile_names, + self.input_zarr.tile_layout, downsample_exp) # Rescale fine mesh, stride @@ -359,7 +360,6 @@ def run_fusion(self, print('box {i}: {t1:0.2f} render {t2:0.2f} write'.format(i=i, t1=t_render - t_start, t2=t_write - t_render)) - # TODO fix this too def run_fusion_on_coarse_mesh(self, output_cloud_storage: zarr_io.CloudStorage, output_bucket: str, @@ -373,17 +373,31 @@ def run_fusion_on_coarse_mesh(self, parallelism: int = 16) -> None: """ Transforms coarse mesh into fine mesh before - passing along to ZarrStitcher._run_fusion(...) + passing along to ZarrStitcher.run_fusion(...) + + Inputs: + output_cloud_storage, output_bucket, output_path: + Output storage parameters + downsample_exp: + Desired output resolution level, 0 for highest resolution. + cx, cy, coarse_mesh: + Output of coarse registration + stride_zyx: + Grid size of elastic/fine mesh + save_mesh_path: + Output path to save elastic mesh. + parallelism: + Fusion multithreading. """ # Create Fine Mesh Tile Index - fine_mesh_xy_to_index = {(tx, ty): i for i, (tx, ty) in enumerate(self.tile_map.keys())} + fine_mesh_xy_to_index = {(tx, ty): i for i, (tx, ty) in enumerate(list(self.tile_map.keys()))} # Convert Coarse Mesh into Fine Mesh dim = len(stride_zyx) mesh_shape = (np.array(self.tile_size_xyz[::-1]) // stride_zyx).tolist() fine_mesh = np.zeros([dim, len(fine_mesh_xy_to_index)] + mesh_shape, dtype=np.float32) - for (tx, ty) in self.tile_map.keys(): + for (tx, ty) in list(self.tile_map.keys()): fine_mesh[:, fine_mesh_xy_to_index[tx, ty], ...] = coarse_mesh[:, 0, ty, tx].reshape( (dim,) + (1,) * dim) @@ -422,32 +436,34 @@ def run_fusion_on_coarse_mesh(self, # Application Outputs output_cloud_storage = zarr_io.CloudStorage.GCS - # output_bucket = 'YOUR-BUCKET-HERE' - # output_path = 'YOUR-OUTPUT-NAME.zarr' - output_bucket = 'sofima-test-bucket-2' - # output_path = 'fused_level_2_refactor.zarr' - output_path = 'tmp.zarr' - - # What test runs do I need? - # Low res 2 defintely -- main path - # Low res 1 -- main path - # Low res 2 off path - + output_bucket = 'YOUR-BUCKET-HERE' + output_path = 'YOUR-OUTPUT-NAME.zarr' + # output_bucket = 'sofima-test-bucket' + # output_path = 'fused_level_2_refactor_skip.zarr' # Processing save_mesh_path = 'solved_mesh_refactor.npz' zarr_stitcher = ZarrStitcher(input_zarr) cx, cy, coarse_mesh = zarr_stitcher.run_coarse_registration() - # zarr_stitcher.run_fine_registration(cx, - # cy, - # coarse_mesh, - # stride_zyx=(20, 20, 20), - # save_mesh_path=save_mesh_path) - + zarr_stitcher.run_fine_registration(cx, + cy, + coarse_mesh, + stride_zyx=(20, 20, 20), + save_mesh_path=save_mesh_path) zarr_stitcher.run_fusion(output_cloud_storage=output_cloud_storage, output_bucket=output_bucket, output_path=output_path, - downsample_exp=2, # For full resolution fusion. + downsample_exp=1, # 0 for full resolution fusion. cx=cx, cy=cy, - tile_mesh_path=save_mesh_path) \ No newline at end of file + tile_mesh_path=save_mesh_path) + + # zarr_stitcher.run_fusion_on_coarse_mesh(output_cloud_storage=output_cloud_storage, + # output_bucket=output_bucket, + # output_path=output_path, + # downsample_exp=2, + # cx=cx, + # cy=cy, + # coarse_mesh=coarse_mesh, + # stride_zyx=(20, 20, 20), + # save_mesh_path='solved_mesh_refactor_skip.npz') \ No newline at end of file From cabcaf0410c08584476c0da22617335bf9bc1f95 Mon Sep 17 00:00:00 2001 From: jwong-nd Date: Fri, 7 Jul 2023 18:06:02 +0000 Subject: [PATCH 5/7] add zarr package to setup.cfg --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index ddf3920..f4927ee 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,7 +16,7 @@ classifiers = [options] package_dir = sofima = . -packages = sofima, sofima.processor +packages = sofima, sofima.processor, sofima.zarr python_requires = >=3.9 install_requires = connectomics From 4c03e41eb5301a4c273990e6411ec846dcf4d5c8 Mon Sep 17 00:00:00 2001 From: jwong-nd Date: Wed, 12 Jul 2023 23:04:47 +0000 Subject: [PATCH 6/7] remove relative import --- zarr/zarr_register_and_fuse_3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zarr/zarr_register_and_fuse_3d.py b/zarr/zarr_register_and_fuse_3d.py index 716935a..4158ae7 100644 --- a/zarr/zarr_register_and_fuse_3d.py +++ b/zarr/zarr_register_and_fuse_3d.py @@ -13,7 +13,7 @@ from sofima import coarse_registration, flow_utils, stitch_elastic, stitch_rigid, map_utils, mesh from sofima.processor import warp -import zarr_io +from sofima.zarr import zarr_io # NOTE: # - SOFIMA/ZarrStitcher follows following basis convention: From 96f7eda39d3bc4393ef926c86f2772186f3f4195 Mon Sep 17 00:00:00 2001 From: jwong-nd Date: Wed, 12 Jul 2023 23:44:11 +0000 Subject: [PATCH 7/7] resolve field error --- zarr/zarr_register_and_fuse_3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zarr/zarr_register_and_fuse_3d.py b/zarr/zarr_register_and_fuse_3d.py index 4158ae7..033b200 100644 --- a/zarr/zarr_register_and_fuse_3d.py +++ b/zarr/zarr_register_and_fuse_3d.py @@ -95,7 +95,7 @@ def __init__(self, self.tile_layout = input_zarr.tile_layout self.tile_map: dict[tuple[int, int], ts.TensorStore] = {} - for y, row in enumerate(tile_layout): + for y, row in enumerate(self.tile_layout): for x, tile_id in enumerate(row): self.tile_map[(x, y)] = self.tile_volumes[tile_id]