diff --git a/docs/api.rst b/docs/api.rst
index dba583af..ae0d78d2 100644
--- a/docs/api.rst
+++ b/docs/api.rst
@@ -8,3 +8,5 @@ plots though the napari user interface.
 .. automodapi:: napari_matplotlib
 
 .. automodapi:: napari_matplotlib.base
+
+.. automodapi:: napari_matplotlib.features
diff --git a/src/napari_matplotlib/base.py b/src/napari_matplotlib/base.py
index 8c717d6a..5687895e 100644
--- a/src/napari_matplotlib/base.py
+++ b/src/napari_matplotlib/base.py
@@ -281,7 +281,9 @@ def __init__(
         napari_viewer: napari.viewer.Viewer,
         parent: Optional[QWidget] = None,
     ):
-        super().__init__(napari_viewer=napari_viewer, parent=parent)
+        NapariMPLWidget.__init__(
+            self, napari_viewer=napari_viewer, parent=parent
+        )
         self.add_single_axes()
 
     def clear(self) -> None:
diff --git a/src/napari_matplotlib/features.py b/src/napari_matplotlib/features.py
new file mode 100644
index 00000000..3e1eb9ba
--- /dev/null
+++ b/src/napari_matplotlib/features.py
@@ -0,0 +1,153 @@
+from typing import Any, Dict, List, Optional, Tuple
+
+import napari
+import napari.layers
+import numpy as np
+import numpy.typing as npt
+import pandas as pd
+from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout
+
+from napari_matplotlib.base import NapariMPLWidget
+from napari_matplotlib.util import Interval
+
+__all__ = ["FeaturesMixin"]
+
+
+class FeaturesMixin(NapariMPLWidget):
+    """
+    Mixin to help with widgets that plot data from a features table stored
+    in a single napari layer.
+
+    This provides:
+
+    - Setup for one or two combo boxes to select features to be plotted.
+    - An ``on_update_layers()`` callback that updates the combo box options
+      when the napari layer selection changes.
+    """
+
+    n_layers_input = Interval(1, 1)
+    # All layers that have a .features attributes
+    input_layer_types = (
+        napari.layers.Labels,
+        napari.layers.Points,
+        napari.layers.Shapes,
+        napari.layers.Tracks,
+        napari.layers.Vectors,
+    )
+
+    def __init__(self, *, ndim: int) -> None:
+        """
+        Parameters
+        ----------
+        ndim : int
+            Number of dimensions that are plotted by the widget.
+            Must be 1 or 2.
+        """
+        assert ndim in [1, 2]
+        self.dims = ["x", "y"][:ndim]
+        # Set up selection boxes
+        self.layout().addLayout(QVBoxLayout())
+
+        self._selectors: Dict[str, QComboBox] = {}
+        for dim in self.dims:
+            self._selectors[dim] = QComboBox()
+            # Re-draw when combo boxes are updated
+            self._selectors[dim].currentTextChanged.connect(self._draw)
+
+            self.layout().addWidget(QLabel(f"{dim}-axis:"))
+            self.layout().addWidget(self._selectors[dim])
+
+    def get_key(self, dim: str) -> Optional[str]:
+        """
+        Get key for a given dimension.
+
+        Parameters
+        ----------
+        dim : str
+            "x" or "y"
+        """
+        if self._selectors[dim].count() == 0:
+            return None
+        else:
+            return self._selectors[dim].currentText()
+
+    def set_key(self, dim: str, value: str) -> None:
+        """
+        Set key for a given dimension.
+
+        Parameters
+        ----------
+        dim : str
+            "x" or "y"
+        value : str
+            Value to set.
+        """
+        assert value in self._get_valid_axis_keys(), (
+            "value must be on of the columns "
+            "of the feature table on the currently seleted layer"
+        )
+        self._selectors[dim].setCurrentText(value)
+        self._draw()
+
+    def _get_valid_axis_keys(self) -> List[str]:
+        """
+        Get the valid axis keys from the features table column names.
+
+        Returns
+        -------
+        axis_keys : List[str]
+            The valid axis keys in the FeatureTable. If the table is empty
+            or there isn't a table, returns an empty list.
+        """
+        if len(self.layers) == 0 or not (hasattr(self.layers[0], "features")):
+            return []
+        else:
+            return self.layers[0].features.keys()
+
+    def _ready_to_plot(self) -> bool:
+        """
+        Return True if selected layer has a feature table we can plot with,
+        and the columns to plot have been selected.
+        """
+        if not hasattr(self.layers[0], "features"):
+            return False
+
+        feature_table = self.layers[0].features
+        valid_keys = self._get_valid_axis_keys()
+        return (
+            feature_table is not None
+            and len(feature_table) > 0
+            and all([self.get_key(dim) in valid_keys for dim in self.dims])
+        )
+
+    def _get_data_names(
+        self,
+    ) -> Tuple[List[npt.NDArray[Any]], List[str]]:
+        """
+        Get the plot data from the ``features`` attribute of the first
+        selected layer.
+
+        Returns
+        -------
+        data : List[np.ndarray]
+            List contains X and Y columns from the FeatureTable. Returns
+            an empty array if nothing to plot.
+        names : List[str]
+            Names for each axis.
+        """
+        feature_table: pd.DataFrame = self.layers[0].features
+
+        names = [str(self.get_key(dim)) for dim in self.dims]
+        data = [np.array(feature_table[key]) for key in names]
+        return data, names
+
+    def on_update_layers(self) -> None:
+        """
+        Called when the layer selection changes by ``self.update_layers()``.
+        """
+        # Clear combobox
+        for dim in self.dims:
+            while self._selectors[dim].count() > 0:
+                self._selectors[dim].removeItem(0)
+            # Add keys for newly selected layer
+            self._selectors[dim].addItems(self._get_valid_axis_keys())
diff --git a/src/napari_matplotlib/scatter.py b/src/napari_matplotlib/scatter.py
index 334f941c..4fa45798 100644
--- a/src/napari_matplotlib/scatter.py
+++ b/src/napari_matplotlib/scatter.py
@@ -1,10 +1,11 @@
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Optional, Tuple
 
 import napari
 import numpy.typing as npt
