watcher.training package

watcher.training.train_watcher(dataset_dir: str, output_dir: str, max_gpus: int = 4, embedding_dim: int = 1024, num_layers: int = 16, num_heads: int = 16, ff_hidden_dim: int = 3072, dropout_rate: float = 0.1, total_epochs: int = 100, weight_decay: float = 0.01, max_lr: float = 0.0001, min_lr: float = 1e-05, lr_scheduler_enabled: bool = True, lr_warmup: float = 0.01, lr_decay: float = 0.9, batch_size_step_scale: int = 8, batch_size_active_phase: float = 0.03, max_batch_size: int = 64, min_batch_size: int = 16, batch_schedule_enabled: bool = True, dataloader_workers: int = 2, checkpoint_interval: int = 5, validation_interval: int = 5, early_stopping_epochs: int = 5, snapshot_path: str | None = None, initial_weight_path: str | None = None, precision: Literal['float32', 'float16', 'bfloat16'] = 'bfloat16', update: bool = False, restart_limit: int = 10, debug: bool = False, debug_chunks: int = 10, trainer_class: object | None = None) str[source]

Train a Watcher model.

This function initiates training and stops either when the maximum number of epochs is reached, or when early stopping is triggered due to no improvement in validation loss.

Warning

  • Training requires GPU devices with large memory capacity. We recommend using NVIDIA A100 or newer GPUs.

  • If training fails, it may be due to an OutOfMemoryError. In that case, consider reducing the batch size or using a smaller model configuration.

Note

  • Training may take from several hours to several days, depending on dataset size and hardware specifications.

  • For reference, in an experiment with approximately 370,000 patients using four NVIDIA A100 80GB GPUs, training completed in about 48 hours.

Example

Pretraining

from watcher.training import train_watcher

best_weights_path = train_watcher(
    dataset_dir="/path/to/prepared_dataset",
    output_dir="/path/to/save_training_outputs",
    max_gpus=4,
    embedding_dim=1024,
    num_layers=16,
    num_heads=16,
    ff_hidden_dim=3072,
    dropout_rate=0.1,
    total_epochs=100,
    weight_decay=0.01,
    max_lr=1e-4,
    min_lr=1e-5,
    lr_scheduler_enabled=True,
    lr_warmup=0.01,
    lr_decay=0.9,
    batch_size_step_scale=8,
    batch_size_active_phase=0.03,
    max_batch_size=64,
    min_batch_size=16,
    batch_schedule_enabled=True,
    dataloader_workers=2,
    checkpoint_interval=5,
    validation_interval=5,
    early_stopping_epochs=10,
    snapshot_path=None,
    initial_weight_path=None,
    precision="bfloat16",
    restart_limit=10,
)

print(f"Best model weights saved at: {best_weights_path}")

Fine tuning

from watcher.training import train_watcher

best_weights_path = train_watcher(
    dataset_dir="/path/to/prepared_dataset",  # Use the same dataset as pretraining
    output_dir="/path/to/save_finetuning_outputs",
    max_gpus=4,
    embedding_dim=1024,  # Use the same model hyperparameters as pretraining
    num_layers=16,       # Must match pretraining
    num_heads=16,        # Must match pretraining
    ff_hidden_dim=3072,  # Must match pretraining
    dropout_rate=0.1,
    total_epochs=20,
    weight_decay=0.01,
    max_lr=1e-5,
    min_lr=1e-5,
    lr_scheduler_enabled=False,  # Constant LR = min_lr
    batch_schedule_enabled=False,  # Constant batch size
    max_batch_size=64,
    min_batch_size=64,
    dataloader_workers=2,
    checkpoint_interval=5,
    validation_interval=5,
    early_stopping_epochs=5,
    initial_weight_path="/path/to/pretrained_model.pt",  # Set pretrained weight path
    precision="bfloat16",
    update=True,  # Flag to use fine-tuning dataset
    restart_limit=10,
)

print(f"Best model weights saved at: {best_weights_path}")

The following directory structure is created during training:

output_dir
├── main_training_report.json     # Training summary
├── profiling/
│   └── ...
├── snapshots/
│   ├── epoch_0/
│   │   ├── training_state.pt
│   │   ├── tensorboard_logs/
│   │   └── watcher_blueprint/
│   │       ├── catalogs/         # CSV files containing model vocabulary
│   │       ├── laboratory_stats/ # CSV files of lab test stats
│   │       ├── model_state.pt    # Model weights
│   │       └── training_report.json
│   └── ...
└── tensorboard_active/
    └── ... (TensorBoard logs)

The watcher_blueprint directory is the main product of training. Each blueprint contains everything needed to re-instantiate the Watcher model. To monitor training progress with TensorBoard, set the tensorboard_active directory as the logdir.

Parameters:
  • dataset_dir (str) – Path to the dataset created by watcher.preprocess.create_dataset().

  • output_dir (str) – Directory where training results are saved.

  • max_gpus (int, optional) – Maximum number of GPUs to use for training.

  • embedding_dim (int, optional) – Dimensionality of the model embeddings (d_model).

  • num_layers (int, optional) – Number of transformer blocks in the model.

  • num_heads (int, optional) – Number of attention heads per transformer layer.

  • ff_hidden_dim (int, optional) – Hidden layer size of the feedforward network (d_ff).

  • dropout_rate (float, optional) – Dropout rate applied during training.

  • total_epochs (int, optional) – Maximum number of training epochs.

  • weight_decay (float, optional) – Weight decay for regularization.

  • max_lr (float, optional) – Peak learning rate.

  • min_lr (float, optional) – Minimum learning rate.

  • lr_scheduler_enabled (bool, optional) – Whether to use a learning rate scheduler. If False, min_lr is used throughout training.

  • lr_warmup (float, optional) – Warm-up duration as a fraction of total training data.

  • lr_decay (float, optional) – Learning rate decay duration after warm-up.

  • batch_size_step_scale (int, optional) – Scaling factor for batch size scheduling.

  • batch_size_active_phase (float, optional) – Fraction of data (0 to 1.0) during which the batch size increases.

  • max_batch_size (int, optional) – Maximum batch size.

  • min_batch_size (int, optional) – Initial batch size.

  • batch_schedule_enabled (bool, optional) – If False, batch size is fixed to max_batch_size throughout training.

  • dataloader_workers (int, optional) – Number of worker processes for data loading.

  • checkpoint_interval (int, optional) – Epoch interval at which training snapshots are saved.

  • validation_interval (int, optional) – Epoch interval at which validation is performed.

  • early_stopping_epochs (int, optional) – Number of consecutive epochs without validation loss improvement required to stop training early.

  • snapshot_path (str, optional) – Path to a training snapshot to resume from. Ensure that all training parameters match the previous run.

  • initial_weight_path (str, optional) – Path to pretrained weights for model initialization. Required if update=True.

  • precision (Literal["float32", "float16", "bfloat16"], optional) – Floating point precision for training.

  • update (bool, optional) – If True, fine-tunes the model using the current dataset.

  • restart_limit (int, optional) – Maximum number of automatic training restarts after runtime errors.

  • debug (bool, optional) – If True, enables debug mode. Dataloaders will yield only debug_chunks number of samples.

  • debug_chunks (int, optional) – Number of samples yielded per loader in debug mode.

  • trainer_class (object, optional) – [Deprecated] Custom trainer class.

Returns:

Path to the best-performing model weights.

Return type:

best_weights (str)