example.fashionMNIST package

Subpackages

Submodules

example.fashionMNIST.simple_MLP_classifier module

class SimpleMLPClassifier(model, loss, train_loader, val_loader, test_loader, train_func, eval_func)

Typical example implementation of a JAWTrainer. We want train a mini classifier with the `fashion MNIST <https://github.com/zalandoresearch/fashion-mnist>`_dataset.

Here we want be able to launch a training with our previously written classes (dataloader, models, losses, training and evaluation processes).

Note

Here we only use the necessary methods for implement a JAWTrainer, but you are free to declare new parameters.

Is registered inside the constructor all parameters that can most likely not change between two training.

Note

Please notice that in the litterature we can find the case of test process before validation, but it’s usually the validation process that precedes the test one. It’s also generally the same method that is use for the validation and the testing process.

Parameters:
  • model (torch.nn.Module.) – Model to train.

  • loss (torch.nn.Module.) – Loss used for train and evaluate the model.

  • train_loader (torch.utils.data.dataloader.) – Dataloader used for load the training data.

  • val_loader (torch.utils.data.dataloader.) – Dataloader used for load the data that will be used for testing the training process.

  • test_loader (torch.utils.data.dataloader.) – Dataloader used for load the data that will be used for validate the model.

  • train_func (FunctionType.) – the method that handle the train loop.

  • test_func (FunctionType.) – the method that handle the validation loop.

launch_training(epochs, device, logdir, prefix)

Method where you define your custom training workflow.

Parameters:
  • epochs (int.) – Number of total training complete epochs.

  • device (torch.device.) – Pytorch device used for this training.

  • logdir (str.) – The name of the directory where your models and training info will be saved.

  • prefix (str.) – Prefix of the training saving directory.

Return type:

None

Returns:

None.

main(args)

Training definition. It’s here where we give our custom classes at our previous written :func: launch_training.

Parameters:

args (dict.) – Argument given in the command line.

Return type:

None

Returns:

None.