Source code for superintendent

"""Interactive machine learning supervision."""
import time
import warnings
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
from typing import Any, Callable, Dict, Optional

import codetiming
import ipywidgets as widgets
import numpy as np
import sklearn.model_selection
from sklearn.base import BaseEstimator

from . import acquisition_functions
from ._compatibility import ignore_widget_on_submit_warning
from .db_queue import DatabaseQueue


__all__ = ["Superintendent"]
__version__ = "0.6.0"


[docs]class Superintendent(widgets.VBox): """ Data point labelling. This is a base class for data point labelling. """ def __init__( self, *, features: Optional[Any] = None, labels: Optional[Any] = None, database_url: str = "sqlite:///:memory:", queue: Optional[DatabaseQueue] = None, labelling_widget: Optional[widgets.Widget] = None, model: Optional[BaseEstimator] = None, eval_method: Optional[Callable] = None, acquisition_function: Optional[Callable] = None, shuffle_prop: float = 0.1, model_preprocess: Optional[Callable] = None, worker_id: bool = False, **kwargs, ): """ Make a class that allows you to label data points. Parameters ---------- features : np.ndarray, pd.DataFrame, sequence This should be either a numpy array, a pandas dataframe, or any other sequence object (e.g. a list). You can also add data later. labels : np.array, pd.Series, sequence The labels for your data, if you have some already. queue : BaseLabellingQueue A queue object. The interface needs to follow the abstract class superintendent.queueing.BaseLabellingQueue. By default, SimpleLabellingQueue (an in-memory queue using python's deque) labelling_widget : Optional[widgets.Widget] An input widget. This needs to follow the interface of the class superintendent.controls.base.SubmissionWidgetMixin model : sklearn.base.BaseEstimator An sklearn-interface compliant model (that implements `fit`, `predict`, `predict_proba` and `score`). eval_method : callable A function that accepts three arguments - model, x, and y - and returns the score of the model. If None, sklearn.model_selection.cross_val_score is used. acquisition_function : callable A function that re-orders data points during active learning. This can be a function that accepts a numpy array (class probabilities) or a string referring to a function from superintendent.acquisition_functions. shuffle_prop : float The proportion of data points that is shuffled when re-ordering during active learning. This is to avoid biasing too much towards the model predictions. model_preprocess : callable A function that accepts x and y data and returns x and y data. y can be None (in which it should return x, None) as this function is used on the un-labelled data too. worker_id : bool | str If True, will check for the worker's ID first - this can be helpful when working in a distributed fashion. If a string, this is used as the worker ID. If False, a UUID is generated for this widget. """ super().__init__(**kwargs) if labelling_widget is None: raise ValueError("No input widget was provided.") self.labelling_widget = labelling_widget self.labelling_widget.on_submit(self._apply_annotation) if hasattr(self.labelling_widget, "on_undo"): self.labelling_widget.on_undo(self._undo) self.queue = queue or DatabaseQueue(connection_string=database_url) if self.queue.url == "sqlite:///:memory:": warnings.warn( "You are using an in-memory SQLite database. Even when " "labelling locally, it is recommended to use a persistend DB. " "You can try passing sqlite:///test.db." ) if features is not None: self.queue.enqueue(features, labels) self.progressbar = widgets.FloatProgress(max=1, description="Progress:") self.timers: Dict[str, codetiming.Timer] = defaultdict( lambda: codetiming.Timer(logger=None) ) self.model = model self.eval_method = eval_method self.acquisition_function = acquisition_function if isinstance(acquisition_function, str): self.acquisition_function = acquisition_functions.functions[ acquisition_function ] self.shuffle_prop = shuffle_prop self.model_preprocess = model_preprocess # if there is a model, we need the interface components for it if self.model: self.retrain_button = widgets.Button( description="Retrain", disabled=False, button_style="", tooltip=( "Click here to retrain the model and rank unlabelled data " "points based on its prediction." ), icon="refresh", ) self.retrain_button.on_click(self.retrain) self.model_performance = widgets.HTML(value="") else: self.retrain_button = widgets.Box() self.model_performance = widgets.Box() self.top_bar = widgets.HBox( [ widgets.HBox( [self.progressbar], layout=widgets.Layout(width="50%", justify_content="space-between"), ), widgets.HBox( [self.retrain_button, self.model_performance], layout=widgets.Layout(width="50%"), ), ] ) self.children = [self.top_bar, self.labelling_widget] if isinstance(worker_id, str): self.queue.worker_id = worker_id elif worker_id: self._get_worker_id() else: self._begin_annotation() # Workflow functionality --------- def _get_worker_id(self): worker_id_field = widgets.Text(placeholder="Please enter your name or user ID.") self.children = [ widgets.HTML(value="<h2>Please enter your name or user ID:</h2>"), widgets.Box( children=[worker_id_field], layout=widgets.Layout( justify_content="center", padding="5% 0", display="flex", width="100%", min_height="150px", ), ), ] with ignore_widget_on_submit_warning(): worker_id_field.on_submit(self._set_worker_id) def _set_worker_id(self, worker_id_field): if len(worker_id_field.value) > 0: self.queue.worker_id = worker_id_field.value self._begin_annotation() def _begin_annotation(self): """Set correct UI elements, then kick off the loop.""" self._annotation_loop = self._annotation_iterator() next(self._annotation_loop) # kick off the loop # data labelling functionality def _annotation_iterator(self): """The annotation loop.""" self.children = [self.top_bar, self.labelling_widget] self.progressbar.bar_style = "" for id_, x in self.queue: with self._render_hold_message("Loading..."): self.labelling_widget.display(x) y = yield if y is not None: self.queue.submit(id_, y) self.progressbar.value = self.queue.progress yield self._render_finished() def _apply_annotation(self, y): self._annotation_loop.send(y) def _undo(self): self.queue.undo() # unpop the current item self.queue.undo() # unpop and unlabel the previous item self._annotation_loop.send(None) # Advance next item
[docs] def add_features(self, features, labels=None): """ Add data to the widget. This adds the data provided to the queue of data to be labelled. You Can optionally provide labels for each data point. Parameters ---------- features : Any The data you'd like to add to the labelling widget. labels : Any, optional The labels for the data you're adding; if you have labels. """ self.queue.enqueue(features, labels) # reset the iterator self._annotation_loop = self._annotation_iterator() self.queue.undo() next(self._annotation_loop)
@contextmanager def _render_hold_message(self, message="Rendering..."): """Add a message that is followed by a spinner, indicating load time.""" timer = self.timers[message] spinner = '<i class="fa fa-spinner fa-spin" aria-hidden="true"></i>' message_widget = widgets.HTML( value=(f"<p><b>{message}</b>{spinner}"), layout=widgets.Layout(padding="0 10%"), ) if timer.last > 0.5: self.top_bar.children[0].children = [self.progressbar, message_widget] try: with timer: yield finally: self.top_bar.children[0].children = [self.progressbar] def _render_finished(self): """Render a celebratory message to the user.""" self.progressbar.bar_style = "success" message = widgets.Box( (widgets.HTML(value="<h1>Finished labelling 🎉!"),), layout=widgets.Layout( justify_content="center", padding="2.5% 0", display="flex", width="100%" ), ) self.children = [self.progressbar, message] @property def new_labels(self): _, _, labels = self.queue.list_all() return labels
[docs] def retrain(self, button=None): """Re-train the classifier you passed when creating this widget. This calls the fit method of your class with the data that you've labelled. It will also score the classifier and display the performance. Parameters ---------- button : widget.Widget, optional Optional & ignored; this is passed when invoked by a button. """ if self.model is None: raise ValueError("No model to retrain.") with self._render_hold_message("Retraining..."): _, labelled_X, labelled_y = self.queue.list_completed() if len(labelled_y) < 4: self.model_performance.value = "Not enough labels to retrain." return if self.model_preprocess is not None: labelled_X, labelled_y = self.model_preprocess(labelled_X, labelled_y) # first, fit the model try: self.model.fit(labelled_X, labelled_y) except ValueError as e: if str(e).startswith("This solver needs samples of at least 2"): self.model_performance.value = "Not enough classes to retrain." return else: raise # now evaluate. by default, using cross validation. in sklearn this # clones the model, so it's OK to do after the model fit. try: if self.eval_method is not None: performance = np.mean( self.eval_method(self.model, labelled_X, labelled_y) ) else: performance = np.mean( sklearn.model_selection.cross_val_score( self.model, labelled_X, labelled_y, cv=3, error_score=np.nan ) ) except ValueError as e: if "n_splits=" in str(e): self.model_performance.value = "Not enough labels to evaluate." return else: raise self.model_performance.value = f"Score: {performance:.3f}" if self.acquisition_function is not None: ids, unlabelled_X = self.queue.list_uncompleted() if self.model_preprocess is not None: unlabelled_X, _ = self.model_preprocess(unlabelled_X, None) reordering = list( self.acquisition_function( self.model.predict_proba(unlabelled_X), shuffle_prop=self.shuffle_prop, ) ) new_order = OrderedDict( [(id_, index) for id_, index in zip(ids, list(reordering))] ) self.queue.reorder(new_order) self.queue.undo() # undo the previously popped item self._annotation_loop.send(None) # advance the loop
# orchestrate when not interactively labelling
[docs] def orchestrate( self, interval_seconds: Optional[float] = None, interval_n_labels: int = 0, shuffle_prop: float = 0.1, max_runs: float = np.inf, ): """Orchestrate the active learning process. This method can either re-train the classifier and re-order the data once, or it can run a never-ending loop to re-train the model at regular intervals, both in time and in the size of labelled data. Parameters ---------- interval_seconds : int, optional How often the retraining should occur, in seconds. If this is None, the retraining only happens once, then returns (this is suitable) if you want the retraining schedule to be maintained e.g. by a cron job). The default is 60 seconds. interval_n_labels : int, optional How many new data points need to have been labelled in between runs in order for the re-training to occur. shuffle_prop : float What proportion of the data should be randomly sampled on each re- training run. max_runs : float, int How many orchestration runs to do at most. By default infinite. Returns ------- None """ if interval_seconds is None: self._run_orchestration( interval_n_labels=interval_n_labels, shuffle_prop=shuffle_prop ) else: runs = 0 while runs < max_runs: runs += self._run_orchestration( interval_n_labels=interval_n_labels, shuffle_prop=shuffle_prop ) time.sleep(interval_seconds)
def _run_orchestration( self, interval_n_labels: int = 0, shuffle_prop: float = 0.1 ) -> bool: first_orchestration = not hasattr(self, "_last_n_labelled") if first_orchestration: self._last_n_labelled = 0 n_new_labels = self.queue._labelled_count() - self._last_n_labelled if n_new_labels >= interval_n_labels: self._last_n_labelled += n_new_labels self.shuffle_prop = shuffle_prop self.retrain() # type: ignore print(self.model_performance.value) # type: ignore return True else: return False