/usr/share/pyshared/mvpa2/base/learner.py is in python-mvpa2 2.2.0-4ubuntu2.
This file is owned by root:root, with mode 0o644.
The actual contents of the file can be viewed below.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 | # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
#
# See COPYING file distributed along with the PyMVPA package for the
# copyright and license terms.
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Implementation of a common trainable processing object (Learner)."""
__docformat__ = 'restructuredtext'
import time
from mvpa2.base.node import Node
from mvpa2.base.state import ConditionalAttribute
from mvpa2.base.types import is_datasetlike
from mvpa2.base.dochelpers import _repr_attrs
if __debug__:
from mvpa2.base import debug
class LearnerError(Exception):
"""Base class for exceptions thrown by the Learners
"""
pass
class DegenerateInputError(LearnerError):
"""Learner exception thrown if input data is not bogus
i.e. no features or samples
"""
pass
class FailedToTrainError(LearnerError):
"""Learner exception thrown if training failed"""
pass
class FailedToPredictError(LearnerError):
"""Learner exception if it fails to predictions.
Usually happens if it was trained on degenerate data but without any
complaints.
"""
pass
class Learner(Node):
"""Common trainable processing object.
A `Learner` is a `Node` that can (maybe has to) be trained on a dataset,
before it can perform its function.
"""
training_time = ConditionalAttribute(enabled=True,
doc="Time (in seconds) it took to train the learner")
def __init__(self, auto_train=False, force_train=False, **kwargs):
"""
Parameters
----------
auto_train : bool
Flag whether the learner will automatically train itself on the input
dataset when called untrained.
force_train : bool
Flag whether the learner will enforce training on the input dataset
upon every call.
**kwargs
All arguments are passed to the baseclass.
"""
Node.__init__(self, **kwargs)
self.__is_trained = False
self.__auto_train = auto_train
self.__force_train = force_train
def __repr__(self, prefixes=[]):
return super(Learner, self).__repr__(
prefixes=prefixes
+ _repr_attrs(self, ['auto_train', 'force_train'], default=False))
def train(self, ds):
"""
The default implementation calls ``_pretrain()``, ``_train()``, and
finally ``_posttrain()``.
Parameters
----------
ds: Dataset
Training dataset.
Returns
-------
None
"""
got_ds = is_datasetlike(ds)
# TODO remove first condition if all Learners get only datasets
if got_ds and (ds.nfeatures == 0 or len(ds) == 0):
raise DegenerateInputError(
"Cannot train learner on degenerate data %s" % ds)
if __debug__:
debug("LRN", "Training learner %(lrn)s on dataset %(dataset)s",
msgargs={'lrn':self, 'dataset': ds})
self._pretrain(ds)
# remember the time when started training
t0 = time.time()
if got_ds:
# things might have happened during pretraining
if ds.nfeatures > 0:
result = self._train(ds)
else:
warning("Trying to train on dataset with no features present")
if __debug__:
debug("LRN",
"No features present for training, no actual training " \
"is called")
result = None
else:
# in this case we claim to have no idea and simply try to train
result = self._train(ds)
# store timing
self.ca.training_time = time.time() - t0
# and post-proc
result = self._posttrain(ds)
# finally flag as trained
self._set_trained()
if __debug__:
debug("LRN", "Finished training learner %(lrn)s on dataset %(dataset)s",
msgargs={'lrn':self, 'dataset': ds})
def untrain(self):
"""Reverts changes in the state of this node caused by previous training
"""
# flag the learner as untrained
# important to do that before calling the implementation in the derived
# class, as it might decide that an object remains trained
self._set_trained(False)
# call subclass untrain first to allow it to access current attributes
self._untrain()
# TODO evaluate whether this should also reset the nodes collections, or
# whether that should be done by a more general reset() method
self.reset()
def _untrain(self):
# nothing by default
pass
def _pretrain(self, ds):
"""Preparations prior training.
By default, does nothing.
Parameters
----------
ds: Dataset
Original training dataset.
Returns
-------
None
"""
pass
def _train(self, ds):
# nothing by default
pass
def _posttrain(self, ds):
"""Finalizing the training.
By default, does nothing.
Parameters
----------
ds: Dataset
Original training dataset.
Returns
-------
None
"""
pass
def _set_trained(self, status=True):
"""Set the Learner's training status
Derived use this to set the Learner's status to trained (True) or
untrained (False).
"""
self.__is_trained = status
def __call__(self, ds):
# overwrite __call__ to perform a rigorous check whether the learner was
# trained before use and auto-train
if self.is_trained:
# already trained
if self.force_train:
if __debug__:
debug('LRN', "Forcing training of %s on %s",
(self, ds))
# but retraining is enforced
self.train(ds)
elif __debug__:
debug('LRN', "Skipping training of already trained %s on %s",
(self, ds))
else:
# not trained
if self.auto_train:
# auto training requested
if __debug__:
debug('LRN', "Auto-training %s on %s",
(self, ds))
self.train(ds)
else:
# we always have to have trained before using a learner
raise RuntimeError("%s needs to be trained before it can be "
"used and auto training is disabled."
% str(self))
return super(Learner, self).__call__(ds)
is_trained = property(fget=lambda x:x.__is_trained, fset=_set_trained,
doc="Whether the Learner is currently trained.")
auto_train = property(fget=lambda x:x.__auto_train,
doc="Whether the Learner performs automatic training"
"when called untrained.")
force_train = property(fget=lambda x:x.__force_train,
doc="Whether the Learner enforces training upon every"
"called.")
|