-from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget
+from qtpy.QtWidgets import QWidget
 
 from .base import SingleAxesWidget
+from .features import FeaturesMixin
 from .util import Interval
 
 __all__ = ["ScatterBaseWidget", "ScatterWidget", "FeaturesScatterWidget"]
@@ -85,144 +86,27 @@ def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
         return x, y, x_axis_name, y_axis_name
 
 
-class FeaturesScatterWidget(ScatterBaseWidget):
+class FeaturesScatterWidget(ScatterBaseWidget, FeaturesMixin):
     """
     Widget to scatter data stored in two layer feature attributes.
     """
 
-    n_layers_input = Interval(1, 1)
-    # All layers that have a .features attributes
-    input_layer_types = (
-        napari.layers.Labels,
-        napari.layers.Points,
-        napari.layers.Shapes,
-        napari.layers.Tracks,
-        napari.layers.Vectors,
-    )
-
     def __init__(
         self,
         napari_viewer: napari.viewer.Viewer,
         parent: Optional[QWidget] = None,
     ):
-        super().__init__(napari_viewer, parent=parent)
-
-        self.layout().addLayout(QVBoxLayout())
-
-        self._selectors: Dict[str, QComboBox] = {}
-        for dim in ["x", "y"]:
-            self._selectors[dim] = QComboBox()
-            # Re-draw when combo boxes are updated
-            self._selectors[dim].currentTextChanged.connect(self._draw)
-
-            self.layout().addWidget(QLabel(f"{dim}-axis:"))
-            self.layout().addWidget(self._selectors[dim])
-
+        ScatterBaseWidget.__init__(self, napari_viewer, parent=parent)
+        FeaturesMixin.__init__(self, ndim=2)
         self._update_layers(None)
 
