Source code for watcher.models.src.simulator

from datetime import timedelta, datetime
import pandas as pd
from .model_loaders import build_watcher
from .generation import monte_carlo, make_sim_dataframe
from ...preprocess import preprocess_for_inference


[docs] class Simulator: """ High-level interface for running patient trajectory simulations. This class wraps a Watcher model and provides a convenient interface for generating multiple simulated patient timelines given initial patient information. Example: .. code-block:: python from watcher.models import Simulator simulator = Simulator( blueprint="path/to/blueprint", device = "cuda:0", # Your GPU ID) Attributes: blueprint (str): Path to the model blueprint directory. device (str): GPU device ID. model (Watcher): Instantiated Watcher model ready for inference. """ def __init__(self, blueprint: str, device: str): """ Initialize the Simulator with a trained Watcher model. Args: blueprint (str): Path to the model blueprint directory. device (str): Device for running simulations ('cpu' or 'cuda'). """ self.model = build_watcher(blueprint=blueprint, train=False) self.model = self.model.to(device) self.device = device self.blueprint = blueprint
[docs] def simulate( self, db_schema: str, patient_id: str, n_iter: int, timeline_start: str, simulation_start: str, time_horizon: int, max_length: int, age_as_timestamp: bool = False, stride: int = 64, temperature: float = 1.0, ) -> pd.DataFrame: """ Run Monte Carlo simulations for a single patient. This function performs preprocessing, generates multiple simulated timelines, and returns the results as a pandas DataFrame. Example: .. code-block:: python from watcher.models import Simulator simulator = Simulator( blueprint="path/to/blueprint", # Path to the model blueprint device="cuda:0", # Your GPU ID ) simulations = simulator.simulate( db_schema="public", patient_id="0001UT", n_iter= 256, # Run 256 simulations timeline_start="2011/01/01 00:00", # Patient history start time simulation_start="2025/04/26 14:57", # Simulation start time time_horizon = 7, # Simulation time horizon in days (7 days) max_length= 2000, # Maximum sequence length for each simulation temperature = 1.0, The details of the simulation results are available at :ref:`simulation_result_table`. Args: db_schema (str): PostgreSQL schema name containing patient records. (If you have not specified a schema, the schema name is expected to be 'public'.) patient_id (str): Patient ID in the database. n_iter (int): Number of simulation trajectories to generate. If you only want one simulation, set this to 1. If you are performing the Monte Carlo simulation, set this to a large number (e.g., 256). timeline_start (str): Start date for retrieving patient history (format: '%Y/%m/%d %H:%M'). simulation_start (str): Simulation start time (format: '%Y/%m/%d %H:%M'). time_horizon (int): Length of simulation in days. Simulation is completed once the time horizon is reached. max_length (int): Maximum allowed sequence length for each simulation. If the simulation exceeds this length, the simulation result is discarded. age_as_timestamp (bool, optional): If True, timestamps are placed in the `age` column instead of patient age. stride (int, optional): Stride size for sliding window decoding. Defaults to 64. temperature (float, optional): Sampling temperature for stochasticity. Defaults to 1.0. Returns: pd.DataFrame: A DataFrame containing simulated timelines. - Each simulation is labeled by 'patient_id' (e.g., 'simulation0', 'simulation1', etc.). """ # Preprocess args time_horizon_td = timedelta(days=time_horizon) # Preprocess clinical records timeline, catalog_ids, dob = preprocess_for_inference( patient_id=patient_id, model=self.model, start=timeline_start, end=simulation_start, db_schema=db_schema, ) # Preprocess_time horizon_start_td = datetime.strptime(simulation_start, "%Y/%m/%d %H:%M") - dob # Monte Carlo prod_timelines, prod_catalog_ids = monte_carlo( model=self.model, timeline=timeline, catalog_ids=catalog_ids, n_iter=n_iter, time_horizon=time_horizon_td, horizon_start=horizon_start_td, stride=stride, temperature=temperature, max_length=max_length, stop_vocab=None, return_generated_parts_only=True, ) if prod_timelines and prod_catalog_ids: df = make_sim_dataframe( interpreter=self.model.interpreter, prod_timelines=prod_timelines, prod_catalog_ids=prod_catalog_ids, ) if age_as_timestamp: df["age"] = pd.to_timedelta(df["age"]) df["age"] = df["age"] + dob df["age"] = df["age"].dt.strftime("%Y/%m/%d %H:%M") else: df = pd.DataFrame() return df