diff --git a/.gitignore b/.gitignore index aa58525..f9baae0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ __pycache__ dist sofima.egg-info +_version.py +*.npz \ No newline at end of file diff --git a/coarse_registration.py b/coarse_registration.py new file mode 100644 index 0000000..5a015bf --- /dev/null +++ b/coarse_registration.py @@ -0,0 +1,120 @@ +import functools as ft +import gc +import jax +import numpy as np +import tensorstore as ts + +from sofima import flow_field + + +QUERY_R_ORTHO = 100 +QUERY_OVERLAP_OFFSET = 0 # Overlap = 'starting line' in neighboring tile +QUERY_R_OVERLAP = 100 + +SEARCH_OVERLAP = 300 # Boundary - overlap = 'starting line' in search tile +SEARCH_R_ORTHO = 100 + + +@ft.partial(jax.jit) +def _estimate_relative_offset_zyx(base, + kernel + ) -> list[float, float, float]: + # Calculate FFT: left = base, right = kernel + xc = flow_field.masked_xcorr(base, kernel, use_jax=True, dim=3) + xc = xc.astype(np.float32) + xc = xc[None, ...] + + # Find strongest peak in FFT, pass in FFT image center + r = flow_field._batched_peaks(xc, + ((xc.shape[1] + 1) // 2, (xc.shape[2] + 1) // 2, xc.shape[3] // 2), + min_distance=2, + threshold_rel=0.5) + + # r returns a list, relative offset is here + relative_offset_xyz = r[0][0:3] + return [relative_offset_xyz[2], relative_offset_xyz[1], relative_offset_xyz[0]] + + +def _estimate_h_offset_zyx(left_tile: ts.TensorStore, + right_tile: ts.TensorStore + ) -> tuple[list[float], float]: + tile_size_xyz = left_tile.shape + mz = tile_size_xyz[2] // 2 + my = tile_size_xyz[1] // 2 + + # Search Space, fixed + left = left_tile[tile_size_xyz[0]-SEARCH_OVERLAP:, + my-SEARCH_R_ORTHO:my+SEARCH_R_ORTHO, + mz-SEARCH_R_ORTHO:mz+SEARCH_R_ORTHO].read().result().T + + # Query Patch, scanned against search space + right = right_tile[QUERY_OVERLAP_OFFSET:QUERY_OVERLAP_OFFSET + QUERY_R_OVERLAP*2, + my-QUERY_R_ORTHO:my+QUERY_R_ORTHO, + mz-QUERY_R_ORTHO:mz+QUERY_R_ORTHO].read().result().T + + start_zyx = np.array(left.shape) // 2 - np.array(right.shape) // 2 + pc_init_zyx = np.array([0, 0, tile_size_xyz[0] - SEARCH_OVERLAP + start_zyx[2]]) + pc_zyx = np.array(_estimate_relative_offset_zyx(left, right)) + + return pc_init_zyx + pc_zyx + + +def _estimate_v_offset_zyx(top_tile: ts.TensorStore, + bot_tile: ts.TensorStore, + ) -> tuple[list[float], float]: + tile_size_xyz = top_tile.shape + mz = tile_size_xyz[2] // 2 + mx = tile_size_xyz[0] // 2 + + top = top_tile[mx-SEARCH_R_ORTHO:mx+SEARCH_R_ORTHO, + tile_size_xyz[1]-SEARCH_OVERLAP:, + mz-SEARCH_R_ORTHO:mz+SEARCH_R_ORTHO].read().result().T + bot = bot_tile[mx-QUERY_R_ORTHO:mx+QUERY_R_ORTHO, + 0:QUERY_R_OVERLAP*2, + mz-QUERY_R_ORTHO:mz+QUERY_R_ORTHO].read().result().T + + start_zyx = np.array(top.shape) // 2 - np.array(bot.shape) // 2 + pc_init_zyx = np.array([0, tile_size_xyz[1] - SEARCH_OVERLAP + start_zyx[1], 0]) + pc_zyx = np.array(_estimate_relative_offset_zyx(top, bot)) + + return pc_init_zyx + pc_zyx + + +def compute_coarse_offsets(tile_layout: np.ndarray, + tile_volumes: list[ts.TensorStore] + ) -> tuple[np.ndarray, np.ndarray]: + layout_y, layout_x = tile_layout.shape + + # Output Containers, sofima uses cartesian convention + conn_x = np.full((3, 1, layout_y, layout_x), np.nan) + conn_y = np.full((3, 1, layout_y, layout_x), np.nan) + + # Row Pairs + for y in range(layout_y): + for x in range(layout_x - 1): # Stop one before the end + left_id = tile_layout[y, x] + right_id = tile_layout[y, x + 1] + left_tile = tile_volumes[left_id] + right_tile = tile_volumes[right_id] + + conn_x[:, 0, y, x] = _estimate_h_offset_zyx(left_tile, right_tile) + gc.collect() + + print(f'Left Id: {left_id}, Right Id: {right_id}') + print(f'Left: ({y}, {x}), Right: ({y}, {x + 1})', conn_x[:, 0, y, x]) + + # Column Pairs -- Reversed Loops + for x in range(layout_x): + for y in range(layout_y - 1): + top_id = tile_layout[y, x] + bot_id = tile_layout[y + 1, x] + top_tile = tile_volumes[top_id] + bot_tile = tile_volumes[bot_id] + + conn_y[:, 0, y, x] = _estimate_v_offset_zyx(top_tile, bot_tile) + gc.collect() + + print(f'Top Id: {top_id}, Bottom Id: {bot_id}') + print(f'Top: ({y}, {x}), Bot: ({y + 1}, {x})', conn_y[:, 0, y, x]) + + return conn_x, conn_y \ No newline at end of file diff --git a/processor/warp.py b/processor/warp.py index 7d2fab1..a1833d3 100644 --- a/processor/warp.py +++ b/processor/warp.py @@ -209,8 +209,8 @@ def _load_tile_images( local_rel_box = sub_box.translate(-out_box.start) local_warp_box = local_rel_box.translate(local_out_box.start) - # Part of the inverted mesh that is needed to render the current - # region of interest. + # Part of the inverted mesh that is needed to render + # the current region of interest. s = 1.0 / np.array(self._stride)[::-1] local_map_box = local_warp_box.scale(s).adjusted_by( start=(-2, -2, -2), end=(2, 2, 2) @@ -332,4 +332,4 @@ def process( ret[norm > 0] /= norm[norm > 0] ret = ret.astype(self.output_type(subvol.data.dtype)) - return self.crop_box_and_data(box, ret[None, ...]) + return self.crop_box_and_data(box, ret[None, ...]) \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index f76369c..f4927ee 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,7 +16,7 @@ classifiers = [options] package_dir = sofima = . -packages = sofima, sofima.processor +packages = sofima, sofima.processor, sofima.zarr python_requires = >=3.9 install_requires = connectomics @@ -28,6 +28,7 @@ install_requires = opencv-python>=4.5.5.62 scipy>=1.2.3 scikit-image>=0.17.2 + tensorstore>=0.1.39 [options.packages.find] where = . diff --git a/sofima_env.sh b/sofima_env.sh new file mode 100755 index 0000000..8fc0d7f --- /dev/null +++ b/sofima_env.sh @@ -0,0 +1,5 @@ +#!/bin/bash +conda create --name py311 -c conda-forge python=3.11 -y +conda run -n py311 pip install git+https://github.com/google-research/sofima +conda run -n py311 pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +conda run -n py311 pip install tensorstore \ No newline at end of file diff --git a/stitch_elastic.py b/stitch_elastic.py index 977f55a..a653e66 100644 --- a/stitch_elastic.py +++ b/stitch_elastic.py @@ -673,4 +673,4 @@ def compute_target_mesh( if dim == 2: return updated[:, : x.shape[-2], : x.shape[-1]] else: - return updated[:, : x.shape[-3], : x.shape[-2], : x.shape[-1]] + return updated[:, : x.shape[-3], : x.shape[-2], : x.shape[-1]] \ No newline at end of file diff --git a/zarr/__init__.py b/zarr/__init__.py new file mode 100644 index 0000000..7766812 --- /dev/null +++ b/zarr/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +# Copyright 2022 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/zarr/zarr_io.py b/zarr/zarr_io.py new file mode 100644 index 0000000..6f3331c --- /dev/null +++ b/zarr/zarr_io.py @@ -0,0 +1,130 @@ +from dataclasses import dataclass +from enum import Enum +import numpy as np +import tensorstore as ts + +from sofima import stitch_elastic + +class CloudStorage(Enum): + """ + Documented Cloud Storage Options + """ + S3 = 1 + GCS = 2 + + +@dataclass +class ZarrDataset: + """ + Parameters for locating Zarr dataset living on the cloud. + Args: + cloud_storage: CloudStorage option + bucket: Name of bucket + dataset_path: Path to directory containing zarr files within bucket + tile_names: List of zarr tiles to include in dataset. + Order of tile_names defines an index that + is expected to be used in tile_layout. + tile_layout: 2D array of indices that defines relative position of tiles. + downsample_exp: Level in image pyramid with each level + downsampling the original resolution by 2**downsmaple_exp. + """ + + cloud_storage: CloudStorage + bucket: str + dataset_path: str + tile_names: list[str] + tile_layout: np.ndarray + downsample_exp: int + + +def open_zarr_gcs(bucket: str, path: str) -> ts.TensorStore: + return ts.open({ + 'driver': 'zarr', + 'kvstore': { + 'driver': 'gcs', + 'bucket': bucket, + }, + 'path': path, + }).result() + + +def open_zarr_s3(bucket: str, path: str) -> ts.TensorStore: + return ts.open({ + 'driver': 'zarr', + 'kvstore': { + 'driver': 'http', + 'base_url': f'https://{bucket}.s3.us-west-2.amazonaws.com/{path}', + }, + }).result() + + +def load_zarr_data(params: ZarrDataset + ) -> tuple[list[ts.TensorStore], stitch_elastic.ShapeXYZ]: + """ + Reads Zarr dataset from input location + and returns list of equally-sized tensorstores + in matching order as ZarrDataset.tile_names and tile size. + Tensorstores are cropped to tiles at origin to the smallest tile in the set. + """ + + def load_zarr(bucket: str, tile_location: str) -> ts.TensorStore: + if params.cloud_storage == CloudStorage.S3: + return open_zarr_s3(bucket, tile_location) + else: # cloud == 'gcs' + return open_zarr_gcs(bucket, tile_location) + tile_volumes = [] + min_x, min_y, min_z = np.inf, np.inf, np.inf + for t_name in params.tile_names: + tile_location = f"{params.dataset_path}/{t_name}/{params.downsample_exp}" + tile = load_zarr(params.bucket, tile_location) + tile_volumes.append(tile) + + _, _, tz, ty, tx = tile.shape + min_x, min_y, min_z = int(np.minimum(min_x, tx)), \ + int(np.minimum(min_y, ty)), \ + int(np.minimum(min_z, tz)) + tile_size_xyz = min_x, min_y, min_z + + # Standardize size of tile volumes + for i, tile_vol in enumerate(tile_volumes): + tile_volumes[i] = tile_vol[:, :, :min_z, :min_y, :min_x] + + return tile_volumes, tile_size_xyz + + +def write_zarr(bucket: str, shape: list, path: str): + """ + Args: + bucket: Name of gcs cloud storage bucket + shape: 5D vector in tczyx order, ex: [1, 1, 3551, 576, 576] + path: Output path inside bucket + """ + + return ts.open({ + 'driver': 'zarr', + 'dtype': 'uint16', + 'kvstore' : { + 'driver': 'gcs', + 'bucket': bucket, + }, + 'create': True, + 'delete_existing': True, + 'path': path, + 'metadata': { + 'chunks': [1, 1, 128, 256, 256], + 'compressor': { + 'blocksize': 0, + 'clevel': 1, + 'cname': 'zstd', + 'id': 'blosc', + 'shuffle': 1, + }, + 'dimension_separator': '/', + 'dtype': ' None: + super().__init__(tile_map=zarr_params.tile_layout, + tile_mesh_path=tile_mesh_path, + tile_pattern_path="", + stride=stride_zyx, + offset=offset_xyz, + parallelism=parallelism) + self.zarr_params = zarr_params + + def _open_tile_volume(self, tile_id: int): + if tile_id in self.cache: + return self.cache[tile_id] + + tile_volumes, tile_size_xyz = zarr_io.load_zarr_data(self.zarr_params) + tile = tile_volumes[tile_id] + self.cache[tile_id] = SyncAdapter(tile[0,0,:,:,:]) + return self.cache[tile_id] + + +class ZarrStitcher: + """ + Object wrapper around SOFIMA for operating on Zarr datasets. + """ + + def __init__(self, + input_zarr: zarr_io.ZarrDataset) -> None: + """ + zarr_params: See ZarrDataset, params for input dataset + """ + + self.input_zarr = input_zarr + + self.tile_volumes: list[ts.TensorStore] = [] # 5D tczyx homogenous shape + self.tile_volumes, self.tile_size_xyz = zarr_io.load_zarr_data(input_zarr) + self.tile_layout = input_zarr.tile_layout + + self.tile_map: dict[tuple[int, int], ts.TensorStore] = {} + for y, row in enumerate(self.tile_layout): + for x, tile_id in enumerate(row): + self.tile_map[(x, y)] = self.tile_volumes[tile_id] + + def run_coarse_registration(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Runs coarse registration. + Returns: + cx: tile_layout shape + Each entry represents displacement of current tile towards right neighbor. + cy: tile_layout shape + Each entry represents displacement of current tile towards bottom neighbor. + coarse_mesh: (3, 1, tile_layout) shape + Each entry net displacement of current tile. + """ + + # Custom data loading for coarse registration + _tile_volumes: list[ts.TensorStore] = [] + for vol in self.tile_volumes: + _tile_volumes.append(vol.T[:,:,:,0,0]) + + cx, cy = coarse_registration.compute_coarse_offsets(self.tile_layout, _tile_volumes) + coarse_mesh = stitch_rigid.optimize_coarse_mesh(cx, + cy, + mesh_fn=stitch_rigid.elastic_tile_mesh_3d) + return cx, cy, coarse_mesh + + + def run_fine_registration(self, + cx: np.ndarray, + cy: np.ndarray, + coarse_mesh: np.ndarray, + stride_zyx: tuple[int, int, int], + save_mesh_path: str = "solved_meshes.npy" + ) -> None: + """ + Runs fine registration. + Inputs: + cx: Coarse offsets in x direction, output of coarse registration. + cy: Coarse offsets in y direction, output of coarse registration. + coarse_mesh: Coarse offsets in combined array, output of coarse registration. + stride_zyx: Subdivision of each tile to create fine mesh. + + Outputs (inside of output mesh path): + solved_fine_mesh: Fine mesh containing offsets of each subdivision. + Shape is (3, tile_index, stride_z, stride_y, stride_x). + fine_mesh_xy_to_index: Map of tile coordinates to custom mesh tile index. + stride_zyx: Same as input, by returned as important parameter. + """ + + # Custom data loading for fine registration + _tile_map = {} + for key, tstore in self.tile_map.items(): + _tile_map[key] = SyncAdapter(tstore[0,:,:,:,:]) + + # INPUT FORMATTING: + # For axis 0, subtract tile_size x from the offset[0] + # For axis 1, subtract tile_size y from the offset[1] + # Tile size is readded inside of stitch_elastic.compute_flow_map3d. + cx[np.isnan(cx)] = 0 + cy[np.isnan(cy)] = 0 + # cx[:, :, 1:, 1:] = cx[:, :, 1:, 1:] - np.array([self.tile_size_xyz[0], 0, 0]) # cx[b, ...] is in xyz order + # cy[:, :, 1:, 1:] = cy[:, :, 1:, 1:] - np.array([0, self.tile_size_xyz[1], 0]) + + for y in range(0, cx.shape[-2:][0]): + for x in range(0, cx.shape[-2:][1] - 1): + cx[0, 0, y, x] = cx[0, 0, y, x] - self.tile_size_xyz[0] + cx[1, 0, y, x] = cx[1, 0, y, x] + + for y in range(0, cy.shape[-2:][0] - 1): + for x in range(0, cy.shape[-2:][1]): + cy[0, 0, y, x] = cy[0, 0, y, x] + cy[1, 0, y, x] = cy[1, 0, y, x] - self.tile_size_xyz[1] + + flow_x, offsets_x = stitch_elastic.compute_flow_map3d(_tile_map, + self.tile_size_xyz, + cx, axis=0, + stride=stride_zyx, + patch_size=(80, 80, 80)) + flow_y, offsets_y = stitch_elastic.compute_flow_map3d(_tile_map, + self.tile_size_xyz, + cy, axis=1, + stride=stride_zyx, + patch_size=(80, 80, 80)) + + # Filter patch flows + kwargs = {"min_peak_ratio": 1.4, "min_peak_sharpness": 1.4, + "max_deviation": 5, "max_magnitude": 0, "dim": 3} + fine_x = {k: flow_utils.clean_flow(v, **kwargs) for k, v in flow_x.items()} + fine_y = {k: flow_utils.clean_flow(v, **kwargs) for k, v in flow_y.items()} + + kwargs = {"min_patch_size": 10, "max_gradient": -1, "max_deviation": -1} + fine_x = {k: flow_utils.reconcile_flows([v], **kwargs) for k, v in fine_x.items()} + fine_y = {k: flow_utils.reconcile_flows([v], **kwargs) for k, v in fine_y.items()} + + # Update mesh (convert coarse tile mesh into fine patch mesh) + data_x = (cx[:, 0, ...], fine_x, offsets_x) + data_y = (cy[:, 0, ...], fine_y, offsets_y) + fx, fy, fine_mesh, nbors, fine_mesh_xy_to_index = stitch_elastic.aggregate_arrays( + data_x, data_y, list(self.tile_map.keys()), + coarse_mesh[:, 0, ...], stride=stride_zyx, tile_shape=self.tile_size_xyz[::-1]) + + @jax.jit + def prev_fn(x): + target_fn = ft.partial(stitch_elastic.compute_target_mesh, + x=x, fx=fx, fy=fy, stride=stride_zyx) + x = jax.vmap(target_fn)(nbors) + return jnp.transpose(x, [1, 0, 2, 3, 4]) + + config = mesh.IntegrationConfig(dt=0.001, gamma=0., k0=0.01, k=0.1, stride=stride_zyx, + num_iters=1000, max_iters=20000, stop_v_max=0.001, + dt_max=100, prefer_orig_order=False, + start_cap=0.1, final_cap=10., remove_drift=True) + + solved_fine_mesh, ekin, t = mesh.relax_mesh(fine_mesh, None, config, + prev_fn=prev_fn, mesh_force=mesh.elastic_mesh_3d) + + # Save the mesh/mesh index map + np.savez_compressed(save_mesh_path, + x=solved_fine_mesh, + key_to_idx=fine_mesh_xy_to_index, + stride_zyx=stride_zyx) + + + def run_fusion(self, + output_cloud_storage: zarr_io.CloudStorage, + output_bucket: str, + output_path: str, + downsample_exp: int, + cx: np.ndarray, + cy: np.ndarray, + tile_mesh_path: str, + parallelism: int = 16 + ) -> None: + """ + Runs fusion. + Inputs: + output_cloud_storage, output_bucket, output_path: + Output storage parameters + downsample_exp: + Desired output resolution level, 0 for highest resolution. + cx, cy: + Output of coarse registration + tile_mesh_path: + Output of elastic registration + parallelism: + Multithreading. + """ + + data = np.load(tile_mesh_path, allow_pickle=True) + fine_mesh = data['x'] + fine_mesh_xy_to_index = data['key_to_idx'].item() # extract the dictionary + stride_zyx = tuple(data['stride_zyx']) + + if output_cloud_storage == zarr_io.CloudStorage.S3: + raise NotImplementedError( + 'TensorStore does not support s3 writes.' + ) + + fusion_zarr = self.input_zarr + fusion_mesh = fine_mesh + fusion_stride_zyx = stride_zyx + fusion_tile_size_zyx = self.tile_size_xyz[::-1] + if downsample_exp != self.input_zarr.downsample_exp: + # Reload the data at target resolution + fusion_zarr = zarr_io.ZarrDataset(self.input_zarr.cloud_storage, + self.input_zarr.bucket, + self.input_zarr.dataset_path, + self.input_zarr.tile_names, + self.input_zarr.tile_layout, + downsample_exp) + + # Rescale fine mesh, stride + curr_exp = self.input_zarr.downsample_exp + target_exp = downsample_exp + scale_factor = 2**(curr_exp - target_exp) + fusion_mesh = fine_mesh * scale_factor + fusion_stride_zyx = tuple(np.array(stride_zyx) * scale_factor) + fusion_tile_size_zyx = tuple(np.array(self.tile_size_xyz)[::-1] * scale_factor) + print(f'{scale_factor=}') + + start = np.array([np.inf, np.inf, np.inf]) + map_box = bounding_box.BoundingBox( + start=(0, 0, 0), + size=fusion_mesh.shape[2:][::-1], + ) + fine_mesh_index_to_xy = { + v: k for k, v in fine_mesh_xy_to_index.items() + } + for i in range(0, fusion_mesh.shape[1]): + tx, ty = fine_mesh_index_to_xy[i] + mesh = fusion_mesh[:, i, ...] + tg_box = map_utils.outer_box(mesh, map_box, fusion_stride_zyx) + + out_box = bounding_box.BoundingBox( + start=( + tg_box.start[0] * fusion_stride_zyx[2] + tx * fusion_tile_size_zyx[2], + tg_box.start[1] * fusion_stride_zyx[1] + ty * fusion_tile_size_zyx[1], + tg_box.start[2] * fusion_stride_zyx[0], + ), + size=( + tg_box.size[0] * fusion_stride_zyx[2], + tg_box.size[1] * fusion_stride_zyx[1], + tg_box.size[2] * fusion_stride_zyx[0], + ) + ) + start = np.minimum(start, out_box.start) + print(f'{tg_box=}') + print(f'{out_box=}') + + crop_offset = -start + print(f'{crop_offset=}') + + # Fused shape + cx[np.isnan(cx)] = 0 + cy[np.isnan(cy)] = 0 + x_overlap = cx[2,0,0,0] / self.tile_size_xyz[0] + y_overlap = cy[1,0,0,0] / self.tile_size_xyz[1] + y_shape, x_shape = cx.shape[2], cx.shape[3] + + fused_x = fusion_tile_size_zyx[2] * (1 + ((x_shape - 1) * (1 - x_overlap))) + fused_y = fusion_tile_size_zyx[1] * (1 + ((y_shape - 1) * (1 - y_overlap))) + fused_z = fusion_tile_size_zyx[0] + fused_shape_5d = [1, 1, int(fused_z), int(fused_y), int(fused_x)] + print(f'{fused_shape_5d=}') + + # INPUT FORMATTING: + # Save rescaled mesh back into .npz volume + # as this is the expected input of warp.StitchAndRender3dTiles.process + rescaled_mesh_path = 'rescaled_fusion_mesh.npz' + np.savez_compressed(rescaled_mesh_path, + x=fusion_mesh, + key_to_idx=fine_mesh_xy_to_index, + stride_zyx=fusion_stride_zyx) + + # Perform fusion + ds_out = zarr_io.write_zarr(output_bucket, fused_shape_5d, output_path) + renderer = ZarrFusion(zarr_params=fusion_zarr, + tile_mesh_path=rescaled_mesh_path, + stride_zyx=fusion_stride_zyx, + offset_xyz=crop_offset, + parallelism=parallelism) + + box = bounding_box.BoundingBox(start=(0,0,0), size=ds_out.shape[4:1:-1]) # Needs xyz + gen = box_generator.BoxGenerator(box, (512, 512, 512), (0, 0, 0), True) # These are xyz + renderer.set_effective_subvol_and_overlap((512, 512, 512), (0, 0, 0)) + for i, sub_box in enumerate(gen.boxes): + t_start = time.time() + + # Feed in an empty subvol, with dimensions of sub_box. + inp_subvol = subvolume.Subvolume(np.zeros(sub_box.size[::-1], dtype=np.uint16)[None, ...], sub_box) + ret_subvol = renderer.process(inp_subvol) # czyx + + t_render = time.time() + + # ret_subvol is a 4D CZYX volume + slice = ret_subvol.bbox.to_slice3d() + slice = (0, 0, slice[0], slice[1], slice[2]) + ds_out[slice].write(ret_subvol.data[0, ...]).result() + + t_write = time.time() + + print('box {i}: {t1:0.2f} render {t2:0.2f} write'.format(i=i, t1=t_render - t_start, t2=t_write - t_render)) + + + def run_fusion_on_coarse_mesh(self, + output_cloud_storage: zarr_io.CloudStorage, + output_bucket: str, + output_path: str, + downsample_exp: int, + cx: np.ndarray, + cy: np.ndarray, + coarse_mesh: np.ndarray, + stride_zyx: tuple[int, int, int] = (20, 20, 20), + save_mesh_path: str = "solved_meshes.npy", + parallelism: int = 16) -> None: + """ + Transforms coarse mesh into fine mesh before + passing along to ZarrStitcher.run_fusion(...) + + Inputs: + output_cloud_storage, output_bucket, output_path: + Output storage parameters + downsample_exp: + Desired output resolution level, 0 for highest resolution. + cx, cy, coarse_mesh: + Output of coarse registration + stride_zyx: + Grid size of elastic/fine mesh + save_mesh_path: + Output path to save elastic mesh. + parallelism: + Fusion multithreading. + """ + + # Create Fine Mesh Tile Index + fine_mesh_xy_to_index = {(tx, ty): i for i, (tx, ty) in enumerate(list(self.tile_map.keys()))} + + # Convert Coarse Mesh into Fine Mesh + dim = len(stride_zyx) + mesh_shape = (np.array(self.tile_size_xyz[::-1]) // stride_zyx).tolist() + fine_mesh = np.zeros([dim, len(fine_mesh_xy_to_index)] + mesh_shape, dtype=np.float32) + for (tx, ty) in list(self.tile_map.keys()): + fine_mesh[:, fine_mesh_xy_to_index[tx, ty], ...] = coarse_mesh[:, 0, ty, tx].reshape( + (dim,) + (1,) * dim) + + # Save the mesh/mesh index map + np.savez_compressed(save_mesh_path, + x=fine_mesh, + key_to_idx=fine_mesh_xy_to_index, + stride_zyx=stride_zyx) + + self.run_fusion(output_cloud_storage, + output_bucket, + output_path, + downsample_exp, + cx, + cy, + save_mesh_path, + parallelism) + + +if __name__ == '__main__': + # Example set of Application Inputs + cloud_storage = zarr_io.CloudStorage.S3 + bucket = 'aind-open-data' + dataset_path = 'diSPIM_647459_2022-12-07_00-00-00/diSPIM.zarr' + downsample_exp = 2 + tile_names = ['tile_X_0000_Y_0000_Z_0000_CH_0405_cam1.zarr', + 'tile_X_0001_Y_0000_Z_0000_CH_0405_cam1.zarr'] + tile_layout = np.array([[1], + [0]]) + input_zarr = zarr_io.ZarrDataset(cloud_storage=cloud_storage, + bucket=bucket, + dataset_path=dataset_path, + tile_names=tile_names, + tile_layout=tile_layout, + downsample_exp=downsample_exp) + + # Application Outputs + output_cloud_storage = zarr_io.CloudStorage.GCS + output_bucket = 'YOUR-BUCKET-HERE' + output_path = 'YOUR-OUTPUT-NAME.zarr' + # output_bucket = 'sofima-test-bucket' + # output_path = 'fused_level_2_refactor_skip.zarr' + + # Processing + save_mesh_path = 'solved_mesh_refactor.npz' + zarr_stitcher = ZarrStitcher(input_zarr) + cx, cy, coarse_mesh = zarr_stitcher.run_coarse_registration() + zarr_stitcher.run_fine_registration(cx, + cy, + coarse_mesh, + stride_zyx=(20, 20, 20), + save_mesh_path=save_mesh_path) + zarr_stitcher.run_fusion(output_cloud_storage=output_cloud_storage, + output_bucket=output_bucket, + output_path=output_path, + downsample_exp=1, # 0 for full resolution fusion. + cx=cx, + cy=cy, + tile_mesh_path=save_mesh_path) + + # zarr_stitcher.run_fusion_on_coarse_mesh(output_cloud_storage=output_cloud_storage, + # output_bucket=output_bucket, + # output_path=output_path, + # downsample_exp=2, + # cx=cx, + # cy=cy, + # coarse_mesh=coarse_mesh, + # stride_zyx=(20, 20, 20), + # save_mesh_path='solved_mesh_refactor_skip.npz') \ No newline at end of file