Source code for watcher.watcher_api.src.api_src

import json
import uuid
import time
import signal
import atexit
import traceback
from multiprocessing import Process, JoinableQueue, Manager
from multiprocessing.managers import DictProxy
from datetime import datetime, timedelta
import pandas as pd
import torch
from flask import Flask, make_response, request
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
from flask_limiter.errors import RateLimitExceeded
from ...models import (
    build_watcher,
    Watcher,
    WatcherOrchestrator,
)
from ...preprocess import preprocess_for_inference
from ...general_params import watcher_config as config

TIMEOUT = 60
EXPIRES_AFTER = 600
STATE_204 = {
    "status": 204,
    "progress": "Empty",
    "errors": ["Request not found"],
    "updated": datetime.now().isoformat(),
}

import atexit


@atexit.register
def cleanup_all():
    print("Shutting down on exit.")
    _graceful_shutdown("atexit", None)


# ==================
# Flask app setup
# ==================
# region: Flask apps
# Initialize Flask app
app = Flask(__name__)
limiter = Limiter(
    key_func=get_remote_address,
    storage_uri="redis://redis:6379",  # container name used as hostname
    app=app,
)
limiter.init_app(app)


# Rate-limit handler for Flask-limiter
@app.errorhandler(RateLimitExceeded)
def ratelimit_handler(e):
    """Handles rate limits."""
    return make_response(
        json.dumps(
            {
                "status": 429,
                "errors": ["Rate limit exceeded."],
                "retry_after_seconds": e.retry_after,
            }
        ),
        429,
    )


# API to add a request in the queue
@app.route("/watcher_api/monte_carlo", methods=["POST"])
@limiter.limit("1 per 5 seconds")
def submit_monte_carlo_request():
    """API to perform Monte Carlo simulation."""
    try:
        # Variables
        request_queue = app.config["REQUEST_QUEUE"]
        status_board = app.config["STATUS_BOARD"]

        # Generate simulation id (UUID)
        simulation_id = uuid.uuid4().hex

        # Get data
        data = request.json.get("data", {})
        patient_id = data.get("patient_id", None)
        n_iter_str = data.get("n_iter", None)
        time_horizon_str = data.get("time_horizon", None)
        sim_start_str = data.get("sim_start", None)

        # Put the request in the queue
        # NOTE: Any unexpected data is put as 'None'
        request_queue.put(
            (patient_id, sim_start_str, time_horizon_str, n_iter_str, simulation_id),
            timeout=TIMEOUT,
        )

        # Write the request status
        status_board[simulation_id] = {
            "status": 202,
            "progress": "Waiting for data preprocessing",
            "errors": [],
            "updated": datetime.now().isoformat(),
        }

        # Return the simulation id
        # NOTE: The task status is '202', but the json response is returned with 200
        # because submission was successful
        return make_response(json.dumps({"simulation_id": simulation_id}), 200)

    except Exception as e:
        # Error handling
        error_response = make_response(json.dumps({"errors": str(e)}), 500)
        error_response.headers["Content-Type"] = "application/json"
        traceback.print_exc()
        return error_response


# API to recieve the product
@app.route("/watcher_api/result/<simulation_id>", methods=["GET"])
@limiter.limit("1 per 1 seconds")
def get_result(simulation_id):
    """Fetches products if ready."""
    # Check request status
    status_board = app.config["STATUS_BOARD"]
    state = status_board.get(simulation_id)
    if state is not None:
        status = state["status"]
        if status in [200, 400]:
            # Pop out the final status
            final_state = status_board.pop(simulation_id, STATE_204)
            # Successful response
            if status == 200:
                # Get products
                product_store = app.config["PRODUCT_STORE"]
                result = product_store.pop(simulation_id, None)
                if result is not None:
                    return make_response(result, 200)
                else:
                    # Products not found (e.g., expired)
                    return make_response(STATE_204, 204)
            # Invalid request (400)
            else:
                return make_response(json.dumps(final_state), 400)

        # Pending response (202)
        else:
            return make_response(json.dumps(status_board[simulation_id]), 202)

    # simulation id not found (e.g., expired)
    else:
        return make_response(json.dumps(STATE_204), 204)


# endregion


# ==========================
# Simulator background utils
# ==========================
# region: Simulator utils
def _graceful_shutdown(signum, frame):
    print(f"Received signal {signum}. Terminating simulator processes...")

    # 1. Terminate child processes
    started_processes = app.config.get("STARTED_PROC")
    if started_processes:
        for proc in app.config["STARTED_PROC"]:
            if proc.is_alive():
                proc.terminate()
                proc.join(timeout=10)  # force shutdown if not clean

    print("All simulator processes terminated.")

    # 2. Terminate orchestrator
    active_orch = app.config.get("ORCH")
    if active_orch is not None:
        try:
            print("Terminating WatcherOrchestrator...")
            active_orch.terminate()
        except Exception as e:
            print(f"Failed to terminate orchestrator cleanly: {e}")


