November 25, 20257 min read

PyMC on Remote GPU

If you've ever kicked off a PyMC model and watched the progress bar crawl through thousands of samples, you know the feeling. Hierarchical models with course effects, runner effects, and correlated parameters can take hours on a CPU. And if you're iterating on model structure? That's a lot of coffee.

The good news: PyMC's JAX backend combined with a cloud GPU can cut that time by 10-100x. The even better news: you don't need to own expensive hardware. Services like RunPod let you rent an RTX 4090 for about $0.40/hour—spin up a pod, run your sampling, pull the results, and shut it down.

This post walks through my complete setup for GPU-accelerated Bayesian modeling.

The Architecture

The workflow splits computation across two environments:

  • Local (MacBook): VS Code, notebooks, all my usual editing tools
  • Cloud (RunPod): Docker container with CUDA, JAX, PyMC, and a beefy GPU

Mutagen handles bidirectional file sync between them. I edit code locally, it syncs to the container in real-time, and when sampling finishes I pull the trained models back. Best of both worlds—comfortable local development with cloud GPU execution.

The Dockerfile

Here's the complete Dockerfile that powers the setup:

1# Base image for PyMC GPU analysis - CUDA 12.4 (compatible with RunPod driver 565)
2# Build: docker buildx build --platform linux/amd64 -t {docker-username}/pymc-base:v4 --push .
3
4FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
5
6# Environment setup for CUDA 12.4
7ENV PYTHONUNBUFFERED=1 \
8 PIP_NO_CACHE_DIR=1 \
9 DEBIAN_FRONTEND=noninteractive \
10 LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64:/usr/local/cuda-12.4/extras/CUPTI/lib64:${LD_LIBRARY_PATH}
11
12# Install Python 3.12 from deadsnakes PPA + core system dependencies
13RUN rm -rf /var/lib/apt/lists/* && \
14 apt-get update && apt-get install -y software-properties-common && \
15 add-apt-repository -y ppa:deadsnakes/ppa && \
16 apt-get update && apt-get install -y \
17 python3.12 \
18 python3.12-venv \
19 python3.12-dev \
20 git \
21 graphviz \
22 libgraphviz-dev \
23 curl \
24 rsync \
25 openssh-server \
26 && rm -rf /var/lib/apt/lists/*
27
28# Make Python 3.12 the default and install pip via ensurepip (avoids distutils issues)
29RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \
30 update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1 && \
31 python -m ensurepip --upgrade && \
32 python -m pip install --upgrade pip
33
34# Configure SSH server (key-only authentication)
35RUN mkdir /var/run/sshd && \
36 sed -i 's/#PermitRootLogin prohibit-password/PermitRootLogin prohibit-password/' /etc/ssh/sshd_config && \
37 sed -i 's/#PasswordAuthentication yes/PasswordAuthentication no/' /etc/ssh/sshd_config && \
38 sed -i 's/#PubkeyAuthentication yes/PubkeyAuthentication yes/' /etc/ssh/sshd_config
39
40# Create SSH directory for authorized keys
41RUN mkdir -p /root/.ssh && chmod 700 /root/.ssh
42
43# Install JAX with CUDA 12 support
44RUN pip install "jax[cuda12]>=0.4.30"
45
46# Install PyMC stack with dependencies
47RUN pip install \
48 "pymc>=5.6.1" \
49 "numpyro>=0.19.0" \
50 "pymc-extras>=0.5.0" \
51 "blackjax>=1.3" \
52 "arviz>=0.22.0"
53
54# Install base Python dependencies
55RUN pip install \
56 "duckdb>=1.0.0" \
57 "pandas>=2.0.0" \
58 "numpy>=1.24.0" \
59 "matplotlib>=3.8.0" \
60 "seaborn>=0.13.2" \
61 "plotly>=5.18.0" \
62 "igraph>=0.11.0" \
63 "scipy>=1.11.0" \
64 "graphviz>=0.20.3" \
65 "ipywidgets>=8.1.7" \
66 "jupyterlab>=4.0.0" \
67 "ipykernel>=6.0.0" \
68 "networkx>=3.0" \
69 "qrcode[pil]>=8.2" \
70 "requests>=2.31.0" \
71 "tqdm>=4.66.0" \
72 "rich" \
73 "nvitop" \
74 "psutil" \
75 "gpustat"
76
77# Register Python kernel for Jupyter
78RUN python -m ipykernel install --name=python3 --display-name="Python 3.12 (PyMC CUDA)"
79
80# Create base working directories
81RUN mkdir -p /workspace/analysis /workspace/data
82
83# Install development tools (last layer - changes frequently)
84RUN rm -rf /var/lib/apt/lists/* && apt-get update && apt-get install -y \
85 tmux \
86 htop \
87 vim \
88 && rm -rf /var/lib/apt/lists/*
89
90WORKDIR /workspace/analysis
91
92# Copy and set up entrypoint script
93COPY entrypoint.sh /entrypoint.sh
94RUN chmod +x /entrypoint.sh
95
96# Expose ports (inherited by derived images)
97EXPOSE 22 8888
98
99ENTRYPOINT ["/entrypoint.sh"]

A few things worth noting:

  • NVIDIA CUDA 12.4 base image: Includes cuDNN for accelerated operations. Make sure this matches your cloud provider's driver version.
  • Python 3.12 via deadsnakes: Ubuntu 22.04 ships with Python 3.10, but modern PyMC benefits from 3.12's performance improvements.
  • JAX installed before PyMC: Order matters here. JAX with CUDA support needs to be in place before PyMC pulls in its JAX dependencies.
  • SSH + JupyterLab: The entrypoint starts both services, giving you flexibility—SSH in directly or access notebooks through the browser.

The entrypoint script starts the services and confirms GPU detection:

1#!/bin/bash
2
3# Start SSH daemon
4echo "Starting SSH daemon..."
5/usr/sbin/sshd || echo "Warning: sshd failed to start"
6
7# Start JupyterLab in background (accessible on port 8888)
8echo "Starting JupyterLab on port 8888..."
9cd /workspace/analysis
10jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root \
11 --ServerApp.token='' --ServerApp.password='' &
12
13# Check JAX GPU detection
14echo "Checking JAX GPU detection..."
15python -c "import jax; devices = jax.devices(); print(f'JAX detected {len(devices)} device(s): {devices}')"
16
17echo "Services started. Container ready."
18
19# Keep container alive
20if [ $# -eq 0 ]; then
21 exec tail -f /dev/null
22else
23 exec "$@"
24fi

The VS Code Workflow

VS Code tasks turn multi-step operations into single clicks. Here's my tasks.json:

1{
2 "version": "2.0.0",
3 "inputs": [
4 {
5 "id": "runpodHost",
6 "type": "promptString",
7 "description": "RunPod SSH host (e.g., 203.57.40.77)",
8 "default": "203.57.40.77"
9 },
10 {
11 "id": "runpodPort",
12 "type": "promptString",
13 "description": "RunPod SSH port",
14 "default": "10225"
15 }
16 ],
17 "tasks": [
18 {
19 "label": "RunPod: Push Data",
20 "type": "shell",
21 "command": "rsync -avz --progress ~/UltraSignup/analysis/data/ -e 'ssh -p ${input:runpodPort} -i ~/.ssh/id_ed25519' root@${input:runpodHost}:/workspace/analysis/data/",
22 "group": "none"
23 },
24 {
25 "label": "RunPod: Start Mutagen Sync",
26 "type": "shell",
27 "command": "mutagen sync terminate ultrasignup 2>/dev/null; mutagen sync create ~/UltraSignup/analysis root@${input:runpodHost}:${input:runpodPort}:/workspace/analysis --name=ultrasignup --ignore='.venv,__pycache__,*.pyc,.ipynb_checkpoints' && echo '✅ Mutagen sync started'",
28 "group": "none"
29 },
30 {
31 "label": "RunPod: Stop Mutagen Sync",
32 "type": "shell",
33 "command": "mutagen sync terminate ultrasignup && echo '✅ Mutagen sync stopped'",
34 "group": "none"
35 },
36 {
37 "label": "RunPod: Mutagen Status",
38 "type": "shell",
39 "command": "mutagen sync list",
40 "group": "none"
41 },
42 {
43 "label": "RunPod: Pull Models",
44 "type": "shell",
45 "command": "rsync -avz --progress -e 'ssh -p ${input:runpodPort} -i ~/.ssh/id_ed25519' root@${input:runpodHost}:/workspace/analysis/data/cache/ ~/UltraSignup/analysis/data/cache/",
46 "group": "none"
47 },
48 {
49 "label": "RunPod: Full Setup (Data + Sync)",
50 "dependsOn": ["RunPod: Push Data", "RunPod: Start Mutagen Sync"],
51 "dependsOrder": "sequence",
52 "group": "none"
53 },
54 {
55 "label": "RunPod: SSH Connect",
56 "type": "shell",
57 "command": "ssh -p ${input:runpodPort} -i ~/.ssh/id_ed25519 root@${input:runpodHost}",
58 "group": "none",
59 "presentation": {
60 "focus": true
61 }
62 }
63 ]
64}

The key tasks:

  • Push Data: One-time rsync of large data files (databases, datasets) to the container. These don't need continuous sync.
  • Start Mutagen Sync: Bidirectional sync for code and notebooks. Edit locally, changes appear on the container instantly.
  • Pull Models: After sampling completes, pull the trained model artifacts (NetCDF traces, pickled inference data) back to your local machine.
  • Full Setup: Chains Push Data → Start Sync for one-click pod initialization.
  • SSH Connect: Drop into the container for debugging or running commands directly.

The inputs section prompts for the RunPod host and port, which change each time you spin up a new pod. RunPod provides these in the pod details.

Putting It Together

Here's the workflow in practice:

  1. Spin up a RunPod pod using your custom template (pointing to your Docker image)
  2. Run "Full Setup" task in VS Code—pushes data and starts file sync
  3. Edit notebooks locally—Mutagen syncs changes to the container in real-time
  4. Execute cells—sampling runs on the GPU, 10-100x faster than your laptop
  5. Run "Pull Models" when done—brings trained artifacts back to local
  6. Terminate the pod—stop paying for GPU time

A typical 2-hour modeling session costs about $0.80 on an RTX 4090. Compare that to waiting 20+ hours on a MacBook, and the economics are obvious.

What's Next

I'm using this setup to build hierarchical Bayesian models for ultramarathon race analysis—estimating runner abilities, course difficulties, and DNF probabilities from historical race data. The models have thousands of parameters and would be impractical without GPU acceleration.

The full code is available at github.com/justmytwospence/ultrasignup-analysis. Feel free to adapt it for your own projects, and reach out if you have questions!