Executor

class shabda.run.Executor(model, data_iterator, config, model_hparams=None, train_hooks=None, eval_hooks=None, session_config=None)[source]

Bases: object

Class that executes training, evaluation, prediction, export, and other actions of tf.estimator.Estimator.

Args:
model: An instance of a subclass of
ModelBase.
data_hparams: A dict or an instance of HParams
containing the hyperparameters of data. It must contain train and/or eval fields for relevant processes. For example, for train_and_evaluate(), both fields are required.
config: An instance of
tf.estimator.RunConfig, used as the config argument of Estimator.
model_hparams (optional): A dict or an instance of
HParams containing the hyperparameters of the model. If None, uses model.hparams. Used as the params argument of Estimator.
train_hooks (optional): Iterable of tf.train.SessionRunHook
objects to run during training.
eval_hooks (optional): Iterable of tf.train.SessionRunHook
objects to run during evaluation.
session_config (optional): An instance of
tf.ConfigProto, used as the config argument of tf session.

Example:

TODO

See bin/train.py for the usage in detail.

evaluate(steps=None, checkpoint_path=None)[source]

Evaluates the model. See tf.estimator.Estimator.evaluate for more details.

Args:
steps (int, optional): Number of steps for which to evaluate
model. If None, evaluates until the eval data raises an OutOfRange exception.
checkpoint_path (str, optional): Path of a specific checkpoint to
evaluate. If None, the the latest checkpoint in config.model_dir is used. If there are no checkpoints in model_dir, evaluation is run with newly initialized variables instead of restored from checkpoint.
train(max_steps=None)[source]

Trains the model. See tf.estimator.Estimator.train for more details.

Args:
max_steps (int, optional): Total number of steps for which
to train model. If None, train forever or until the train data generates the OutOfRange exception. If OutOfRange occurs in the middle, training stops before max_steps steps.
train_and_evaluate(max_train_steps=None, eval_steps=None)[source]

Trains and evaluates the model. See tf.estimator.train_and_evaluate for more details.

Args:
max_train_steps (int, optional): Total number of steps for which
to train model. If None, train forever or until the train data generates the OutOfRange exception. If OutOfRange occurs in the middle, training stops before max_steps steps.
eval_steps (int, optional): Number of steps for which to evaluate
model. If None, evaluates until the eval data raises an OutOfRange exception.