Source code for acton.plot

"""Script to plot a dump of predictions."""

import itertools
import sys
from typing import Iterable
from typing.io import BinaryIO

import acton.proto.io
from acton.proto.acton_pb2 import Predictions
import acton.proto.wrappers
import click
import matplotlib.pyplot as plt
import sklearn.metrics


[docs]def plot(predictions: Iterable[BinaryIO]): """Plots predictions from a file. Parameters ---------- predictions Files containing predictions. """ if len(predictions) < 1: raise ValueError('Must have at least 1 set of predictions.') metadata = [] predictions, predictions_ = itertools.tee(predictions) for proto_file in predictions_: metadata.append(acton.proto.io.read_metadata(proto_file)) proto_file.seek(0) for meta, proto_file in zip(metadata, predictions): # Read in the first protobuf to get the database file. protobuf = next(acton.proto.io.read_protos(proto_file, Predictions)) protobuf = acton.proto.wrappers.Predictions(protobuf) with protobuf.DB() as db: accuracies = [] for protobuf in acton.proto.io.read_protos( proto_file, Predictions): protobuf = acton.proto.wrappers.Predictions(protobuf) ids = protobuf.predicted_ids predictions_ = protobuf.predictions assert predictions_.shape[0] == 1 predictions_ = predictions_[0] labels = db.read_labels([0], ids).ravel() predicted_labels = predictions_.argmax(axis=1).ravel() predicted_labels = [str(p).encode('ascii') # quick and for p in predicted_labels] # dirty hack print(labels, predicted_labels) accuracies.append(sklearn.metrics.accuracy_score( labels, predicted_labels)) plt.plot(accuracies, label=meta.decode('ascii', errors='replace')) plt.xlabel('Number of additional labels') plt.ylabel('Accuracy score') plt.legend() plt.show()
@click.command() @click.argument('predictions', type=click.File('rb'), nargs=-1, required=True) def _plot(predictions: Iterable[BinaryIO]): """Plots predictions from a file. Parameters ---------- predictions Files containing predictions. """ return plot(predictions) if __name__ == '__main__': sys.exit(_plot())