PyMC on Remote GPU
If you've ever kicked off a PyMC model and watched the progress bar crawl through thousands of samples, you know the feeling. Hierarchical models with course effects, runner effects, and correlated parameters can take hours on a CPU. And if you're iterating on model structure? That's a lot of coffee.
The good news: PyMC's JAX backend combined with a cloud GPU can cut that time by 10-100x. The even better news: you don't need to own expensive hardware. Services like RunPod let you rent an RTX 4090 for about $0.40/hour—spin up a pod, run your sampling, pull the results, and shut it down.
This post walks through my complete setup for GPU-accelerated Bayesian modeling.
The Architecture
The workflow splits computation across two environments:
- Local (MacBook): VS Code, notebooks, all my usual editing tools
- Cloud (RunPod): Docker container with CUDA, JAX, PyMC, and a beefy GPU
Mutagen handles bidirectional file sync between them. I edit code locally, it syncs to the container in real-time, and when sampling finishes I pull the trained models back. Best of both worlds—comfortable local development with cloud GPU execution.
The Dockerfile
Here's the complete Dockerfile that powers the setup:
1 # Base image for PyMC GPU analysis - CUDA 12.4 (compatible with RunPod driver 565) 2 # Build: docker buildx build --platform linux/amd64 -t {docker-username}/pymc-base:v4 --push . 3 4 FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 5 6 # Environment setup for CUDA 12.4 7 ENV PYTHONUNBUFFERED=1 \ 8 PIP_NO_CACHE_DIR=1 \ 9 DEBIAN_FRONTEND=noninteractive \ 10 LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64:/usr/local/cuda-12.4/extras/CUPTI/lib64:${LD_LIBRARY_PATH} 11 12 # Install Python 3.12 from deadsnakes PPA + core system dependencies 13 RUN rm -rf /var/lib/apt/lists/* && \ 14 apt-get update && apt-get install -y software-properties-common && \ 15 add-apt-repository -y ppa:deadsnakes/ppa && \ 16 apt-get update && apt-get install -y \ 17 python3.12 \ 18 python3.12-venv \ 19 python3.12-dev \ 20 git \ 21 graphviz \ 22 libgraphviz-dev \ 23 curl \ 24 rsync \ 25 openssh-server \ 26 && rm -rf /var/lib/apt/lists/* 27 28 # Make Python 3.12 the default and install pip via ensurepip (avoids distutils issues) 29 RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.12 1 && \ 30 update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1 && \ 31 python -m ensurepip --upgrade && \ 32 python -m pip install --upgrade pip 33 34 # Configure SSH server (key-only authentication) 35 RUN mkdir /var/run/sshd && \ 36 sed -i 's/#PermitRootLogin prohibit-password/PermitRootLogin prohibit-password/' /etc/ssh/sshd_config && \ 37 sed -i 's/#PasswordAuthentication yes/PasswordAuthentication no/' /etc/ssh/sshd_config && \ 38 sed -i 's/#PubkeyAuthentication yes/PubkeyAuthentication yes/' /etc/ssh/sshd_config 39 40 # Create SSH directory for authorized keys 41 RUN mkdir -p /root/.ssh && chmod 700 /root/.ssh 42 43 # Install JAX with CUDA 12 support 44 RUN pip install "jax[cuda12]>=0.4.30" 45 46 # Install PyMC stack with dependencies 47 RUN pip install \ 48 "pymc>=5.6.1" \ 49 "numpyro>=0.19.0" \ 50 "pymc-extras>=0.5.0" \ 51 "blackjax>=1.3" \ 52 "arviz>=0.22.0" 53 54 # Install base Python dependencies 55 RUN pip install \ 56 "duckdb>=1.0.0" \ 57 "pandas>=2.0.0" \ 58 "numpy>=1.24.0" \ 59 "matplotlib>=3.8.0" \ 60 "seaborn>=0.13.2" \ 61 "plotly>=5.18.0" \ 62 "igraph>=0.11.0" \ 63 "scipy>=1.11.0" \ 64 "graphviz>=0.20.3" \ 65 "ipywidgets>=8.1.7" \ 66 "jupyterlab>=4.0.0" \ 67 "ipykernel>=6.0.0" \ 68 "networkx>=3.0" \ 69 "qrcode[pil]>=8.2" \ 70 "requests>=2.31.0" \ 71 "tqdm>=4.66.0" \ 72 "rich" \ 73 "nvitop" \ 74 "psutil" \ 75 "gpustat" 76 77 # Register Python kernel for Jupyter 78 RUN python -m ipykernel install --name=python3 --display-name="Python 3.12 (PyMC CUDA)" 79 80 # Create base working directories 81 RUN mkdir -p /workspace/analysis /workspace/data 82 83 # Install development tools (last layer - changes frequently) 84 RUN rm -rf /var/lib/apt/lists/* && apt-get update && apt-get install -y \ 85 tmux \ 86 htop \ 87 vim \ 88 && rm -rf /var/lib/apt/lists/* 89 90 WORKDIR /workspace/analysis 91 92 # Copy and set up entrypoint script 93 COPY entrypoint.sh /entrypoint.sh 94 RUN chmod +x /entrypoint.sh 95 96 # Expose ports (inherited by derived images) 97 EXPOSE 22 8888 98 99 ENTRYPOINT ["/entrypoint.sh"]
A few things worth noting:
- NVIDIA CUDA 12.4 base image: Includes cuDNN for accelerated operations. Make sure this matches your cloud provider's driver version.
- Python 3.12 via deadsnakes: Ubuntu 22.04 ships with Python 3.10, but modern PyMC benefits from 3.12's performance improvements.
- JAX installed before PyMC: Order matters here. JAX with CUDA support needs to be in place before PyMC pulls in its JAX dependencies.
- SSH + JupyterLab: The entrypoint starts both services, giving you flexibility—SSH in directly or access notebooks through the browser.
The entrypoint script starts the services and confirms GPU detection:
1 #!/bin/bash 2 3 # Start SSH daemon 4 echo "Starting SSH daemon..." 5 /usr/sbin/sshd || echo "Warning: sshd failed to start" 6 7 # Start JupyterLab in background (accessible on port 8888) 8 echo "Starting JupyterLab on port 8888..." 9 cd /workspace/analysis 10 jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root \ 11 --ServerApp.token='' --ServerApp.password='' & 12 13 # Check JAX GPU detection 14 echo "Checking JAX GPU detection..." 15 python -c "import jax; devices = jax.devices(); print(f'JAX detected {len(devices)} device(s): {devices}')" 16 17 echo "Services started. Container ready." 18 19 # Keep container alive 20 if [ $# -eq 0 ]; then 21 exec tail -f /dev/null 22 else 23 exec "$@" 24 fi
The VS Code Workflow
VS Code tasks turn multi-step operations into single clicks. Here's my tasks.json:
1 { 2 "version": "2.0.0", 3 "inputs": [ 4 { 5 "id": "runpodHost", 6 "type": "promptString", 7 "description": "RunPod SSH host (e.g., 203.57.40.77)", 8 "default": "203.57.40.77" 9 }, 10 { 11 "id": "runpodPort", 12 "type": "promptString", 13 "description": "RunPod SSH port", 14 "default": "10225" 15 } 16 ], 17 "tasks": [ 18 { 19 "label": "RunPod: Push Data", 20 "type": "shell", 21 "command": "rsync -avz --progress ~/UltraSignup/analysis/data/ -e 'ssh -p ${input:runpodPort} -i ~/.ssh/id_ed25519' root@${input:runpodHost}:/workspace/analysis/data/", 22 "group": "none" 23 }, 24 { 25 "label": "RunPod: Start Mutagen Sync", 26 "type": "shell", 27 "command": "mutagen sync terminate ultrasignup 2>/dev/null; mutagen sync create ~/UltraSignup/analysis root@${input:runpodHost}:${input:runpodPort}:/workspace/analysis --name=ultrasignup --ignore='.venv,__pycache__,*.pyc,.ipynb_checkpoints' && echo '✅ Mutagen sync started'", 28 "group": "none" 29 }, 30 { 31 "label": "RunPod: Stop Mutagen Sync", 32 "type": "shell", 33 "command": "mutagen sync terminate ultrasignup && echo '✅ Mutagen sync stopped'", 34 "group": "none" 35 }, 36 { 37 "label": "RunPod: Mutagen Status", 38 "type": "shell", 39 "command": "mutagen sync list", 40 "group": "none" 41 }, 42 { 43 "label": "RunPod: Pull Models", 44 "type": "shell", 45 "command": "rsync -avz --progress -e 'ssh -p ${input:runpodPort} -i ~/.ssh/id_ed25519' root@${input:runpodHost}:/workspace/analysis/data/cache/ ~/UltraSignup/analysis/data/cache/", 46 "group": "none" 47 }, 48 { 49 "label": "RunPod: Full Setup (Data + Sync)", 50 "dependsOn": ["RunPod: Push Data", "RunPod: Start Mutagen Sync"], 51 "dependsOrder": "sequence", 52 "group": "none" 53 }, 54 { 55 "label": "RunPod: SSH Connect", 56 "type": "shell", 57 "command": "ssh -p ${input:runpodPort} -i ~/.ssh/id_ed25519 root@${input:runpodHost}", 58 "group": "none", 59 "presentation": { 60 "focus": true 61 } 62 } 63 ] 64 }
The key tasks:
- Push Data: One-time rsync of large data files (databases, datasets) to the container. These don't need continuous sync.
- Start Mutagen Sync: Bidirectional sync for code and notebooks. Edit locally, changes appear on the container instantly.
- Pull Models: After sampling completes, pull the trained model artifacts (NetCDF traces, pickled inference data) back to your local machine.
- Full Setup: Chains Push Data → Start Sync for one-click pod initialization.
- SSH Connect: Drop into the container for debugging or running commands directly.
The inputs section prompts for the RunPod host and port, which change each time you spin up a new pod. RunPod provides these in the pod details.
Putting It Together
Here's the workflow in practice:
- Spin up a RunPod pod using your custom template (pointing to your Docker image)
- Run "Full Setup" task in VS Code—pushes data and starts file sync
- Edit notebooks locally—Mutagen syncs changes to the container in real-time
- Execute cells—sampling runs on the GPU, 10-100x faster than your laptop
- Run "Pull Models" when done—brings trained artifacts back to local
- Terminate the pod—stop paying for GPU time
A typical 2-hour modeling session costs about $0.80 on an RTX 4090. Compare that to waiting 20+ hours on a MacBook, and the economics are obvious.
What's Next
I'm using this setup to build hierarchical Bayesian models for ultramarathon race analysis—estimating runner abilities, course difficulties, and DNF probabilities from historical race data. The models have thousands of parameters and would be impractical without GPU acceleration.
The full code is available at github.com/justmytwospence/ultrasignup-analysis. Feel free to adapt it for your own projects, and reach out if you have questions!