Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for Zarr stitching/fusion #41

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__pycache__
dist
sofima.egg-info
_version.py
*.npz
120 changes: 120 additions & 0 deletions coarse_registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import functools as ft
import gc
import jax
import numpy as np
import tensorstore as ts

from sofima import flow_field


QUERY_R_ORTHO = 100
QUERY_OVERLAP_OFFSET = 0 # Overlap = 'starting line' in neighboring tile
QUERY_R_OVERLAP = 100

SEARCH_OVERLAP = 300 # Boundary - overlap = 'starting line' in search tile
SEARCH_R_ORTHO = 100


@ft.partial(jax.jit)
def _estimate_relative_offset_zyx(base,
kernel
) -> list[float, float, float]:
# Calculate FFT: left = base, right = kernel
xc = flow_field.masked_xcorr(base, kernel, use_jax=True, dim=3)
xc = xc.astype(np.float32)
xc = xc[None, ...]

# Find strongest peak in FFT, pass in FFT image center
r = flow_field._batched_peaks(xc,
((xc.shape[1] + 1) // 2, (xc.shape[2] + 1) // 2, xc.shape[3] // 2),
min_distance=2,
threshold_rel=0.5)

# r returns a list, relative offset is here
relative_offset_xyz = r[0][0:3]
return [relative_offset_xyz[2], relative_offset_xyz[1], relative_offset_xyz[0]]


def _estimate_h_offset_zyx(left_tile: ts.TensorStore,
right_tile: ts.TensorStore
) -> tuple[list[float], float]:
tile_size_xyz = left_tile.shape
mz = tile_size_xyz[2] // 2
my = tile_size_xyz[1] // 2

# Search Space, fixed
left = left_tile[tile_size_xyz[0]-SEARCH_OVERLAP:,
my-SEARCH_R_ORTHO:my+SEARCH_R_ORTHO,
mz-SEARCH_R_ORTHO:mz+SEARCH_R_ORTHO].read().result().T

# Query Patch, scanned against search space
right = right_tile[QUERY_OVERLAP_OFFSET:QUERY_OVERLAP_OFFSET + QUERY_R_OVERLAP*2,
my-QUERY_R_ORTHO:my+QUERY_R_ORTHO,
mz-QUERY_R_ORTHO:mz+QUERY_R_ORTHO].read().result().T

start_zyx = np.array(left.shape) // 2 - np.array(right.shape) // 2
pc_init_zyx = np.array([0, 0, tile_size_xyz[0] - SEARCH_OVERLAP + start_zyx[2]])
pc_zyx = np.array(_estimate_relative_offset_zyx(left, right))

return pc_init_zyx + pc_zyx


def _estimate_v_offset_zyx(top_tile: ts.TensorStore,
bot_tile: ts.TensorStore,
) -> tuple[list[float], float]:
tile_size_xyz = top_tile.shape
mz = tile_size_xyz[2] // 2
mx = tile_size_xyz[0] // 2

top = top_tile[mx-SEARCH_R_ORTHO:mx+SEARCH_R_ORTHO,
tile_size_xyz[1]-SEARCH_OVERLAP:,
mz-SEARCH_R_ORTHO:mz+SEARCH_R_ORTHO].read().result().T
bot = bot_tile[mx-QUERY_R_ORTHO:mx+QUERY_R_ORTHO,
0:QUERY_R_OVERLAP*2,
mz-QUERY_R_ORTHO:mz+QUERY_R_ORTHO].read().result().T

start_zyx = np.array(top.shape) // 2 - np.array(bot.shape) // 2
pc_init_zyx = np.array([0, tile_size_xyz[1] - SEARCH_OVERLAP + start_zyx[1], 0])
pc_zyx = np.array(_estimate_relative_offset_zyx(top, bot))

return pc_init_zyx + pc_zyx


def compute_coarse_offsets(tile_layout: np.ndarray,
tile_volumes: list[ts.TensorStore]
) -> tuple[np.ndarray, np.ndarray]:
layout_y, layout_x = tile_layout.shape

# Output Containers, sofima uses cartesian convention
conn_x = np.full((3, 1, layout_y, layout_x), np.nan)
conn_y = np.full((3, 1, layout_y, layout_x), np.nan)

# Row Pairs
for y in range(layout_y):
for x in range(layout_x - 1): # Stop one before the end
left_id = tile_layout[y, x]
right_id = tile_layout[y, x + 1]
left_tile = tile_volumes[left_id]
right_tile = tile_volumes[right_id]

conn_x[:, 0, y, x] = _estimate_h_offset_zyx(left_tile, right_tile)
gc.collect()

print(f'Left Id: {left_id}, Right Id: {right_id}')
print(f'Left: ({y}, {x}), Right: ({y}, {x + 1})', conn_x[:, 0, y, x])

# Column Pairs -- Reversed Loops
for x in range(layout_x):
for y in range(layout_y - 1):
top_id = tile_layout[y, x]
bot_id = tile_layout[y + 1, x]
top_tile = tile_volumes[top_id]
bot_tile = tile_volumes[bot_id]

conn_y[:, 0, y, x] = _estimate_v_offset_zyx(top_tile, bot_tile)
gc.collect()

print(f'Top Id: {top_id}, Bottom Id: {bot_id}')
print(f'Top: ({y}, {x}), Bot: ({y + 1}, {x})', conn_y[:, 0, y, x])

return conn_x, conn_y
6 changes: 3 additions & 3 deletions processor/warp.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please completely revert the changes to this file?

Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def _load_tile_images(
local_rel_box = sub_box.translate(-out_box.start)
local_warp_box = local_rel_box.translate(local_out_box.start)

# Part of the inverted mesh that is needed to render the current
# region of interest.
# Part of the inverted mesh that is needed to render
# the current region of interest.
s = 1.0 / np.array(self._stride)[::-1]
local_map_box = local_warp_box.scale(s).adjusted_by(
start=(-2, -2, -2), end=(2, 2, 2)
Expand Down Expand Up @@ -332,4 +332,4 @@ def process(
ret[norm > 0] /= norm[norm > 0]
ret = ret.astype(self.output_type(subvol.data.dtype))

return self.crop_box_and_data(box, ret[None, ...])
return self.crop_box_and_data(box, ret[None, ...])
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ classifiers =
[options]
package_dir =
sofima = .
packages = sofima, sofima.processor
packages = sofima, sofima.processor, sofima.zarr
python_requires = >=3.9
install_requires =
connectomics
Expand All @@ -28,6 +28,7 @@ install_requires =
opencv-python>=4.5.5.62
scipy>=1.2.3
scikit-image>=0.17.2
tensorstore>=0.1.39

[options.packages.find]
where = .
5 changes: 5 additions & 0 deletions sofima_env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash
conda create --name py311 -c conda-forge python=3.11 -y
conda run -n py311 pip install git+https://github.com/google-research/sofima
conda run -n py311 pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
conda run -n py311 pip install tensorstore
2 changes: 1 addition & 1 deletion stitch_elastic.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please completely revert the changes to this file?

Original file line number Diff line number Diff line change
Expand Up @@ -673,4 +673,4 @@ def compute_target_mesh(
if dim == 2:
return updated[:, : x.shape[-2], : x.shape[-1]]
else:
return updated[:, : x.shape[-3], : x.shape[-2], : x.shape[-1]]
return updated[:, : x.shape[-3], : x.shape[-2], : x.shape[-1]]
14 changes: 14 additions & 0 deletions zarr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# coding=utf-8
# Copyright 2022 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
130 changes: 130 additions & 0 deletions zarr/zarr_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from dataclasses import dataclass
from enum import Enum
import numpy as np
import tensorstore as ts

from sofima import stitch_elastic

class CloudStorage(Enum):
"""
Documented Cloud Storage Options
"""
S3 = 1
GCS = 2


@dataclass
class ZarrDataset:
"""
Parameters for locating Zarr dataset living on the cloud.
Args:
cloud_storage: CloudStorage option
bucket: Name of bucket
dataset_path: Path to directory containing zarr files within bucket
tile_names: List of zarr tiles to include in dataset.
Order of tile_names defines an index that
is expected to be used in tile_layout.
tile_layout: 2D array of indices that defines relative position of tiles.
downsample_exp: Level in image pyramid with each level
downsampling the original resolution by 2**downsmaple_exp.
"""

cloud_storage: CloudStorage
bucket: str
dataset_path: str
tile_names: list[str]
tile_layout: np.ndarray
downsample_exp: int


def open_zarr_gcs(bucket: str, path: str) -> ts.TensorStore:
return ts.open({
'driver': 'zarr',
'kvstore': {
'driver': 'gcs',
'bucket': bucket,
},
'path': path,
}).result()


def open_zarr_s3(bucket: str, path: str) -> ts.TensorStore:
return ts.open({
'driver': 'zarr',
'kvstore': {
'driver': 'http',
'base_url': f'https://{bucket}.s3.us-west-2.amazonaws.com/{path}',
},
}).result()


def load_zarr_data(params: ZarrDataset
) -> tuple[list[ts.TensorStore], stitch_elastic.ShapeXYZ]:
"""
Reads Zarr dataset from input location
and returns list of equally-sized tensorstores
in matching order as ZarrDataset.tile_names and tile size.
Tensorstores are cropped to tiles at origin to the smallest tile in the set.
"""

def load_zarr(bucket: str, tile_location: str) -> ts.TensorStore:
if params.cloud_storage == CloudStorage.S3:
return open_zarr_s3(bucket, tile_location)
else: # cloud == 'gcs'
return open_zarr_gcs(bucket, tile_location)
tile_volumes = []
min_x, min_y, min_z = np.inf, np.inf, np.inf
for t_name in params.tile_names:
tile_location = f"{params.dataset_path}/{t_name}/{params.downsample_exp}"
tile = load_zarr(params.bucket, tile_location)
tile_volumes.append(tile)

_, _, tz, ty, tx = tile.shape
min_x, min_y, min_z = int(np.minimum(min_x, tx)), \
int(np.minimum(min_y, ty)), \
int(np.minimum(min_z, tz))
tile_size_xyz = min_x, min_y, min_z

# Standardize size of tile volumes
for i, tile_vol in enumerate(tile_volumes):
tile_volumes[i] = tile_vol[:, :, :min_z, :min_y, :min_x]

return tile_volumes, tile_size_xyz


def write_zarr(bucket: str, shape: list, path: str):
"""
Args:
bucket: Name of gcs cloud storage bucket
shape: 5D vector in tczyx order, ex: [1, 1, 3551, 576, 576]
path: Output path inside bucket
"""

return ts.open({
'driver': 'zarr',
'dtype': 'uint16',
'kvstore' : {
'driver': 'gcs',
'bucket': bucket,
},
'create': True,
'delete_existing': True,
'path': path,
'metadata': {
'chunks': [1, 1, 128, 256, 256],
'compressor': {
'blocksize': 0,
'clevel': 1,
'cname': 'zstd',
'id': 'blosc',
'shuffle': 1,
},
'dimension_separator': '/',
'dtype': '<u2',
'fill_value': 0,
'filters': None,
'order': 'C',
'shape': shape,
'zarr_format': 2
}
}).result()
Loading