{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "(spline)=\n", "# Splines\n", "\n", ":::{post} June 4, 2022 \n", ":tags: patsy, regression, spline \n", ":category: beginner\n", ":author: Joshua Cook\n", ":::" ] }, { "cell_type": "markdown", "metadata": { "tags": [] }, "source": [ "## Introduction\n", "\n", "Often, the model we want to fit is not a perfect line between some $x$ and $y$.\n", "Instead, the parameters of the model are expected to vary over $x$.\n", "There are multiple ways to handle this situation, one of which is to fit a *spline*.\n", "Spline fit is effectively a sum of multiple individual curves (piecewise polynomials), each fit to a different section of $x$, that are tied together at their boundaries, often called *knots*.\n", "\n", "The spline is effectively multiple individual lines, each fit to a different section of $x$, that are tied together at their boundaries, often called *knots*.\n", "\n", "Below is a full working example of how to fit a spline using PyMC. The data and model are taken from [*Statistical Rethinking* 2e](https://xcelab.net/rm/statistical-rethinking/) by [Richard McElreath's](https://xcelab.net/rm/) {cite:p}`mcelreath2018statistical`.\n", "\n", "For more information on this method of non-linear modeling, I suggesting beginning with [chapter 5 of Bayesian Modeling and Computation in Python](https://bayesiancomputationbook.com/markdown/chp_05.html) {cite:p}`martin2021bayesian`." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "\n", "import arviz as az\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import pymc as pm\n", "\n", "from patsy import build_design_matrices, dmatrix" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%config InlineBackend.figure_format = \"retina\"\n", "\n", "seed = sum(map(ord, \"splines\"))\n", "rng = np.random.default_rng(seed)\n", "az.style.use(\"arviz-darkgrid\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cherry blossom data\n", "\n", "The data for this example is the number of days (`doy` for \"days of year\") that the cherry trees were in bloom in each year (`year`). \n", "For convenience, years missing a `doy` were dropped (which is a bad idea to deal with missing data in general!)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | year | \n", "doy | \n", "temp | \n", "temp_upper | \n", "temp_lower | \n", "
|---|---|---|---|---|---|
| count | \n", "787.000000 | \n", "787.00000 | \n", "787.000000 | \n", "787.000000 | \n", "787.000000 | \n", "
| mean | \n", "1533.395172 | \n", "104.92122 | \n", "6.100356 | \n", "6.937560 | \n", "5.263545 | \n", "
| std | \n", "291.122597 | \n", "6.25773 | \n", "0.683410 | \n", "0.811986 | \n", "0.762194 | \n", "
| min | \n", "851.000000 | \n", "86.00000 | \n", "4.690000 | \n", "5.450000 | \n", "2.610000 | \n", "
| 25% | \n", "1318.000000 | \n", "101.00000 | \n", "5.625000 | \n", "6.380000 | \n", "4.770000 | \n", "
| 50% | \n", "1563.000000 | \n", "105.00000 | \n", "6.060000 | \n", "6.800000 | \n", "5.250000 | \n", "
| 75% | \n", "1778.500000 | \n", "109.00000 | \n", "6.460000 | \n", "7.375000 | \n", "5.650000 | \n", "
| max | \n", "1980.000000 | \n", "124.00000 | \n", "8.300000 | \n", "12.100000 | \n", "7.740000 | \n", "
| \n", " | year | \n", "doy | \n", "temp | \n", "temp_upper | \n", "temp_lower | \n", "
|---|---|---|---|---|---|
| 0 | \n", "812 | \n", "92.0 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
| 1 | \n", "815 | \n", "105.0 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
| 2 | \n", "831 | \n", "96.0 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
| 3 | \n", "851 | \n", "108.0 | \n", "7.38 | \n", "12.10 | \n", "2.66 | \n", "
| 4 | \n", "853 | \n", "104.0 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
| 5 | \n", "864 | \n", "100.0 | \n", "6.42 | \n", "8.69 | \n", "4.14 | \n", "
| 6 | \n", "866 | \n", "106.0 | \n", "6.44 | \n", "8.11 | \n", "4.77 | \n", "
| 7 | \n", "869 | \n", "95.0 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
| 8 | \n", "889 | \n", "104.0 | \n", "6.83 | \n", "8.48 | \n", "5.19 | \n", "
| 9 | \n", "891 | \n", "109.0 | \n", "6.98 | \n", "8.96 | \n", "5.00 | \n", "
Sampler Progress
\n", "Total Chains: 4
\n", "Active Chains: 0
\n", "\n", " Finished Chains:\n", " 4\n", "
\n", "Sampling for now
\n", "\n", " Estimated Time to Completion:\n", " now\n", "
\n", "\n", " \n", "| Progress | \n", "Draws | \n", "Divergences | \n", "Step Size | \n", "Gradients/Draw | \n", "
|---|---|---|---|---|
| \n", " \n", " | \n", "2000 | \n", "0 | \n", "0.53 | \n", "7 | \n", "
| \n", " \n", " | \n", "2000 | \n", "0 | \n", "0.52 | \n", "15 | \n", "
| \n", " \n", " | \n", "2000 | \n", "0 | \n", "0.52 | \n", "15 | \n", "
| \n", " \n", " | \n", "2000 | \n", "0 | \n", "0.51 | \n", "15 | \n", "
| \n", " | mean | \n", "sd | \n", "hdi_3% | \n", "hdi_97% | \n", "mcse_mean | \n", "mcse_sd | \n", "ess_bulk | \n", "ess_tail | \n", "r_hat | \n", "
|---|---|---|---|---|---|---|---|---|---|
| a | \n", "103.743 | \n", "0.733 | \n", "102.370 | \n", "105.147 | \n", "0.021 | \n", "0.015 | \n", "1276.0 | \n", "2009.0 | \n", "1.0 | \n", "
| w[0] | \n", "-1.827 | \n", "2.251 | \n", "-6.169 | \n", "2.267 | \n", "0.029 | \n", "0.028 | \n", "5907.0 | \n", "3251.0 | \n", "1.0 | \n", "
| w[1] | \n", "-1.397 | \n", "2.129 | \n", "-5.272 | \n", "2.697 | \n", "0.034 | \n", "0.027 | \n", "3952.0 | \n", "3177.0 | \n", "1.0 | \n", "
| w[2] | \n", "-1.294 | \n", "1.953 | \n", "-5.002 | \n", "2.329 | \n", "0.036 | \n", "0.026 | \n", "2938.0 | \n", "2630.0 | \n", "1.0 | \n", "
| w[3] | \n", "3.610 | \n", "1.492 | \n", "0.879 | \n", "6.559 | \n", "0.028 | \n", "0.020 | \n", "2866.0 | \n", "3473.0 | \n", "1.0 | \n", "
| w[4] | \n", "0.037 | \n", "1.515 | \n", "-2.648 | \n", "3.062 | \n", "0.029 | \n", "0.020 | \n", "2736.0 | \n", "3097.0 | \n", "1.0 | \n", "
| w[5] | \n", "2.685 | \n", "1.673 | \n", "-0.436 | \n", "5.924 | \n", "0.031 | \n", "0.022 | \n", "2929.0 | \n", "3015.0 | \n", "1.0 | \n", "
| w[6] | \n", "-1.479 | \n", "1.596 | \n", "-4.494 | \n", "1.409 | \n", "0.032 | \n", "0.022 | \n", "2569.0 | \n", "3082.0 | \n", "1.0 | \n", "
| w[7] | \n", "-1.578 | \n", "1.505 | \n", "-4.573 | \n", "1.134 | \n", "0.029 | \n", "0.022 | \n", "2660.0 | \n", "2458.0 | \n", "1.0 | \n", "
| w[8] | \n", "5.536 | \n", "1.545 | \n", "2.757 | \n", "8.435 | \n", "0.030 | \n", "0.021 | \n", "2690.0 | \n", "2868.0 | \n", "1.0 | \n", "
| w[9] | \n", "-0.126 | \n", "1.565 | \n", "-2.918 | \n", "2.900 | \n", "0.030 | \n", "0.022 | \n", "2637.0 | \n", "2754.0 | \n", "1.0 | \n", "
| w[10] | \n", "1.120 | \n", "1.614 | \n", "-1.962 | \n", "4.012 | \n", "0.030 | \n", "0.021 | \n", "2910.0 | \n", "2822.0 | \n", "1.0 | \n", "
| w[11] | \n", "4.663 | \n", "1.542 | \n", "1.739 | \n", "7.449 | \n", "0.029 | \n", "0.020 | \n", "2844.0 | \n", "3233.0 | \n", "1.0 | \n", "
| w[12] | \n", "0.147 | \n", "1.587 | \n", "-2.901 | \n", "3.083 | \n", "0.033 | \n", "0.023 | \n", "2385.0 | \n", "3011.0 | \n", "1.0 | \n", "
| w[13] | \n", "2.820 | \n", "1.555 | \n", "-0.152 | \n", "5.696 | \n", "0.028 | \n", "0.020 | \n", "3130.0 | \n", "3304.0 | \n", "1.0 | \n", "
| w[14] | \n", "2.839 | \n", "1.577 | \n", "-0.085 | \n", "5.838 | \n", "0.030 | \n", "0.021 | \n", "2800.0 | \n", "2889.0 | \n", "1.0 | \n", "
| w[15] | \n", "0.509 | \n", "1.642 | \n", "-2.515 | \n", "3.596 | \n", "0.033 | \n", "0.024 | \n", "2445.0 | \n", "2960.0 | \n", "1.0 | \n", "
| w[16] | \n", "-2.807 | \n", "1.859 | \n", "-6.355 | \n", "0.548 | \n", "0.033 | \n", "0.024 | \n", "3207.0 | \n", "3229.0 | \n", "1.0 | \n", "
| w[17] | \n", "-6.127 | \n", "1.951 | \n", "-9.550 | \n", "-2.197 | \n", "0.036 | \n", "0.025 | \n", "2959.0 | \n", "2921.0 | \n", "1.0 | \n", "
| w[18] | \n", "-6.111 | \n", "1.911 | \n", "-9.586 | \n", "-2.420 | \n", "0.030 | \n", "0.022 | \n", "4173.0 | \n", "2946.0 | \n", "1.0 | \n", "
| sigma | \n", "5.951 | \n", "0.148 | \n", "5.664 | \n", "6.224 | \n", "0.002 | \n", "0.001 | \n", "6452.0 | \n", "2891.0 | \n", "1.0 | \n", "
Sampler Progress
\n", "Total Chains: 4
\n", "Active Chains: 0
\n", "\n", " Finished Chains:\n", " 4\n", "
\n", "Sampling for now
\n", "\n", " Estimated Time to Completion:\n", " now\n", "
\n", "\n", " \n", "| Progress | \n", "Draws | \n", "Divergences | \n", "Step Size | \n", "Gradients/Draw | \n", "
|---|---|---|---|---|
| \n", " \n", " | \n", "2000 | \n", "0 | \n", "0.52 | \n", "7 | \n", "
| \n", " \n", " | \n", "2000 | \n", "0 | \n", "0.53 | \n", "15 | \n", "
| \n", " \n", " | \n", "2000 | \n", "0 | \n", "0.52 | \n", "7 | \n", "
| \n", " \n", " | \n", "2000 | \n", "0 | \n", "0.53 | \n", "7 | \n", "