From a7e26238ebf5b460bc7f438ab5ec80c53a2ad415 Mon Sep 17 00:00:00 2001 From: Luke Campagnola Date: Mon, 13 Feb 2023 16:56:58 -0800 Subject: [PATCH 1/3] Add training data collector --- parallax/dialogs.py | 64 ++++++++++++++++++++++++++++++-- parallax/main_window.py | 14 ++++++- parallax/training_data.py | 78 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 151 insertions(+), 5 deletions(-) create mode 100644 parallax/training_data.py diff --git a/parallax/dialogs.py b/parallax/dialogs.py index 176fcba8..97fe678c 100755 --- a/parallax/dialogs.py +++ b/parallax/dialogs.py @@ -1,12 +1,10 @@ -from PyQt5.QtWidgets import QPushButton, QLabel, QRadioButton, QSpinBox +from PyQt5.QtWidgets import QPushButton, QLabel, QSpinBox from PyQt5.QtWidgets import QGridLayout from PyQt5.QtWidgets import QDialog, QLineEdit, QDialogButtonBox from PyQt5.QtCore import Qt from PyQt5.QtGui import QDoubleValidator - +import pyqtgraph as pg import numpy as np -import time -import datetime from .toggle_switch import ToggleSwitch from .helper import FONT_BOLD @@ -343,3 +341,61 @@ def get_params(self): y = float(self.yedit.text()) z = float(self.zedit.text()) return x,y,z + + +class TrainingDataDialog(QDialog): + + def __init__(self, model): + QDialog.__init__(self) + self.model = model + + self.setWindowTitle('Training Data Generator') + + self.stage_label = QLabel('Select a Stage:') + self.stage_label.setAlignment(Qt.AlignCenter) + self.stage_label.setFont(FONT_BOLD) + + self.stage_dropdown = StageDropdown(self.model) + self.stage_dropdown.activated.connect(self.update_status) + + self.img_count_label = QLabel('Image Count:') + self.img_count_label.setAlignment(Qt.AlignCenter) + self.img_count_box = QSpinBox() + self.img_count_box.setMinimum(1) + self.img_count_box.setValue(100) + + self.extent_label = QLabel('Extent:') + self.extent_label.setAlignment(Qt.AlignCenter) + self.extent_spin = pg.SpinBox(value=4e-3, suffix='m', siPrefix=True, bounds=[0.1e-3, 20e-3], dec=True, step=0.5, minStep=1e-6, compactHeight=False) + + self.go_button = QPushButton('Start Data Collection') + self.go_button.setEnabled(False) + self.go_button.clicked.connect(self.go) + + layout = QGridLayout() + layout.addWidget(self.stage_label, 0,0, 1,1) + layout.addWidget(self.stage_dropdown, 0,1, 1,1) + layout.addWidget(self.img_count_label, 1,0, 1,1) + layout.addWidget(self.img_count_box, 1,1, 1,1) + layout.addWidget(self.extent_label, 2,0, 1,1) + layout.addWidget(self.extent_spin, 2,1, 1,1) + layout.addWidget(self.go_button, 4,0, 1,2) + self.setLayout(layout) + + self.setMinimumWidth(300) + + def get_stage(self): + return self.stage_dropdown.current_stage() + + def get_img_count(self): + return self.img_count_box.value() + + def get_extent(self): + return self.extent_spin.value() * 1e6 + + def go(self): + self.accept() + + def update_status(self): + if self.stage_dropdown.is_selected(): + self.go_button.setEnabled(True) diff --git a/parallax/main_window.py b/parallax/main_window.py index 8ea747ba..3ee0b85b 100644 --- a/parallax/main_window.py +++ b/parallax/main_window.py @@ -1,4 +1,4 @@ -from PyQt5.QtWidgets import QApplication, QWidget, QVBoxLayout, QHBoxLayout, QMainWindow, QAction +from PyQt5.QtWidgets import QApplication, QWidget, QHBoxLayout, QGridLayout, QMainWindow, QAction, QSplitter from PyQt5.QtCore import Qt, QTimer from PyQt5.QtGui import QIcon import pyqtgraph.console @@ -10,6 +10,7 @@ from .dialogs import AboutDialog from .rigid_body_transform_tool import RigidBodyTransformTool from .stage_manager import StageManager +from .training_data import TrainingDataCollector class MainWindow(QMainWindow): @@ -18,6 +19,11 @@ def __init__(self, model): QMainWindow.__init__(self) self.model = model + # allow main window to be accessed globally + model.main_window = self + + self.data_collector = None + self.widget = MainWidget(model) self.setCentralWidget(self.widget) @@ -37,6 +43,8 @@ def __init__(self, model): self.rbt_action.triggered.connect(self.launch_rbt) self.console_action = QAction("Python Console") self.console_action.triggered.connect(self.show_console) + self.training_data_action = QAction("Collect Training Data") + self.training_data_action.triggered.connect(self.collect_training_data) self.about_action = QAction("About") self.about_action.triggered.connect(self.launch_about) @@ -56,6 +64,7 @@ def __init__(self, model): self.tools_menu = self.menuBar().addMenu("Tools") self.tools_menu.addAction(self.rbt_action) self.tools_menu.addAction(self.console_action) + self.tools_menu.addAction(self.training_data_action) self.help_menu = self.menuBar().addMenu("Help") self.help_menu.addAction(self.about_action) @@ -106,6 +115,9 @@ def refresh_focus_controllers(self): for screen in self.screens(): screen.update_focus_control_menu() + def collect_training_data(self): + self.data_collector = TrainingDataCollector(self.model) + self.data_collector.start() class MainWidget(QWidget): diff --git a/parallax/training_data.py b/parallax/training_data.py new file mode 100644 index 00000000..7da78430 --- /dev/null +++ b/parallax/training_data.py @@ -0,0 +1,78 @@ +import threading, pickle, os +import numpy as np +from PyQt5 import QtWidgets, QtCore +from .dialogs import TrainingDataDialog + + +class TrainingDataCollector(QtCore.QObject): + def __init__(self, model): + QtCore.QObject.__init__(self) + self.model = model + + def start(self): + dlg = TrainingDataDialog(self.model) + dlg.exec_() + if dlg.result() != dlg.Accepted: + return + + self.stage = dlg.get_stage() + self.img_count = dlg.get_img_count() + self.extent = dlg.get_extent() + self.path = QtWidgets.QFileDialog.getExistingDirectory(parent=None, caption="Select Storage Directory") + if self.path == '': + return + + self.start_pos = self.stage.get_position() + self.stage_cal = self.model.get_calibration(self.stage) + + self.thread = threading.Thread(target=self.thread_run, daemon=True) + self.thread.start() + + def thread_run(self): + meta_file = os.path.join(self.path, 'meta.pkl') + if os.path.exists(meta_file): + # todo: just append + raise Exception("Already data in this folder!") + trials = [] + meta = { + 'calibration': self.stage_cal, + 'stage': self.stage.get_name(), + 'trials': trials, + } + + # move electrode out of fov for background images + pos = self.start_pos.coordinates.copy() + pos[2] += 10000 + self.stage.move_to_target_3d(*pos, block=True) + imgs = self.save_images('background') + meta['background'] = imgs + + for i in range(self.img_count): + + # first image in random location + rnd = np.random.uniform(-self.extent/2, self.extent/2, size=3) + pos1 = self.start_pos.coordinates + rnd + self.stage.move_to_target_3d(*pos1, block=True) + images1 = self.save_images(f'{i:04d}-a') + + # take a second image slightly shifted + pos2 = pos1.copy() + pos2[2] += 10 + self.stage.move_to_target_3d(*pos2, block=True) + images2 = self.save_images(f'{i:04d}-b') + + trials.append([ + {'pos': pos1, 'images': images1}, + {'pos': pos2, 'images': images2}, + ]) + + with open(meta_file, 'wb') as fh: + pickle.dump(meta, fh) + + def save_images(self, name): + images = [] + for camera in self.model.cameras: + filename = f'{name}-{camera.name()}.png' + camera.save_last_image(os.path.join(self.path, filename)) + images.append({'camera': camera.name(), 'image': filename}) + return images \ No newline at end of file From 56ca49a61089676ca971a42e08cffaaa47fb252f Mon Sep 17 00:00:00 2001 From: Luke Campagnola Date: Mon, 13 Feb 2023 17:00:48 -0800 Subject: [PATCH 2/3] minor fixes --- parallax/main_window.py | 2 +- parallax/training_data.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/parallax/main_window.py b/parallax/main_window.py index 3ee0b85b..9e1ed0f5 100644 --- a/parallax/main_window.py +++ b/parallax/main_window.py @@ -1,4 +1,4 @@ -from PyQt5.QtWidgets import QApplication, QWidget, QHBoxLayout, QGridLayout, QMainWindow, QAction, QSplitter +from PyQt5.QtWidgets import QApplication, QWidget, QHBoxLayout, QVBoxLayout, QGridLayout, QMainWindow, QAction, QSplitter from PyQt5.QtCore import Qt, QTimer from PyQt5.QtGui import QIcon import pyqtgraph.console diff --git a/parallax/training_data.py b/parallax/training_data.py index 7da78430..df39e347 100644 --- a/parallax/training_data.py +++ b/parallax/training_data.py @@ -23,7 +23,7 @@ def start(self): return self.start_pos = self.stage.get_position() - self.stage_cal = self.model.get_calibration(self.stage) + self.stage_cal = list(self.model.calibrations.values())[0] self.thread = threading.Thread(target=self.thread_run, daemon=True) self.thread.start() From df1f28fbf3faf9d711b32b9c8786e2b5db9e8fa3 Mon Sep 17 00:00:00 2001 From: Luke Campagnola Date: Mon, 13 Feb 2023 20:23:19 -0800 Subject: [PATCH 3/3] added script for annotating training images --- tools/annotate_training_data.py | 128 ++++++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 tools/annotate_training_data.py diff --git a/tools/annotate_training_data.py b/tools/annotate_training_data.py new file mode 100644 index 00000000..f748a93f --- /dev/null +++ b/tools/annotate_training_data.py @@ -0,0 +1,128 @@ +import pyqtgraph as pg +import json + + +class MainWindow(pg.GraphicsView): + def __init__(self, meta_file, img_files): + super().__init__() + self.img_files = img_files + + self.view = pg.ViewBox() + self.view.invertY() + self.setCentralItem(self.view) + + self.img_item = pg.QtWidgets.QGraphicsPixmapItem() + self.view.addItem(self.img_item) + + self.line_item = pg.QtWidgets.QGraphicsLineItem() + self.line_item.setPen(pg.mkPen('r')) + self.circle_item = pg.QtWidgets.QGraphicsEllipseItem() + self.circle_item.setPen(pg.mkPen('r')) + self.view.addItem(self.line_item) + self.view.addItem(self.circle_item) + + self.next_click = 0 + self.attached_pt = None + self.loaded_file = None + + self.meta_file = meta_file + if os.path.exists(meta_file): + self.meta = json.load(open(meta_file, 'r')) + else: + self.meta = {} + + self.load_image(0) + + def keyPressEvent(self, ev): + if ev.key() == pg.QtCore.Qt.Key_Left: + self.load_image(self.current_index - 1) + elif ev.key() == pg.QtCore.Qt.Key_Right: + self.load_image(self.current_index + 1) + else: + print(ev.key()) + + def mousePressEvent(self, ev): + # print('press', ev) + if ev.button() == pg.QtCore.Qt.LeftButton: + self.attached_pt = self.next_click + self.update_pos(ev.pos()) + ev.accept() + return + # return super().mousePressEvent(ev) + + def mouseReleaseEvent(self, ev): + # print('release', ev) + self.attached_pt = None + self.next_click = (self.next_click + 1) % 2 + + def mouseMoveEvent(self, ev): + # print('move', ev) + self.update_pos(ev.pos()) + ev.accept() + + def update_pos(self, pos): + pos = self.view.mapDeviceToView(pos) + if self.attached_pt == 0: + self.set_pts(pos, None) + elif self.attached_pt == 1: + self.set_pts(None, pos) + else: + return + self.update_meta() + + def set_pts(self, pt1, pt2): + line = self.line_item.line() + if pt1 is not None: + line.setP1(pt1) + self.circle_item.setRect(pt1.x()-10, pt1.y()-10, 20, 20) + self.circle_item.setVisible(True) + if pt2 is not None: + line.setP2(pt2) + self.line_item.setVisible(True) + self.line_item.setLine(line) + + def hide_line(self): + self.line_item.setVisible(False) + self.circle_item.setVisible(False) + + def update_meta(self): + line = self.line_item.line() + self.meta[self.loaded_file] = { + 'pt1': (line.x1(), line.y1()), + 'pt2': (line.x2(), line.y2()), + } + json.dump(self.meta, open(self.meta_file, 'w')) + + def load_image(self, index): + filename = self.img_files[index] + pxm = pg.QtGui.QPixmap() + pxm.load(filename) + self.img_item.setPixmap(pxm) + self.img_item.pxm = pxm + self.view.autoRange(padding=0) + self.current_index = index + self.setWindowTitle(filename) + self.loaded_file = filename + + meta = self.meta.get(filename, {}) + pt1 = meta.get('pt1', None) + pt2 = meta.get('pt2', None) + if None in (pt1, pt2): + self.hide_line() + else: + self.set_pts(pg.QtCore.QPointF(*pt1), pg.QtCore.QPointF(*pt2)) + + +if __name__ == '__main__': + import os, sys + + app = pg.mkQApp() + + meta_file = sys.argv[1] + img_files = sys.argv[2:] + win = MainWindow(meta_file, img_files) + win.resize(1000, 800) + win.show() + + if sys.flags.interactive == 0: + app.exec_() \ No newline at end of file