diff --git a/examples/linecollection_event.ipynb b/examples/linecollection_event.ipynb
new file mode 100644
index 000000000..d5aaabacf
--- /dev/null
+++ b/examples/linecollection_event.ipynb
@@ -0,0 +1,302 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "0c32716f-320e-4021-ad60-1c142fe6fd56",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "6d72f23c-0f3a-4b2c-806d-3b239237c725",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "from fastplotlib.graphics import ImageGraphic, LineCollection\n",
+ "from fastplotlib import GridPlot\n",
+ "import pickle"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "a8514bb3-eef5-4fd1-bcbf-9c50173a9a3c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def auto_scale(p):\n",
+ " p.camera.maintain_aspect = False\n",
+ " width, height, depth = np.ptp(p.scene.get_world_bounding_box(), axis=0)\n",
+ " p.camera.width = width\n",
+ " p.camera.height = height\n",
+ "\n",
+ " p.controller.distance = 0\n",
+ " \n",
+ " p.controller.zoom(0.8 / p.controller.zoom_value)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "2fb13990-63fc-4fc6-b5c1-93f8ec4c1572",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "contours = pickle.load(open(\"/home/kushal/caiman_data/contours.pickle\", \"rb\"))[0]\n",
+ "temporal = pickle.load(open(\"/home/kushal/caiman_data/temporal.pickle\", \"rb\"))\n",
+ "temporal += temporal.min()\n",
+ "\n",
+ "# make it a stack of traces\n",
+ "y_zero = 0\n",
+ "sep = 10\n",
+ "for i in range(1, temporal.shape[0]):\n",
+ " y_zero = temporal[i - 1].max()\n",
+ " temporal[i] += y_zero + sep\n",
+ "\n",
+ "# random colors\n",
+ "colors = np.random.rand(len(contours), 4).astype(np.float32)\n",
+ "colors[:, -1] = 1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "654da73f-d20c-4a0f-bd99-13c1a52f5f5a",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "98ad6155b7c34241bd705d0f40bce8c0",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "RFBOutputContext()"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "

initial snapshot
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "8906e4b8e78c465ca05f088f105de2fc",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "JupyterWgpuCanvas()"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# img and contour plot\n",
+ "plot = GridPlot(shape=(1, 2))\n",
+ "\n",
+ "data = np.ones(shape=(175, 175))\n",
+ "\n",
+ "line_collection = LineCollection(data=contours, z_position=[[1]] * len(contours), colors=colors.tolist())\n",
+ "plot[0, 0].add_graphic(line_collection)\n",
+ "\n",
+ "img = ImageGraphic(data=data)\n",
+ "plot[0, 0].add_graphic(img)\n",
+ "\n",
+ "\n",
+ "temporal_coll = LineCollection(data=temporal, colors=colors.tolist())\n",
+ "plot[0, 1].add_graphic(temporal_coll)\n",
+ "\n",
+ "plot.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "0836f4fc-fb3b-44c6-8515-ab8d63dff52b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "line_collection._world_object.parent"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ae5fe95b-88be-48c7-a4a6-51d2818fbff0",
+ "metadata": {},
+ "source": [
+ "# you need to run this to make the stacked lineplot visible, it's easier in the latest master with camera auto-scaling"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "8597de09-94aa-44cd-b480-acc1758a198c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plot[0, 1].controller.distance = 0\n",
+ "auto_scale(plot[0, 1])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "e321cf0d-52f7-4da2-983a-ff10653093bb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "white = list()\n",
+ "for contour in line_collection:\n",
+ " white.append(np.ones(shape=contour.colors.shape))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "01296977-1664-40ae-86fb-ed515fa96f4a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "white_temporal = np.ones((len(contours), 4)).astype(np.float32)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "4430e805-54db-4218-967a-30290ced8ca9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from typing import *"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "c02210db-5347-4e97-a551-f4f362d3910a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def indices_mapper(target: Any, indices: np.array) -> int:\n",
+ " # calculate coms of line collection \n",
+ " \n",
+ " coms = list()\n",
+ "\n",
+ " for contour in target.data:\n",
+ " coors = contour.data[~np.isnan(contour.data).any(axis=1)]\n",
+ " com = coors.mean(axis=0)\n",
+ " coms.append(com)\n",
+ "\n",
+ " # euclidean distance to find closest index of com \n",
+ " indices = np.append(indices, [0])\n",
+ " \n",
+ " ix = np.linalg.norm((coms - indices), axis=1).argsort()[0] \n",
+ " \n",
+ " #return that index to set feature \n",
+ " return ix"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "0fa4033f-323b-421c-84ad-da34a0ac177c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# until we create an event \"color-changed\" (and for other graphic features)\n",
+ "# later we can just use the \"color-changed\" event from contour to change the lineplot or heatmap etc.\n",
+ "def indices_mapper_temporal(target, indices):\n",
+ " # global since we don't have something like \"color changed\"\n",
+ " # as an event which we can used for stakced line plots\n",
+ " global contours\n",
+ " coms = list()\n",
+ "\n",
+ " for contour in contours:\n",
+ " coors = contour[~np.isnan(contour.data).any(axis=1)]\n",
+ " com = coors.mean(axis=0)\n",
+ " coms.append(com)\n",
+ " \n",
+ " ix = np.linalg.norm((np.array(coms) - np.array(indices)), axis=1).argsort()[0]\n",
+ " print(ix)\n",
+ " \n",
+ " #return that index to set feature \n",
+ " return ix"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "0d3a3665-c9e5-42e4-9d61-c02f1f401ee2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "img.link(event_type=\"click\", target=line_collection, feature=\"colors\", new_data=white, indices_mapper=indices_mapper)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "14fdbae1-31b7-4b58-a7e7-a50589f0ff0d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "img.link(event_type=\"click\", target=temporal_coll, feature=\"colors\", new_data=white_temporal, indices_mapper=indices_mapper_temporal)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/lineplot.ipynb b/examples/lineplot.ipynb
index 7561efe88..d00346daf 100644
--- a/examples/lineplot.ipynb
+++ b/examples/lineplot.ipynb
@@ -178,7 +178,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.5"
+ "version": "3.9.2"
}
},
"nbformat": 4,
diff --git a/fastplotlib/graphics/__init__.py b/fastplotlib/graphics/__init__.py
index cad6de8c7..6294637aa 100644
--- a/fastplotlib/graphics/__init__.py
+++ b/fastplotlib/graphics/__init__.py
@@ -12,6 +12,6 @@
"LineGraphic",
"HistogramGraphic",
"HeatmapGraphic",
- "LineCollection",
- "TextGraphic"
+ "TextGraphic",
+ "LineCollection"
]
diff --git a/fastplotlib/graphics/_base.py b/fastplotlib/graphics/_base.py
index a1a2633b9..e6204c5a1 100644
--- a/fastplotlib/graphics/_base.py
+++ b/fastplotlib/graphics/_base.py
@@ -5,6 +5,8 @@
from ..utils import get_colors
from .features import GraphicFeature, DataFeature, ColorFeature, PresentFeature
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
class Graphic:
def __init__(
@@ -46,6 +48,7 @@ def __init__(
self.colors = None
self.name = name
+ self.registered_callbacks = dict()
if n_colors is None:
n_colors = self.data.feature_data.shape[0]
@@ -104,3 +107,74 @@ def __repr__(self):
return f"'{self.name}' fastplotlib.{self.__class__.__name__} @ {hex(id(self))}"
else:
return f"fastplotlib.{self.__class__.__name__} @ {hex(id(self))}"
+
+
+class Interaction(ABC):
+ @property
+ def indices(self) -> Any:
+ return self.indices
+
+ @indices.setter
+ def indices(self, indices: Any):
+ self.indices = indices
+
+ @property
+ @abstractmethod
+ def features(self) -> List[str]:
+ pass
+
+ @abstractmethod
+ def _set_feature(self, feature: str, new_data: Any, indices: Any):
+ pass
+
+ @abstractmethod
+ def _reset_feature(self, feature: str, old_data: Any):
+ pass
+
+ def link(self, event_type: str, target: Any, feature: str, new_data: Any, indices_mapper: callable = None):
+ valid_events = ["click"]
+ if event_type in valid_events:
+ self.world_object.add_event_handler(self.event_handler, event_type)
+ else:
+ raise ValueError("event not possible")
+
+ if isinstance(target.data, List):
+ old_data = list()
+ for line in target.data:
+ old_data.append(getattr(line, feature).copy())
+ else:
+ old_data = getattr(target, feature).copy()
+
+ if event_type in self.registered_callbacks.keys():
+ self.registered_callbacks[event_type].append(
+ CallbackData(target=target, feature=feature, new_data=new_data, old_data=old_data, indices_mapper=indices_mapper))
+ else:
+ self.registered_callbacks[event_type] = list()
+ self.registered_callbacks[event_type].append(
+ CallbackData(target=target, feature=feature, new_data=new_data, old_data=old_data, indices_mapper=indices_mapper))
+
+ def event_handler(self, event):
+ if event.type == "click":
+ # storing click information for each click in self.indices
+ #self.indices(np.array(event.pick_info["index"]))
+ click_info = np.array(event.pick_info["index"])
+ if event.type in self.registered_callbacks.keys():
+ for target_info in self.registered_callbacks[event.type]:
+ # need to map the indices to the target using indices_mapper
+ if target_info.indices_mapper is not None:
+ indices = target_info.indices_mapper(target=target_info.target, indices=click_info)
+ else:
+ indices = None
+ # reset feature of target using stored old data
+ target_info.target._reset_feature(feature=target_info.feature, old_data=target_info.old_data)
+ # set feature of target at indice using new data
+ target_info.target._set_feature(feature=target_info.feature, new_data=target_info.new_data[indices], indices=indices)
+
+@dataclass
+class CallbackData:
+ """Class for keeping track of the info necessary for interactivity after event occurs."""
+ target: Any
+ feature: str
+ new_data: Any
+ old_data: Any
+ indices_mapper: callable = None
diff --git a/fastplotlib/graphics/heatmap.py b/fastplotlib/graphics/heatmap.py
index 2c33564db..103a0fc2e 100644
--- a/fastplotlib/graphics/heatmap.py
+++ b/fastplotlib/graphics/heatmap.py
@@ -53,24 +53,18 @@ def __init__(
):
"""
Create a Heatmap Graphic
-
Parameters
----------
data: array-like, must be 2-dimensional
| array-like, usually numpy.ndarray, must support ``memoryview()``
| Tensorflow Tensors also work _I think_, but not thoroughly tested
-
vmin: int, optional
minimum value for color scaling, calculated from data if not provided
-
vmax: int, optional
maximum value for color scaling, calculated from data if not provided
-
cmap: str, optional
colormap to use to display the image data, default is ``"plasma"``
-
selection_options
-
args:
additional arguments passed to Graphic
kwargs:
@@ -140,4 +134,4 @@ def add_highlight(self, event):
self.world_object.add(self.selection_graphic)
self._highlights.append(self.selection_graphic)
- return rval
+ return rval
\ No newline at end of file
diff --git a/fastplotlib/graphics/image.py b/fastplotlib/graphics/image.py
index 77c531c8a..0f884b3a9 100644
--- a/fastplotlib/graphics/image.py
+++ b/fastplotlib/graphics/image.py
@@ -3,11 +3,11 @@
import numpy as np
import pygfx
-from ._base import Graphic
+from ._base import Graphic, Interaction
from ..utils import quick_min_max, get_cmap_texture
-class ImageGraphic(Graphic):
+class ImageGraphic(Graphic, Interaction):
def __init__(
self,
data: Any,
@@ -72,6 +72,16 @@ def __init__(
pygfx.ImageBasicMaterial(clim=(vmin, vmax), map=get_cmap_texture(cmap))
)
+ @property
+ def features(self) -> List[str]:
+ return ["cmap", "data"]
+
+ def _set_feature(self, feature: str, new_data: Any, indices: Any):
+ pass
+
+ def _reset_feature(self, feature: str, old_data: Any):
+ pass
+
@property
def clim(self) -> Tuple[float, float]:
return self.world_object.material.clim
diff --git a/fastplotlib/graphics/line.py b/fastplotlib/graphics/line.py
index edf99e43c..ea13d4abc 100644
--- a/fastplotlib/graphics/line.py
+++ b/fastplotlib/graphics/line.py
@@ -1,8 +1,9 @@
from typing import *
import numpy as np
import pygfx
+from typing import *
-from ._base import Graphic
+from ._base import Graphic, CallbackData, Interaction
class LineGraphic(Graphic):
@@ -57,5 +58,19 @@ def __init__(
geometry=pygfx.Geometry(positions=self.data.feature_data, colors=self.colors.feature_data),
material=material(thickness=size, vertex_colors=True)
)
-
+
self.world_object.position.z = z_position
+
+ def _set_feature(self, feature: str, new_data: Any, indices: Any = None):
+ if feature in self.features:
+ update_func = getattr(self, f"update_{feature}")
+ update_func(new_data)
+ else:
+ raise ValueError("name arg is not a valid feature")
+
+ def _reset_feature(self, feature: str, old_data: Any):
+ if feature in self.features:
+ update_func = getattr(self, f"update_{feature}")
+ update_func(old_data)
+ else:
+ raise ValueError("name arg is not a valid feature")
diff --git a/fastplotlib/graphics/linecollection.py b/fastplotlib/graphics/linecollection.py
index ec4b1e4dd..7f8adf57b 100644
--- a/fastplotlib/graphics/linecollection.py
+++ b/fastplotlib/graphics/linecollection.py
@@ -1,13 +1,23 @@
import numpy as np
import pygfx
-from typing import Union
-from .line import LineGraphic
+from typing import Union, List
+
+from fastplotlib.graphics.line import LineGraphic
from typing import *
+from fastplotlib.graphics._base import Interaction
+from abc import ABC, abstractmethod
+
+class LineCollection:
+ def __init__(self, data: List[np.ndarray],
+ z_position: Union[List[float], float] = None,
+ size: Union[float, List[float]] = 2.0,
+ colors: Union[List[np.ndarray], np.ndarray] = None,
+ cmap: Union[List[str], str] = None,
+ *args,
+ **kwargs):
-class LineCollection():
- def __init__(self, data: List[np.ndarray], z_position: Union[List[float], float] = None, size: Union[float, List[float]] = 2.0, colors: Union[List[np.ndarray], np.ndarray] = None,
- cmap: Union[List[str], str] = None, *args, **kwargs):
+ self.name = None
if not isinstance(z_position, float) and z_position is not None:
if not len(data) == len(z_position):
@@ -22,7 +32,8 @@ def __init__(self, data: List[np.ndarray], z_position: Union[List[float], float]
if not len(data) == len(cmap):
raise ValueError("args must be the same length")
- self.collection = list()
+ self.data = list()
+ self._world_object = pygfx.Group()
for i, d in enumerate(data):
if isinstance(z_position, list):
@@ -45,10 +56,37 @@ def __init__(self, data: List[np.ndarray], z_position: Union[List[float], float]
else:
_cmap = cmap
- self.collection.append(LineGraphic(d, _z, _size, _colors, _cmap))
+ lg = LineGraphic(d, _z, _size, _colors, _cmap)
+ self.data.append(lg)
+ self._world_object.add(lg.world_object)
+
+ # TODO: make a base class for Collection graphics and put this as a base method
+ @property
+ def world_object(self) -> pygfx.WorldObject:
+ return self._world_object
+
+ @property
+ def features(self) -> List[str]:
+ return ["colors", "data"]
+
+ def _set_feature(self, feature: str, new_data: Any, indices: Any):
+ if feature in self.features:
+ update_func = getattr(self.data[indices], f"update_{feature}")
+ # if indices is a single indices or list of indices
+ self.data[indices].update_colors(new_data)
+ else:
+ raise ValueError("name arg is not a valid feature")
+
+ def _reset_feature(self, feature: str, old_data: Any):
+ if feature in self.features:
+ #update_func = getattr(self, f"update_{feature}")
+ for i, line in enumerate(self.data):
+ line.update_colors(old_data[i])
+ else:
+ raise ValueError("name arg is not a valid feature")
def __getitem__(self, item):
- return self.collection[item]
+ return self.data[item]