diff --git a/dash_3d_viewer/slicer.py b/dash_3d_viewer/slicer.py index d367ed4..1938e02 100644 --- a/dash_3d_viewer/slicer.py +++ b/dash_3d_viewer/slicer.py @@ -1,5 +1,5 @@ import numpy as np -from plotly.graph_objects import Figure +from plotly.graph_objects import Figure, Image from dash import Dash from dash.dependencies import Input, Output, State from dash_core_components import Graph, Slider, Store @@ -29,44 +29,39 @@ def __init__(self, app, volume, axis=0, id=None): self._id = id # Get the slice size (width, height), and max index - arr_shape = list(volume.shape) - arr_shape.pop(self._axis) - slice_size = list(reversed(arr_shape)) + # arr_shape = list(volume.shape) + # arr_shape.pop(self._axis) + # slice_size = list(reversed(arr_shape)) self._max_index = self._volume.shape[self._axis] - 1 + # Prep low-res slices + thumbnails = [ + img_array_to_uri(self._slice(i), (32, 32)) + for i in range(self._max_index + 1) + ] + + # Create a placeholder trace + # todo: can add "%{z[0]}", but that would be the scaled value ... + trace = Image(source="", hovertemplate="(%{x}, %{y})") # Create the figure object - fig = Figure() + fig = Figure(data=[trace]) fig.update_layout( template=None, margin=dict(l=0, r=0, b=0, t=0, pad=4), ) fig.update_xaxes( + # range=(0, slice_size[0]), showgrid=False, - range=(0, slice_size[0]), showticklabels=False, zeroline=False, ) fig.update_yaxes( + # range=(slice_size[1], 0), # todo: allow flipping x or y showgrid=False, scaleanchor="x", - range=(slice_size[1], 0), # todo: allow flipping x or y showticklabels=False, zeroline=False, ) - # Add an empty layout image that we can populate from JS. - fig.add_layout_image( - dict( - source="", - xref="x", - yref="y", - x=0, - y=0, - sizex=slice_size[0], - sizey=slice_size[1], - sizing="contain", - layer="below", - ) - ) # Wrap the figure in a graph # todo: or should the user provide this? self.graph = Graph( @@ -88,6 +83,7 @@ def __init__(self, app, volume, axis=0, id=None): Store(id=self._subid("slice-index"), data=volume.shape[self._axis] // 2), Store(id=self._subid("_requested-slice-index"), data=0), Store(id=self._subid("_slice-data"), data=""), + Store(id=self._subid("_slice-data-lowres"), data=thumbnails), ] self._create_server_callbacks(app) @@ -101,7 +97,8 @@ def _slice(self, index): """Sample a slice from the volume.""" indices = [slice(None), slice(None), slice(None)] indices[self._axis] = index - return self._volume[tuple(indices)] + im = self._volume[tuple(indices)] + return (im.astype(np.float32) * (255 / im.max())).astype(np.uint8) def _create_server_callbacks(self, app): """Create the callbacks that run server-side.""" @@ -112,7 +109,6 @@ def _create_server_callbacks(self, app): ) def upload_requested_slice(slice_index): slice = self._slice(slice_index) - slice = (slice.astype(np.float32) * (255 / slice.max())).astype(np.uint8) return [slice_index, img_array_to_uri(slice)] def _create_client_callbacks(self, app): @@ -158,7 +154,7 @@ def _create_client_callbacks(self, app): app.clientside_callback( """ - function handle_incoming_slice(index, index_and_data, ori_figure) { + function handle_incoming_slice(index, index_and_data, ori_figure, lowres) { let new_index = index_and_data[0]; let new_data = index_and_data[1]; // Store data in cache @@ -167,17 +163,18 @@ def _create_client_callbacks(self, app): slice_cache[new_index] = new_data; // Get the data we need *now* let data = slice_cache[index]; + //slice_cache[new_index] = undefined; // todo: disabled cache for now! // Maybe we do not need an update if (!data) { - return window.dash_clientside.no_update; + data = lowres[index]; } - if (data == ori_figure.layout.images[0].source) { + if (data == ori_figure.data[0].source) { return window.dash_clientside.no_update; } // Otherwise, perform update console.log("updating figure"); let figure = {...ori_figure}; - figure.layout.images[0].source = data; + figure.data[0].source = data; return figure; } """.replace( @@ -188,5 +185,8 @@ def _create_client_callbacks(self, app): Input(self._subid("slice-index"), "data"), Input(self._subid("_slice-data"), "data"), ], - [State(self._subid("graph"), "figure")], + [ + State(self._subid("graph"), "figure"), + State(self._subid("_slice-data-lowres"), "data"), + ], ) diff --git a/dash_3d_viewer/utils.py b/dash_3d_viewer/utils.py index 61846e1..68ab52c 100644 --- a/dash_3d_viewer/utils.py +++ b/dash_3d_viewer/utils.py @@ -1,19 +1,25 @@ +import io import random +import base64 import PIL.Image import skimage -from plotly.utils import ImageUriValidator def gen_random_id(n=6): return "".join(random.choice("abcdefghijklmnopqrtsuvwxyz") for i in range(n)) -def img_array_to_uri(img_array): +def img_array_to_uri(img_array, new_size=None): img_array = skimage.util.img_as_ubyte(img_array) # todo: leverage this Plotly util once it becomes part of the public API (also drops the Pillow dependency) # from plotly.express._imshow import _array_to_b64str # return _array_to_b64str(img_array) img_pil = PIL.Image.fromarray(img_array) - uri = ImageUriValidator.pil_image_to_uri(img_pil) - return uri + if new_size: + img_pil.thumbnail(new_size) + # The below was taken from plotly.utils.ImageUriValidator.pil_image_to_uri() + f = io.BytesIO() + img_pil.save(f, format="PNG") + base64_str = base64.b64encode(f.getvalue()).decode() + return "data:image/png;base64," + base64_str