def _setup_signal_handlers():
    signal.signal(signal.SIGTERM, _graceful_shutdown)
    signal.signal(signal.SIGINT, _graceful_shutdown)
    atexit.register(lambda: _graceful_shutdown(signum="atexit", frame=None))


def _preprocess_requests(
    request_queue: JoinableQueue,
    temp_queue: JoinableQueue,
    status_board: DictProxy,
    max_n_iter: int,
    model: Watcher,
    db_schema: str,
):
    """Monitors completed requests and stores simulation output in a shared product store."""
    torch.set_num_threads(1)
    while True:
        try:
            errors = []
            # Get one request
            req = request_queue.get()
            patient_id, sim_start_str, time_horizon_str, n_iter_str, simulation_id = req
            # ==========================
            # Input validation
            # ==========================
            # Validate patient ID
            # NOTE: If the patient ID is not in the database, currently, ValueError is raised by `preprocess_for_inference()`
            if patient_id is None:
                errors.append("Patient ID not in data.")
            # Validate sim_start
            if sim_start_str is None:
                errors.append("Time of simulation start not in data.")
                sim_start = None
            else:
                try:
                    sim_start = datetime.strptime(sim_start_str, config.DATETIME_FORMAT)
                except ValueError:
                    errors.append("Time of simulation has an irregular format")
                    sim_start = None
            # Validate time horizon
            if time_horizon_str is None:
                errors.append("Time horizon not in data.")
                time_horizon = None
            else:
                try:
                    time_horizon = int(time_horizon_str)
                    if time_horizon <= 0:
                        errors.append("Time horizon must be a positive integer.")
                except (ValueError, TypeError):
                    errors.append("Time horizon must be a positive integer.")
                    time_horizon = None
            # Validate n_iter
            if n_iter_str is None:
                errors.append("Number of iterations not in data.")
                n_iter = None
            else:
                try:
                    n_iter = int(n_iter_str)
                    if (n_iter <= 0) or (n_iter > max_n_iter):
                        errors.append(f"Number of iterations must be 1 ~ {max_n_iter}.")
                except (ValueError, TypeError):
                    errors.append("Time horizon must be a positive integer.")
                    n_iter = None

            # Do not proceed if obvious request errors are detected
            if errors:
                status = 400
                progress = "Aborted."

            # ===================================
            # Try to download and preprocess data
            # ===================================
            else:
                try:
                    timeline, catalog_ids, dob = preprocess_for_inference(
                        patient_id=patient_id,
                        model=model,
                        start=None,
                        end=sim_start_str,  # This arg expects strings,
                        db_schema=db_schema,
                    )
                    preprocess_success = True

                # NOTE: If patient_id is not found, `preprocess_for_inference()` raises ValueError
                except ValueError as e:
                    errors.append(str(e))
                    status = 400
                    progress = "Aborted."
                    preprocess_success = False

                # Put in the queue
                if preprocess_success:
                    horizon_start = sim_start - dob
                    temp_queue.put(
                        (
                            timeline,
                            catalog_ids,
                            n_iter,
                            horizon_start,
                            timedelta(days=time_horizon),
                            simulation_id,
                        ),
                        timeout=TIMEOUT,
                    )
                    status = 202
                    progress = "Data preprocessing completed."

            # ===================================
            # Update the request status
            # ===================================
            status_board[simulation_id] = {
                "status": status,
                "progress": progress,
                "errors": errors,
                "updated": datetime.now().isoformat(),
            }
            request_queue.task_done()

        # Catch all other errors
        except Exception as e:
            traceback.print_exc()
            raise RuntimeError("Error during request process") from e


def _watch_products(
    orch: WatcherOrchestrator,
    status_board: DictProxy,
    product_store: DictProxy,
):
    """Gets products and put them in a dict."""
    for prod in orch:
        df, simulation_id = prod
        # NOTE:
        #   product_store must be updated first, otherwise, the status_board may be read first while the
        #   product has not yet been saved.
        # Make everyhting to strings for consistency
        df = df.astype(str)
        # Add the product to the store
        product_store[simulation_id] = df.to_json(orient="records")

        # Update status
        # NOTE: As mentioned above, this must follow the product_store update
        status_board[simulation_id] = {
            "status": 200,
            "progress": "Simulation completed",
            "errors": [],
            "updated": datetime.now().isoformat(),
        }


