Executor¶
-
class
shabda.run.
Executor
(model, data_iterator, config, model_hparams=None, train_hooks=None, eval_hooks=None, session_config=None)[source]¶ Bases:
object
Class that executes training, evaluation, prediction, export, and other actions of tf.estimator.Estimator.
- Args:
- model: An instance of a subclass of
ModelBase
.- data_hparams: A dict or an instance of
HParams
- containing the hyperparameters of data. It must contain train
and/or eval fields for relevant processes. For example, for
train_and_evaluate()
, both fields are required. - config: An instance of
- tf.estimator.RunConfig, used as
the
config
argument of Estimator. - model_hparams (optional): A dict or an instance of
HParams
containing the hyperparameters of the model. If None, usesmodel.hparams
. Used as theparams
argument of Estimator.- train_hooks (optional): Iterable of tf.train.SessionRunHook
- objects to run during training.
- eval_hooks (optional): Iterable of tf.train.SessionRunHook
- objects to run during evaluation.
- session_config (optional): An instance of
- tf.ConfigProto, used as the
config
argument of tf session.
Example:
TODO
See bin/train.py for the usage in detail.
-
evaluate
(steps=None, checkpoint_path=None)[source]¶ Evaluates the model. See tf.estimator.Estimator.evaluate for more details.
- Args:
- steps (int, optional): Number of steps for which to evaluate
- model. If None, evaluates until the eval data raises an OutOfRange exception.
- checkpoint_path (str, optional): Path of a specific checkpoint to
- evaluate. If None, the the latest checkpoint in
config.model_dir
is used. If there are no checkpoints inmodel_dir
, evaluation is run with newly initialized variables instead of restored from checkpoint.
-
train
(max_steps=None)[source]¶ Trains the model. See tf.estimator.Estimator.train for more details.
- Args:
- max_steps (int, optional): Total number of steps for which
- to train model. If None, train forever or until the train
data generates the OutOfRange exception. If OutOfRange occurs
in the middle, training stops before
max_steps
steps.
-
train_and_evaluate
(max_train_steps=None, eval_steps=None)[source]¶ Trains and evaluates the model. See tf.estimator.train_and_evaluate for more details.
- Args:
- max_train_steps (int, optional): Total number of steps for which
- to train model. If None, train forever or until the train
data generates the OutOfRange exception. If OutOfRange occurs
in the middle, training stops before
max_steps
steps. - eval_steps (int, optional): Number of steps for which to evaluate
- model. If None, evaluates until the eval data raises an OutOfRange exception.