-    @property
-    def x_axis_key(self) -> Union[str, None]:
-        """
-        Key for the x-axis data.
-        """
-        if self._selectors["x"].count() == 0:
-            return None
-        else:
-            return self._selectors["x"].currentText()
-
-    @x_axis_key.setter
-    def x_axis_key(self, key: str) -> None:
-        self._selectors["x"].setCurrentText(key)
-        self._draw()
-
-    @property
-    def y_axis_key(self) -> Union[str, None]:
-        """
-        Key for the y-axis data.
-        """
-        if self._selectors["y"].count() == 0:
-            return None
-        else:
-            return self._selectors["y"].currentText()
-
-    @y_axis_key.setter
-    def y_axis_key(self, key: str) -> None:
-        self._selectors["y"].setCurrentText(key)
-        self._draw()
-
-    def _get_valid_axis_keys(self) -> List[str]:
-        """
-        Get the valid axis keys from the layer FeatureTable.
-
-        Returns
-        -------
-        axis_keys : List[str]
-            The valid axis keys in the FeatureTable. If the table is empty
-            or there isn't a table, returns an empty list.
-        """
-        if len(self.layers) == 0 or not (hasattr(self.layers[0], "features")):
-            return []
-        else:
-            return self.layers[0].features.keys()
-
-    def _ready_to_scatter(self) -> bool:
-        """
-        Return True if selected layer has a feature table we can scatter with,
-        and the two columns to be scatterd have been selected.
-        """
-        if not hasattr(self.layers[0], "features"):
-            return False
-
-        feature_table = self.layers[0].features
-        valid_keys = self._get_valid_axis_keys()
-        return (
-            feature_table is not None
-            and len(feature_table) > 0
-            and self.x_axis_key in valid_keys
-            and self.y_axis_key in valid_keys
-        )
-
     def draw(self) -> None:
         """
         Scatter two features from the currently selected layer.
         """
-        if self._ready_to_scatter():
+        if self._ready_to_plot():
             super().draw()
 
     def _get_data(self) -> Tuple[npt.NDArray[Any], npt.NDArray[Any], str, str]:
-        """
-        Get the plot data from the ``features`` attribute of the first
-        selected layer.
-
-        Returns
-        -------
-        data : List[np.ndarray]
-            List contains X and Y columns from the FeatureTable. Returns
-            an empty array if nothing to plot.
-        x_axis_name : str
-            The title to display on the x axis. Returns
-            an empty string if nothing to plot.
-        y_axis_name: str
-            The title to display on the y axis. Returns
-            an empty string if nothing to plot.
-        """
-        feature_table = self.layers[0].features
-
-        x = feature_table[self.x_axis_key]
-        y = feature_table[self.y_axis_key]
-
-        x_axis_name = str(self.x_axis_key)
-        y_axis_name = str(self.y_axis_key)
-
-        return x, y, x_axis_name, y_axis_name
-
-    def on_update_layers(self) -> None:
-        """
-        Called when the layer selection changes by ``self.update_layers()``.
-        """
-        # Clear combobox
-        for dim in ["x", "y"]:
-            while self._selectors[dim].count() > 0:
-                self._selectors[dim].removeItem(0)
-            # Add keys for newly selected layer
-            self._selectors[dim].addItems(self._get_valid_axis_keys())
+        data, names = self._get_data_names()
+        return data[0], data[1], names[0], names[1]
diff --git a/src/napari_matplotlib/tests/scatter/test_scatter_features.py b/src/napari_matplotlib/tests/scatter/test_scatter_features.py
index c211a064..0b3f7638 100644
--- a/src/napari_matplotlib/tests/scatter/test_scatter_features.py
+++ b/src/napari_matplotlib/tests/scatter/test_scatter_features.py
@@ -25,8 +25,8 @@ def test_features_scatter_widget_2D(
 
     # Select points data and chosen features
     viewer.layers.selection.add(viewer.layers[0])  # images need to be selected
-    widget.x_axis_key = "feature_0"
-    widget.y_axis_key = "feature_1"
+    widget.set_key("x", "feature_0")
+    widget.set_key("y", "feature_1")
 
     fig = widget.figure
 
@@ -64,9 +64,9 @@ def test_features_scatter_get_data(make_napari_viewer):
     viewer.layers.selection = [labels_layer]
 
     x_column = "feature_0"
-    scatter_widget.x_axis_key = x_column
     y_column = "feature_2"
-    scatter_widget.y_axis_key = y_column
+    scatter_widget.set_key("x", x_column)
+    scatter_widget.set_key("y", y_column)
 
     x, y, x_axis_name, y_axis_name = scatter_widget._get_data()
     np.testing.assert_allclose(x, feature_table[x_column])