def _watch_expired(
    status_board: DictProxy,
    product_store: DictProxy,
):
    """Removes expired requests from memory store."""
    while True:
        # DictProxy -> Dict (make a copy)
        status_dict = dict(status_board)
        # Check expired
        df = pd.DataFrame.from_dict(status_dict, orient="index")
        df["simulation_id"] = df.index
        if not df.empty:
            df["updated"] = pd.to_datetime(df["updated"])
            expired = df["status"].isin([200, 400]) & (
                (datetime.now() - df["updated"]).dt.total_seconds() > EXPIRES_AFTER
            )
            expired_ids = df.loc[expired, "simulation_id"].tolist()
            # Remove expired requests
            for req in expired_ids:
                _ = status_board.pop(req, None)
                _ = product_store.pop(req, None)
        else:
            json_data = "[]"
        # Wait for the next round
        time.sleep(EXPIRES_AFTER)


# endregion


# ===========================
# Main functions to be called
# ===========================
# region: Main functions
[docs] def start_simulators( blueprint: str, log_dir: str, gpu_ids: list[str], n_preprocess_workers: int, db_schema: str = "public", max_batch_size: int = 256, max_length: int = 10000, return_generated_parts_only: bool = True, return_unfinished: bool = False, ) -> Flask: """ Starts all background simulator processes including: - Preprocessing workers - WatcherOrchestrator - Expiration manager This function should be called separately from the Flask app, typically before launching Gunicorn or from a separate process. **Example Usage** .. code-block:: python from watcher.watcher_api import start_simulators start_simulators(...) Args: blueprint (str): Model blueprint name. log_dir (str): Path to directory for saving logs. gpu_ids (list[str]): List of GPU device IDs to assign for simulation. n_preprocess_workers (int): Number of multiprocessing workers for data preprocessing. db_schema (str, optional): PostgreSQL schema name. Defaults to "public". max_batch_size (int, optional): Max batch size for simulation. Defaults to 256. max_length (int, optional): Max sequence length for simulation input. Defaults to 10000. return_generated_parts_only (bool, optional): If True, returns only generated output. Defaults to True. return_unfinished (bool, optional): If True, includes unfinished trajectories. Defaults to False. """ if not hasattr(app, "_simulators_started"): print("Please wait for API initialization...") # Setup objects manager = Manager() status_board = manager.dict() product_store = manager.dict() temp_queue = JoinableQueue() request_queue = JoinableQueue() # Generator -> orchestrator def _gen_task(temp_queue): while True: yield temp_queue.get() task_gen = _gen_task(temp_queue) print("Building watcher orchestrator...") orch = WatcherOrchestrator( task_generator=task_gen, log_dir=log_dir, blueprint=blueprint, gpu_ids=gpu_ids, max_batch_size=max_batch_size, max_length=max_length, return_generated_parts_only=return_generated_parts_only, return_unfinished=return_unfinished, compile_model=False, temperature=1, ) orch.start() print("Orchestrator started") # Store in app config (Flask will pick these up when app starts) app.config["MAX_N_ITER"] = max_batch_size app.config["DB_SCHEMA"] = db_schema app.config["REQUEST_QUEUE"] = request_queue app.config["PRODUCT_STORE"] = product_store app.config["STATUS_BOARD"] = status_board app.config["STARTED_PROC"] = [] # Start preprocessors dummy_model = build_watcher(blueprint=blueprint, train=False).cpu() for _ in range(n_preprocess_workers): p = Process( target=_preprocess_requests, kwargs={ "request_queue": request_queue, "temp_queue": temp_queue, "status_board": status_board, "max_n_iter": max_batch_size, "model": dummy_model, "db_schema": db_schema, }, daemon=True, ) p.start() # Add to the list of started processes app.config["STARTED_PROC"].append(p) # Setup orchestrator app.config["ORCH"] = orch print("Orch done") p_prod = Process( target=_watch_products, kwargs={ "orch": orch, "status_board": status_board, "product_store": product_store, }, daemon=False, ) p_prod.start() print("Prodoct watcher started") app.config["STARTED_PROC"].append(p_prod) p_exp = Process( target=_watch_expired, kwargs={ "status_board": status_board, "product_store": product_store, }, daemon=True, ) p_exp.start() print("Expiration watcher started") app.config["STARTED_PROC"].append(p_exp) # Falg for simulators started app._simulators_started = True # Final message print("API is ready.") # Register handlers # _setup_signal_handlers() return app
# endregion