torchgan.trainer¶
This subpackage provides ability to perform end to end training capabilities of the Generator and Discriminator models. It provides strong visualization capabilities using tensorboardX. Most of the cases can be handled elegantly with the default trainer itself. But if incase you need to subclass the trainer for any reason follow the docs closely.
Trainer¶
-
class
torchgan.trainer.
Trainer
(models, losses_list, metrics_list=None, device=<sphinx.ext.autodoc.importer._MockObject object>, ncritic=None, epochs=5, sample_size=8, checkpoints='./model/gan', retain_checkpoints=5, recon='./images', log_dir=None, test_noise=None, nrow=8, **kwargs)[source]¶ Base class for all Trainers for various GANs.
Features provided by this Base Trainer are:
- Loss and Metrics Logging via the
Logger
class. - Generating Image Samples.
- Saving models at the end of every epoch and loading of previously saved models.
- Highly flexible and allows changing hyperparameters by simply adjusting the arguments.
Most of the functionalities provided by the Trainer are flexible enough and can be customized by simply passing different arguments. You can train anything from a simple DCGAN to complex CycleGANs without ever having to subclass this
Trainer
.Parameters: - models (dict) – A dictionary containing a mapping between the variable name, storing the
generator
,discriminator
and any other model that you might want to define, with the function and arguments that are needed to construct the model. Refer to the examples to see how to define complex models using this API. - losses_list (list) – A list of the Loss Functions that need to be minimized. For a list of
pre-defined losses look at
torchgan.losses
. All losses in the list must be a subclass of atleastGeneratorLoss
orDiscriminatorLoss
. - metrics_list (list, optional) – List of Metric Functions that need to be logged. For a list of
pre-defined metrics look at
torchgan.metrics
. All losses in the list must be a subclass ofEvaluationMetric
. - device (torch.device, optional) – Device in which the operation is to be carried out. If you are using a CPU machine make sure that you change it for proper functioning.
- ncritic (int, optional) – Setting it to a value will make the discriminator train that many times more than the generator.
- sample_size (int, optional) – Total number of images to be generated at the end of an epoch for logging purposes.
- epochs (int, optional) – Total number of epochs for which the models are to be trained.
- checkpoints (str, optional) – Path where the models are to be saved. The naming convention is
if checkpoints is
./model/gan
then models are saved as./model/gan0.model
and so on. - retain_checkpoints (int, optional) – Total number of checkpoints that should be retained. For example, if the value is set to 3, we save at most 3 models and start rewriting the models after that.
- recon (str, optional) – Directory where the sampled images are saved. Make sure the directory exists from beforehand.
- log_dir (str, optional) – The directory for logging tensorboard. It is ignored if TENSORBOARD_LOGGING is 0.
- test_noise (torch.Tensor, optional) – If provided then it will be used as the noise for image sampling.
- nrow (int, optional) – Number of rows in which the image is to be stored.
Any other argument that you need to store in the object can be simply passed via keyword arguments.
Example
>>> dcgan = Trainer( {"generator": {"name": DCGANGenerator, "args": {"out_channels": 1, "step_channels": 16}, "optimizer": {"name": Adam, "args": {"lr": 0.0002, "betas": (0.5, 0.999)}}}, "discriminator": {"name": DCGANDiscriminator, "args": {"in_channels": 1, "step_channels": 16}, "optimizer": {"var": "opt_discriminator", "name": Adam, "args": {"lr": 0.0002, "betas": (0.5, 0.999)}}}}, [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()], sample_size=64, epochs=20)
-
complete
(**kwargs)[source]¶ Marks the end of training. It saves the final model and turns off the logger.
Note
It is not necessary to call this function. If it is not called the logger is kept alive in the background. So it might be considered a good practice to call this function.
-
eval_ops
(**kwargs)[source]¶ Runs all evaluation operations at the end of every epoch. It calls all the metric functions that are passed to the Trainer.
-
load_model
(load_path='', load_items=None)[source]¶ Function to load the model and some necessary information along with it. List of items loaded:
- Epoch
- Model States
- Optimizer States
- Loss Information
- Loss Objects
- Metric Objects
- Loss Logs
Warning
An Exception is raised if the model could not be loaded. Make sure that the model being loaded was saved previously by
torchgan Trainer
itself. We currently do not support loading any other form of models but this might be improved in the future.Parameters: - load_path (str, optional) – Path from which the model is to be loaded.
- load_items (str, list, optional) – Pass the variable name of any other item you want to load. If the item cannot be found then a warning will be thrown and model will start to train from scratch. So make sure that item was saved.
-
save_model
(epoch, save_items=None)[source]¶ Function saves the model and some necessary information along with it. List of items stored for future reference:
- Epoch
- Model States
- Optimizer States
- Loss Information
- Loss Objects
- Metric Objects
- Loss Logs
The save location is printed when this function is called.
Parameters:
-
train
(data_loader, **kwargs)[source]¶ Uses the information passed by the user while creating the object and trains the model. It iterates over the epochs and the DataLoader and calls the functions for training the models and logging the required variables.
Note
Even though
__call__
calls this function, it is best iftrain
is not called directly. When__call__
is invoked, we infer thebatch_size
from thedata_loader
. Also, we are certain not going to change the interface of the__call__
function so it gives the user a stable API, while we can change the flow of execution oftrain
in future.Warning
The user should never try to change this function in subclass. It is too delicate and changing affects every other function present in this
Trainer
class.This function controls the execution of all the components of the
Trainer
. It controls thelogger
,train_iter
,save_model
,eval_ops
andoptim_ops
.Parameters: data_loader (torch.utils.data.DataLoader) – A DataLoader for the trainer to iterate over and train the models.
-
train_iter
()[source]¶ Calls the train_ops of the loss functions. This is the core function of the Trainer. In most cases you will never have the need to extend this function. In extreme cases simply extend
train_iter_custom
.Warning
This function is needed in this exact state for the Trainer to work correctly. So it is highly recommended that this function is not changed even if the
Trainer
is subclassed.Returns: An NTuple of the generator loss
,discriminator loss
,number of times the generator was trained
and thenumber of times the discriminator was trained
.
- Loss and Metrics Logging via the