"""Classes that wrap protobufs."""
import json
from typing import Union, List, Iterable
import acton.database
import acton.proto.acton_pb2 as acton_pb
import acton.proto.io
import google.protobuf.json_format as json_format
import numpy
import sklearn.preprocessing
from sklearn.preprocessing import LabelEncoder as SKLabelEncoder
[docs]def validate_db(db: acton_pb.Database):
"""Validates a Database proto.
Parameters
----------
db
Database to validate.
Raises
------
ValueError
"""
if db.class_name not in acton.database.DATABASES:
raise ValueError('Invalid database class name: {}'.format(
db.class_name))
if not db.path:
raise ValueError('Must specify db.path.')
[docs]def deserialise_encoder(
encoder: acton_pb.Database.LabelEncoder
) -> sklearn.preprocessing.LabelEncoder:
"""Deserialises a LabelEncoder protobuf.
Parameters
----------
encoder
LabelEncoder protobuf.
Returns
-------
sklearn.preprocessing.LabelEncoder
LabelEncoder (or None if no encodings were specified).
"""
encodings = []
for encoding in encoder.encoding:
encodings.append((encoding.class_int, encoding.class_label))
encodings.sort()
encodings = numpy.array([c[1] for c in encodings])
encoder = SKLabelEncoder()
encoder.classes_ = encodings
return encoder
[docs]class LabelPool(object):
"""Wrapper for the LabelPool protobuf.
Attributes
----------
proto : acton_pb.LabelPool
Protobuf representing the label pool.
db_kwargs : dict
Key-value pairs of keyword arguments for the database constructor.
label_encoder : sklearn.preprocessing.LabelEncoder
Encodes labels as integers. May be None.
"""
def __init__(self, proto: Union[str, acton_pb.LabelPool]):
"""
Parameters
----------
proto
Path to .proto file, or raw protobuf itself.
"""
try:
self.proto = acton.proto.io.read_proto(proto, acton_pb.LabelPool)
except TypeError:
if isinstance(proto, acton_pb.LabelPool):
self.proto = proto
else:
raise TypeError('proto should be str or LabelPool protobuf.')
self._validate_proto()
self.db_kwargs = {kwa.key: json.loads(kwa.value)
for kwa in self.proto.db.kwarg}
if len(self.proto.db.label_encoder.encoding) > 0:
self.label_encoder = deserialise_encoder(
self.proto.db.label_encoder)
self.db_kwargs['label_encoder'] = self.label_encoder
else:
self.label_encoder = None
self._set_default()
@classmethod
[docs] def deserialise(cls, proto: bytes, json: bool=False) -> 'LabelPool':
"""Deserialises a protobuf into a LabelPool.
Parameters
----------
proto
Serialised protobuf.
json
Whether the serialised protobuf is in JSON format.
Returns
-------
LabelPool
"""
if not json:
lp = acton_pb.LabelPool()
lp.ParseFromString(proto)
return cls(lp)
return cls(json_format.Parse(proto, acton_pb.LabelPool()))
@property
def DB(self) -> acton.database.Database:
"""Gets a database context manager for the specified database.
Returns
-------
type
Database context manager.
"""
if hasattr(self, '_DB'):
return self._DB
self._DB = lambda: acton.database.DATABASES[self.proto.db.class_name](
self.proto.db.path, **self.db_kwargs)
return self._DB
@property
def ids(self) -> List[int]:
"""Gets a list of IDs.
Returns
-------
List[int]
List of known IDs.
"""
if hasattr(self, '_ids'):
return self._ids
self._ids = list(self.proto.id)
return self._ids
@property
def labels(self) -> numpy.ndarray:
"""Gets labels array specified in input.
Notes
-----
The returned array is cached by this object so future calls will not
need to recompile the array.
Returns
-------
numpy.ndarray
T x N x F NumPy array of labels.
"""
if hasattr(self, '_labels'):
return self._labels
ids = self.ids
with self.DB() as db:
return db.read_labels([0], ids)
def _validate_proto(self):
"""Checks that the protobuf is valid and enforces constraints.
Raises
------
ValueError
"""
validate_db(self.proto.db)
def _set_default(self):
"""Adds default parameters to the protobuf."""
pass
@classmethod
[docs] def make(
cls: type,
ids: Iterable[int],
db: acton.database.Database) -> 'LabelPool':
"""Constructs a LabelPool.
Parameters
----------
ids
Iterable of instance IDs.
db
Database
Returns
-------
LabelPool
"""
proto = acton_pb.LabelPool()
# Store the IDs.
for id_ in ids:
proto.id.append(id_)
# Store the database.
proto.db.CopyFrom(db.to_proto())
return cls(proto)
[docs]class Predictions(object):
"""Wrapper for the Predictions protobuf.
Attributes
----------
proto : acton_pb.Predictions
Protobuf representing predictions.
db_kwargs : dict
Dictionary of database keyword arguments.
label_encoder : sklearn.preprocessing.LabelEncoder
Encodes labels as integers. May be None.
"""
def __init__(self, proto: Union[str, acton_pb.Predictions]):
"""
Parameters
----------
proto
Path to .proto file, or raw protobuf itself.
"""
try:
self.proto = acton.proto.io.read_proto(
proto, acton_pb.Predictions)
except TypeError:
if isinstance(proto, acton_pb.Predictions):
self.proto = proto
else:
raise TypeError('proto should be str or Predictions protobuf.')
self._validate_proto()
self.db_kwargs = {kwa.key: json.loads(kwa.value)
for kwa in self.proto.db.kwarg}
if len(self.proto.db.label_encoder.encoding) > 0:
self.label_encoder = deserialise_encoder(
self.proto.db.label_encoder)
self.db_kwargs['label_encoder'] = self.label_encoder
else:
self.label_encoder = None
self._set_default()
@property
def DB(self) -> acton.database.Database:
"""Gets a database context manager for the specified database.
Returns
-------
type
Database context manager.
"""
if hasattr(self, '_DB'):
return self._DB
self._DB = lambda: acton.database.DATABASES[self.proto.db.class_name](
self.proto.db.path, **self.db_kwargs)
return self._DB
@property
def predicted_ids(self) -> List[int]:
"""Gets a list of IDs corresponding to predictions.
Returns
-------
List[int]
List of IDs corresponding to predictions.
"""
if hasattr(self, '_predicted_ids'):
return self._predicted_ids
self._predicted_ids = [prediction.id
for prediction in self.proto.prediction]
return self._predicted_ids
@property
def labelled_ids(self) -> List[int]:
"""Gets a list of IDs the predictor knew the label for.
Returns
-------
List[int]
List of IDs the predictor knew the label for.
"""
if hasattr(self, '_labelled_ids'):
return self._labelled_ids
self._labelled_ids = list(self.proto.labelled_id)
return self._labelled_ids
@property
def predictions(self) -> numpy.ndarray:
"""Gets predictions array specified in input.
Notes
-----
The returned array is cached by this object so future calls will not
need to recompile the array.
Returns
-------
numpy.ndarray
T x N x D NumPy array of predictions.
"""
if hasattr(self, '_predictions'):
return self._predictions
self._predictions = []
for prediction in self.proto.prediction:
data = prediction.prediction
shape = (self.proto.n_predictors,
self.proto.n_prediction_dimensions)
self._predictions.append(
acton.proto.io.get_ndarray(data, shape, float))
self._predictions = numpy.array(self._predictions).transpose((1, 0, 2))
return self._predictions
def _validate_proto(self):
"""Checks that the protobuf is valid and enforces constraints.
Raises
------
ValueError
"""
if self.proto.n_predictors < 1:
raise ValueError('Number of predictors must be > 0.')
if self.proto.n_prediction_dimensions < 1:
raise ValueError('Prediction dimension must be > 0.')
validate_db(self.proto.db)
def _set_default(self):
"""Adds default parameters to the protobuf."""
pass
@classmethod
[docs] def make(
cls: type,
predicted_ids: Iterable[int],
labelled_ids: Iterable[int],
predictions: numpy.ndarray,
db: acton.database.Database,
predictor: str='') -> 'Predictions':
"""Converts NumPy predictions to a Predictions object.
Parameters
----------
predicted_ids
Iterable of instance IDs corresponding to predictions.
labelled_ids
Iterable of instance IDs used to train the predictor.
predictions
T x N x D array of corresponding predictions.
predictor
Name of predictor used to generate predictions.
db
Database.
Returns
-------
Predictions
"""
proto = acton_pb.Predictions()
# Store single data first.
n_predictors, n_instances, n_prediction_dimensions = predictions.shape
proto.n_predictors = n_predictors
proto.n_prediction_dimensions = n_prediction_dimensions
proto.predictor = predictor
# Store the database.
proto.db.CopyFrom(db.to_proto())
# Store the predictions array. We can do this by looping over the
# instances.
for id_, prediction in zip(
predicted_ids, predictions.transpose((1, 0, 2))):
prediction_ = proto.prediction.add()
prediction_.id = int(id_) # numpy.int64 -> int
prediction_.prediction.extend(prediction.ravel())
# Store the labelled IDs.
for id_ in labelled_ids:
# int() here takes numpy.int64 to int, for protobuf compatibility.
proto.labelled_id.append(int(id_))
return cls(proto)
@classmethod
[docs] def deserialise(cls, proto: bytes, json: bool=False) -> 'Predictions':
"""Deserialises a protobuf into Predictions.
Parameters
----------
proto
Serialised protobuf.
json
Whether the serialised protobuf is in JSON format.
Returns
-------
Predictions
"""
if not json:
predictions = acton_pb.Predictions()
predictions.ParseFromString(proto)
return cls(predictions)
return cls(json_format.Parse(proto, acton_pb.Predictions()))
[docs]class Recommendations(object):
"""Wrapper for the Recommendations protobuf.
Attributes
----------
proto : acton_pb.Recommendations
Protobuf representing recommendations.
db_kwargs : dict
Key-value pairs of keyword arguments for the database constructor.
label_encoder : sklearn.preprocessing.LabelEncoder
Encodes labels as integers. May be None.
"""
def __init__(self, proto: Union[str, acton_pb.Recommendations]):
"""
Parameters
----------
proto
Path to .proto file, or raw protobuf itself.
"""
try:
self.proto = acton.proto.io.read_proto(
proto, acton_pb.Recommendations)
except TypeError:
if isinstance(proto, acton_pb.Recommendations):
self.proto = proto
else:
raise TypeError(
'proto should be str or Recommendations protobuf.')
self._validate_proto()
self.db_kwargs = {kwa.key: json.loads(kwa.value)
for kwa in self.proto.db.kwarg}
if len(self.proto.db.label_encoder.encoding) > 0:
self.label_encoder = deserialise_encoder(
self.proto.db.label_encoder)
self.db_kwargs['label_encoder'] = self.label_encoder
else:
self.label_encoder = None
self._set_default()
@classmethod
[docs] def deserialise(cls, proto: bytes, json: bool=False) -> 'Recommendations':
"""Deserialises a protobuf into Recommendations.
Parameters
----------
proto
Serialised protobuf.
json
Whether the serialised protobuf is in JSON format.
Returns
-------
Recommendations
"""
if not json:
recommendations = acton_pb.Recommendations()
recommendations.ParseFromString(proto)
return cls(recommendations)
return cls(json_format.Parse(proto, acton_pb.Recommendations()))
@property
def DB(self) -> acton.database.Database:
"""Gets a database context manager for the specified database.
Returns
-------
type
Database context manager.
"""
if hasattr(self, '_DB'):
return self._DB
self._DB = lambda: acton.database.DATABASES[self.proto.db.class_name](
self.proto.db.path, **self.db_kwargs)
return self._DB
@property
def recommendations(self) -> List[int]:
"""Gets a list of recommended IDs.
Returns
-------
List[int]
List of recommended IDs.
"""
if hasattr(self, '_recommendations'):
return self._recommendations
self._recommendations = list(self.proto.recommended_id)
return self._recommendations
@property
def labelled_ids(self) -> List[int]:
"""Gets a list of labelled IDs.
Returns
-------
List[int]
List of labelled IDs.
"""
if hasattr(self, '_labelled_ids'):
return self._labelled_ids
self._labelled_ids = list(self.proto.labelled_id)
return self._labelled_ids
def _validate_proto(self):
"""Checks that the protobuf is valid and enforces constraints.
Raises
------
ValueError
"""
validate_db(self.proto.db)
def _set_default(self):
"""Adds default parameters to the protobuf."""
pass
@classmethod
[docs] def make(
cls: type,
recommended_ids: Iterable[int],
labelled_ids: Iterable[int],
recommender: str,
db: acton.database.Database) -> 'Recommendations':
"""Constructs a Recommendations.
Parameters
----------
recommended_ids
Iterable of recommended instance IDs.
labelled_ids
Iterable of labelled instance IDs used to make recommendations.
recommender
Name of the recommender used to make recommendations.
db
Database.
Returns
-------
Recommendations
"""
proto = acton_pb.Recommendations()
# Store single data first.
proto.recommender = recommender
# Store the IDs.
for id_ in recommended_ids:
proto.recommended_id.append(id_)
for id_ in labelled_ids:
proto.labelled_id.append(id_)
# Store the database.
proto.db.CopyFrom(db.to_proto())
return cls(proto)