diff --git a/examples/linecollection_event.ipynb b/examples/linecollection_event.ipynb new file mode 100644 index 000000000..d5aaabacf --- /dev/null +++ b/examples/linecollection_event.ipynb @@ -0,0 +1,302 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "0c32716f-320e-4021-ad60-1c142fe6fd56", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "6d72f23c-0f3a-4b2c-806d-3b239237c725", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from fastplotlib.graphics import ImageGraphic, LineCollection\n", + "from fastplotlib import GridPlot\n", + "import pickle" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a8514bb3-eef5-4fd1-bcbf-9c50173a9a3c", + "metadata": {}, + "outputs": [], + "source": [ + "def auto_scale(p):\n", + " p.camera.maintain_aspect = False\n", + " width, height, depth = np.ptp(p.scene.get_world_bounding_box(), axis=0)\n", + " p.camera.width = width\n", + " p.camera.height = height\n", + "\n", + " p.controller.distance = 0\n", + " \n", + " p.controller.zoom(0.8 / p.controller.zoom_value)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "2fb13990-63fc-4fc6-b5c1-93f8ec4c1572", + "metadata": {}, + "outputs": [], + "source": [ + "contours = pickle.load(open(\"/home/kushal/caiman_data/contours.pickle\", \"rb\"))[0]\n", + "temporal = pickle.load(open(\"/home/kushal/caiman_data/temporal.pickle\", \"rb\"))\n", + "temporal += temporal.min()\n", + "\n", + "# make it a stack of traces\n", + "y_zero = 0\n", + "sep = 10\n", + "for i in range(1, temporal.shape[0]):\n", + " y_zero = temporal[i - 1].max()\n", + " temporal[i] += y_zero + sep\n", + "\n", + "# random colors\n", + "colors = np.random.rand(len(contours), 4).astype(np.float32)\n", + "colors[:, -1] = 1" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "654da73f-d20c-4a0f-bd99-13c1a52f5f5a", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "98ad6155b7c34241bd705d0f40bce8c0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "RFBOutputContext()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
initial snapshot
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8906e4b8e78c465ca05f088f105de2fc", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "JupyterWgpuCanvas()" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# img and contour plot\n", + "plot = GridPlot(shape=(1, 2))\n", + "\n", + "data = np.ones(shape=(175, 175))\n", + "\n", + "line_collection = LineCollection(data=contours, z_position=[[1]] * len(contours), colors=colors.tolist())\n", + "plot[0, 0].add_graphic(line_collection)\n", + "\n", + "img = ImageGraphic(data=data)\n", + "plot[0, 0].add_graphic(img)\n", + "\n", + "\n", + "temporal_coll = LineCollection(data=temporal, colors=colors.tolist())\n", + "plot[0, 1].add_graphic(temporal_coll)\n", + "\n", + "plot.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "0836f4fc-fb3b-44c6-8515-ab8d63dff52b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "line_collection._world_object.parent" + ] + }, + { + "cell_type": "markdown", + "id": "ae5fe95b-88be-48c7-a4a6-51d2818fbff0", + "metadata": {}, + "source": [ + "# you need to run this to make the stacked lineplot visible, it's easier in the latest master with camera auto-scaling" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8597de09-94aa-44cd-b480-acc1758a198c", + "metadata": {}, + "outputs": [], + "source": [ + "plot[0, 1].controller.distance = 0\n", + "auto_scale(plot[0, 1])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "e321cf0d-52f7-4da2-983a-ff10653093bb", + "metadata": {}, + "outputs": [], + "source": [ + "white = list()\n", + "for contour in line_collection:\n", + " white.append(np.ones(shape=contour.colors.shape))" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "01296977-1664-40ae-86fb-ed515fa96f4a", + "metadata": {}, + "outputs": [], + "source": [ + "white_temporal = np.ones((len(contours), 4)).astype(np.float32)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "4430e805-54db-4218-967a-30290ced8ca9", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import *" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c02210db-5347-4e97-a551-f4f362d3910a", + "metadata": {}, + "outputs": [], + "source": [ + "def indices_mapper(target: Any, indices: np.array) -> int:\n", + " # calculate coms of line collection \n", + " \n", + " coms = list()\n", + "\n", + " for contour in target.data:\n", + " coors = contour.data[~np.isnan(contour.data).any(axis=1)]\n", + " com = coors.mean(axis=0)\n", + " coms.append(com)\n", + "\n", + " # euclidean distance to find closest index of com \n", + " indices = np.append(indices, [0])\n", + " \n", + " ix = np.linalg.norm((coms - indices), axis=1).argsort()[0] \n", + " \n", + " #return that index to set feature \n", + " return ix" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "0fa4033f-323b-421c-84ad-da34a0ac177c", + "metadata": {}, + "outputs": [], + "source": [ + "# until we create an event \"color-changed\" (and for other graphic features)\n", + "# later we can just use the \"color-changed\" event from contour to change the lineplot or heatmap etc.\n", + "def indices_mapper_temporal(target, indices):\n", + " # global since we don't have something like \"color changed\"\n", + " # as an event which we can used for stakced line plots\n", + " global contours\n", + " coms = list()\n", + "\n", + " for contour in contours:\n", + " coors = contour[~np.isnan(contour.data).any(axis=1)]\n", + " com = coors.mean(axis=0)\n", + " coms.append(com)\n", + " \n", + " ix = np.linalg.norm((np.array(coms) - np.array(indices)), axis=1).argsort()[0]\n", + " print(ix)\n", + " \n", + " #return that index to set feature \n", + " return ix" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0d3a3665-c9e5-42e4-9d61-c02f1f401ee2", + "metadata": {}, + "outputs": [], + "source": [ + "img.link(event_type=\"click\", target=line_collection, feature=\"colors\", new_data=white, indices_mapper=indices_mapper)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "14fdbae1-31b7-4b58-a7e7-a50589f0ff0d", + "metadata": {}, + "outputs": [], + "source": [ + "img.link(event_type=\"click\", target=temporal_coll, feature=\"colors\", new_data=white_temporal, indices_mapper=indices_mapper_temporal)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/lineplot.ipynb b/examples/lineplot.ipynb index 7561efe88..d00346daf 100644 --- a/examples/lineplot.ipynb +++ b/examples/lineplot.ipynb @@ -178,7 +178,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.5" + "version": "3.9.2" } }, "nbformat": 4, diff --git a/fastplotlib/graphics/__init__.py b/fastplotlib/graphics/__init__.py index cad6de8c7..6294637aa 100644 --- a/fastplotlib/graphics/__init__.py +++ b/fastplotlib/graphics/__init__.py @@ -12,6 +12,6 @@ "LineGraphic", "HistogramGraphic", "HeatmapGraphic", - "LineCollection", - "TextGraphic" + "TextGraphic", + "LineCollection" ] diff --git a/fastplotlib/graphics/_base.py b/fastplotlib/graphics/_base.py index a1a2633b9..e6204c5a1 100644 --- a/fastplotlib/graphics/_base.py +++ b/fastplotlib/graphics/_base.py @@ -5,6 +5,8 @@ from ..utils import get_colors from .features import GraphicFeature, DataFeature, ColorFeature, PresentFeature +from abc import ABC, abstractmethod +from dataclasses import dataclass class Graphic: def __init__( @@ -46,6 +48,7 @@ def __init__( self.colors = None self.name = name + self.registered_callbacks = dict() if n_colors is None: n_colors = self.data.feature_data.shape[0] @@ -104,3 +107,74 @@ def __repr__(self): return f"'{self.name}' fastplotlib.{self.__class__.__name__} @ {hex(id(self))}" else: return f"fastplotlib.{self.__class__.__name__} @ {hex(id(self))}" + + +class Interaction(ABC): + @property + def indices(self) -> Any: + return self.indices + + @indices.setter + def indices(self, indices: Any): + self.indices = indices + + @property + @abstractmethod + def features(self) -> List[str]: + pass + + @abstractmethod + def _set_feature(self, feature: str, new_data: Any, indices: Any): + pass + + @abstractmethod + def _reset_feature(self, feature: str, old_data: Any): + pass + + def link(self, event_type: str, target: Any, feature: str, new_data: Any, indices_mapper: callable = None): + valid_events = ["click"] + if event_type in valid_events: + self.world_object.add_event_handler(self.event_handler, event_type) + else: + raise ValueError("event not possible") + + if isinstance(target.data, List): + old_data = list() + for line in target.data: + old_data.append(getattr(line, feature).copy()) + else: + old_data = getattr(target, feature).copy() + + if event_type in self.registered_callbacks.keys(): + self.registered_callbacks[event_type].append( + CallbackData(target=target, feature=feature, new_data=new_data, old_data=old_data, indices_mapper=indices_mapper)) + else: + self.registered_callbacks[event_type] = list() + self.registered_callbacks[event_type].append( + CallbackData(target=target, feature=feature, new_data=new_data, old_data=old_data, indices_mapper=indices_mapper)) + + def event_handler(self, event): + if event.type == "click": + # storing click information for each click in self.indices + #self.indices(np.array(event.pick_info["index"])) + click_info = np.array(event.pick_info["index"]) + if event.type in self.registered_callbacks.keys(): + for target_info in self.registered_callbacks[event.type]: + # need to map the indices to the target using indices_mapper + if target_info.indices_mapper is not None: + indices = target_info.indices_mapper(target=target_info.target, indices=click_info) + else: + indices = None + # reset feature of target using stored old data + target_info.target._reset_feature(feature=target_info.feature, old_data=target_info.old_data) + # set feature of target at indice using new data + target_info.target._set_feature(feature=target_info.feature, new_data=target_info.new_data[indices], indices=indices) + +@dataclass +class CallbackData: + """Class for keeping track of the info necessary for interactivity after event occurs.""" + target: Any + feature: str + new_data: Any + old_data: Any + indices_mapper: callable = None diff --git a/fastplotlib/graphics/heatmap.py b/fastplotlib/graphics/heatmap.py index 2c33564db..103a0fc2e 100644 --- a/fastplotlib/graphics/heatmap.py +++ b/fastplotlib/graphics/heatmap.py @@ -53,24 +53,18 @@ def __init__( ): """ Create a Heatmap Graphic - Parameters ---------- data: array-like, must be 2-dimensional | array-like, usually numpy.ndarray, must support ``memoryview()`` | Tensorflow Tensors also work _I think_, but not thoroughly tested - vmin: int, optional minimum value for color scaling, calculated from data if not provided - vmax: int, optional maximum value for color scaling, calculated from data if not provided - cmap: str, optional colormap to use to display the image data, default is ``"plasma"`` - selection_options - args: additional arguments passed to Graphic kwargs: @@ -140,4 +134,4 @@ def add_highlight(self, event): self.world_object.add(self.selection_graphic) self._highlights.append(self.selection_graphic) - return rval + return rval \ No newline at end of file diff --git a/fastplotlib/graphics/image.py b/fastplotlib/graphics/image.py index 77c531c8a..0f884b3a9 100644 --- a/fastplotlib/graphics/image.py +++ b/fastplotlib/graphics/image.py @@ -3,11 +3,11 @@ import numpy as np import pygfx -from ._base import Graphic +from ._base import Graphic, Interaction from ..utils import quick_min_max, get_cmap_texture -class ImageGraphic(Graphic): +class ImageGraphic(Graphic, Interaction): def __init__( self, data: Any, @@ -72,6 +72,16 @@ def __init__( pygfx.ImageBasicMaterial(clim=(vmin, vmax), map=get_cmap_texture(cmap)) ) + @property + def features(self) -> List[str]: + return ["cmap", "data"] + + def _set_feature(self, feature: str, new_data: Any, indices: Any): + pass + + def _reset_feature(self, feature: str, old_data: Any): + pass + @property def clim(self) -> Tuple[float, float]: return self.world_object.material.clim diff --git a/fastplotlib/graphics/line.py b/fastplotlib/graphics/line.py index edf99e43c..ea13d4abc 100644 --- a/fastplotlib/graphics/line.py +++ b/fastplotlib/graphics/line.py @@ -1,8 +1,9 @@ from typing import * import numpy as np import pygfx +from typing import * -from ._base import Graphic +from ._base import Graphic, CallbackData, Interaction class LineGraphic(Graphic): @@ -57,5 +58,19 @@ def __init__( geometry=pygfx.Geometry(positions=self.data.feature_data, colors=self.colors.feature_data), material=material(thickness=size, vertex_colors=True) ) - + self.world_object.position.z = z_position + + def _set_feature(self, feature: str, new_data: Any, indices: Any = None): + if feature in self.features: + update_func = getattr(self, f"update_{feature}") + update_func(new_data) + else: + raise ValueError("name arg is not a valid feature") + + def _reset_feature(self, feature: str, old_data: Any): + if feature in self.features: + update_func = getattr(self, f"update_{feature}") + update_func(old_data) + else: + raise ValueError("name arg is not a valid feature") diff --git a/fastplotlib/graphics/linecollection.py b/fastplotlib/graphics/linecollection.py index ec4b1e4dd..7f8adf57b 100644 --- a/fastplotlib/graphics/linecollection.py +++ b/fastplotlib/graphics/linecollection.py @@ -1,13 +1,23 @@ import numpy as np import pygfx -from typing import Union -from .line import LineGraphic +from typing import Union, List + +from fastplotlib.graphics.line import LineGraphic from typing import * +from fastplotlib.graphics._base import Interaction +from abc import ABC, abstractmethod + +class LineCollection: + def __init__(self, data: List[np.ndarray], + z_position: Union[List[float], float] = None, + size: Union[float, List[float]] = 2.0, + colors: Union[List[np.ndarray], np.ndarray] = None, + cmap: Union[List[str], str] = None, + *args, + **kwargs): -class LineCollection(): - def __init__(self, data: List[np.ndarray], z_position: Union[List[float], float] = None, size: Union[float, List[float]] = 2.0, colors: Union[List[np.ndarray], np.ndarray] = None, - cmap: Union[List[str], str] = None, *args, **kwargs): + self.name = None if not isinstance(z_position, float) and z_position is not None: if not len(data) == len(z_position): @@ -22,7 +32,8 @@ def __init__(self, data: List[np.ndarray], z_position: Union[List[float], float] if not len(data) == len(cmap): raise ValueError("args must be the same length") - self.collection = list() + self.data = list() + self._world_object = pygfx.Group() for i, d in enumerate(data): if isinstance(z_position, list): @@ -45,10 +56,37 @@ def __init__(self, data: List[np.ndarray], z_position: Union[List[float], float] else: _cmap = cmap - self.collection.append(LineGraphic(d, _z, _size, _colors, _cmap)) + lg = LineGraphic(d, _z, _size, _colors, _cmap) + self.data.append(lg) + self._world_object.add(lg.world_object) + + # TODO: make a base class for Collection graphics and put this as a base method + @property + def world_object(self) -> pygfx.WorldObject: + return self._world_object + + @property + def features(self) -> List[str]: + return ["colors", "data"] + + def _set_feature(self, feature: str, new_data: Any, indices: Any): + if feature in self.features: + update_func = getattr(self.data[indices], f"update_{feature}") + # if indices is a single indices or list of indices + self.data[indices].update_colors(new_data) + else: + raise ValueError("name arg is not a valid feature") + + def _reset_feature(self, feature: str, old_data: Any): + if feature in self.features: + #update_func = getattr(self, f"update_{feature}") + for i, line in enumerate(self.data): + line.update_colors(old_data[i]) + else: + raise ValueError("name arg is not a valid feature") def __getitem__(self, item): - return self.collection[item] + return self.data[item]