jaw.utils package
Submodules
jaw.utils.computation module
- get_device()
Wrap the pytorch device selection.
- Returns:
a torch device with GPU calculation if cuda is available, with CPU calculation else.
- Ret type:
torch.device.
jaw.utils.progress_bar module
Simple progress bar. All credits to Kangliu : https://github.com/kuangliu/pytorch-cifar/blob/master/utils.py.
- format_time(seconds)
Transform second into days, hours or minutes if needed. Format the time as follow : day, hours, min, sec, ms.
- Parameters:
seconds (int.) – Seconds elapsed since the training launching.
- Returns:
the string of the formated elapsed time.
- progress_bar(current, total, msg=None)
Print a progress bar.
Note
Call this function inside the
test()function. Example :def test(model, loader, f_loss, optimizer, device): ... for i, (inputs, targets) in enumerate(loader): inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs) loss = f_loss(outputs, targets) optimizer.zero_grad() loss.backward() optimizer.step() N += inputs.shape[0] tot_loss += inputs.shape[0] * f_loss(outputs, targets).item() predicted_targets = outputs.argmax(dim=1) correct += (predicted_targets == targets).sum().item() progress_bar(i, len(loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % (tot_loss/(i+1), 100.*correct/N, correct, N)) return tot_loss/N, correct/N
- Parameters:
current (int.) – The position of the current example inside the dataset.
total (int.) – Number of total examples inside the dataset.
msg (str.) – Message to print aside the progress bar.
- Returns:
the string of the formated elapsed time.
jaw.utils.tracking module
- class ModelCheckpoint(filepath, model)
Utility class for saving a model when a training produce a new better.
- update(loss)
Check if the current model is better than the previous saved. If true, overwrite the best model by the current one.
- Parameters:
loss (float.) – Validation loss of the current model.
- generate_unique_logpath(logdir, prefix)
Generate a unique log file for a new model saving.
- Parameters:
logdir (str.) – The path of the directory where the logs will be saved.
prefix (str.) – Prefix name of the log file.