accelerators¶
The Accelerator base class for Lightning PyTorch.  | 
|
Accelerator for CPU devices.  | 
|
Accelerator for NVIDIA CUDA devices.  | 
|
Accelerator for XLA devices, normally TPUs.  | 
callbacks¶
Finetune a backbone model based on a learning rate user-defined scheduling.  | 
|
This class implements the base logic for writing your own Finetuning Callback.  | 
|
Base class to implement how the predictions should be stored.  | 
|
Finds the largest batch size supported by a given model before encountering an out of memory (OOM) error.  | 
|
Abstract base class used to build new callbacks.  | 
|
Automatically monitors and logs device stats during training, validation and testing stage.  | 
|
Monitor a metric and stop training when it stops improving.  | 
|
Change gradient accumulation factor according to scheduling.  | 
|
Create a simple callback on the fly using lambda functions.  | 
|
The   | 
|
Automatically monitor and logs learning rate for learning rate schedulers during training.  | 
|
Save the model periodically by monitoring a quantity.  | 
|
Model pruning Callback, using PyTorch's prune utilities.  | 
|
Generates a summary of all layers in a   | 
|
Used to save a checkpoint on exception.  | 
|
The base class for progress bars in Lightning.  | 
|
Generates a summary of all layers in a   | 
|
Create a progress bar with rich text formatting.  | 
|
Implements the Stochastic Weight Averaging (SWA) Callback to average a model.  | 
|
Computes and logs throughput with the   | 
|
The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the Trainer if the given time limit for the training loop is reached.  | 
|
This is the default progress bar used by Lightning.  | 
|
A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average (EMA) after each training step.  | 
cli¶
Implementation of a configurable command line tool for pytorch-lightning.  | 
|
Extension of jsonargparse's ArgumentParser for pytorch-lightning.  | 
|
Saves a LightningCLI config to the log_dir when training starts.  | 
core¶
Hooks to be used with Checkpointing.  | 
|
Hooks to be used for data related stuff.  | 
|
Hooks to be used in LightningModule.  | 
|
A DataModule standardizes the training, val, test splits, data preparation and transforms.  | 
|
This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches.  | 
loggers¶
Abstract base class used to build new loggers.  | 
|
Comet Logger  | 
|
CSV logger  | 
|
MLflow Logger  | 
|
Neptune Logger  | 
|
TensorBoard Logger  | 
|
Weights and Biases Logger  | 
plugins¶
precision¶
Precision plugin for DeepSpeed integration.  | 
|
Plugin for training with double (  | 
|
Plugin for training with half precision.  | 
|
Precision plugin for training with Fully Sharded Data Parallel (FSDP).  | 
|
Plugin for Automatic Mixed Precision (AMP) training with   | 
|
Base class for all plugins handling the precision-specific parts of the training.  | 
|
Plugin for training with XLA.  | 
|
Plugin for training with fp8 precision via nvidia's Transformer Engine.  | 
|
Plugin for quantizing weights with bitsandbytes.  | 
environments¶
Specification of a cluster environment.  | 
|
Environment for distributed training using the PyTorchJob operator from Kubeflow.  | 
|
The default environment used by Lightning for a single node or free cluster (not managed).  | 
|
An environment for running on clusters managed by the LSF resource manager.  | 
|
An environment for running on clusters with processes created through MPI.  | 
|
Cluster environment for training on a cluster managed by SLURM.  | 
|
Environment for fault-tolerant and elastic training with torchelastic  | 
|
Cluster environment for training on a TPU Pod with the PyTorch/XLA library.  | 
io¶
  | 
|
Interface to save/load checkpoints as they are saved through the   | 
|
CheckpointIO that utilizes   | 
|
CheckpointIO that utilizes   | 
others¶
Abstract base class for creating plugins that wrap layers of a model with synchronization logic for multiprocessing.  | 
|
A plugin that wraps all batch normalization layers of a model with synchronization logic for multiprocessing.  | 
profiler¶
This profiler uses Python's cProfiler to record more detailed information about time spent in each function call recorded during a given action.  | 
|
This class should be used when you don't want the (small) overhead of profiling.  | 
|
If you wish to write a custom profiler, you should inherit from this class.  | 
|
This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of different operators inside your model - both on the CPU and GPU.  | 
|
This profiler simply records the duration of actions (in seconds) and reports the mean duration of each action and the total time spent over the entire training run.  | 
|
XLA Profiler will help you debug and optimize training workload performance for your models using Cloud TPU performance tools.  | 
trainer¶
Customize every aspect of training via flags.  | 
strategies¶
Strategy for multi-process single-device training on one or multiple nodes.  | 
|
Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models.  | 
|
Strategy for Fully Sharded Data Parallel provided by torch.distributed.  | 
|
Enables user-defined parallelism applied to a model.  | 
|
Strategy for training with multiple processes in parallel.  | 
|
Strategy that handles communication on a single device.  | 
|
Strategy for training on a single XLA device.  | 
|
Base class for all strategies that change the behaviour of the training, validation and test- loop.  | 
|
Strategy for training multiple TPU devices using the   | 
tuner¶
Tuner class to tune your model.  | 
utilities¶
Utilities that can be used with Deepspeed.  | 
|
Utilities related to memory.  | 
|
Utilities used for parameter parsing.  | 
|
Utilities that can be used for calling functions on a particular rank.  | 
|
Utilities to help with reproducibility of models.  | 
|
Warning-related utilities.  | 
- lightning.pytorch.utilities.measure_flops(model, forward_fn, loss_fn=None)[source]¶
 Utility to compute the total number of FLOPs used by a module during training or during inference.
It’s recommended to create a meta-device model for this:
Example:
with torch.device("meta"): model = MyModel() x = torch.randn(2, 32) model_fwd = lambda: model(x) fwd_flops = measure_flops(model, model_fwd) model_loss = lambda y: y.sum() fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss)
- Parameters:
 - Return type: