Skip to content

Instantly share code, notes, and snippets.

@akelleh
Created February 10, 2019 14:33
Show Gist options
  • Save akelleh/7a31184ce88453599188c568b09e062d to your computer and use it in GitHub Desktop.
Save akelleh/7a31184ce88453599188c568b09e062d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"\n",
"from dowhy.do_samplers.kernel_density_sampler import KernelDensitySampler\n",
"from dowhy.do_why import CausalModel\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from statsmodels.api import OLS"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"import dowhy.datasets\n",
"\n",
"data = dowhy.datasets.linear_dataset(beta=5,\n",
" num_common_causes=1,\n",
" num_instruments = 0,\n",
" num_samples=100,\n",
" treatment_is_binary=True)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>X0</th>\n",
" <th>v</th>\n",
" <th>y</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>-1.907772</td>\n",
" <td>0.0</td>\n",
" <td>-2.230260</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.707054</td>\n",
" <td>1.0</td>\n",
" <td>6.059161</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>-1.785862</td>\n",
" <td>0.0</td>\n",
" <td>-2.224435</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>-0.168836</td>\n",
" <td>1.0</td>\n",
" <td>4.101302</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>-0.600276</td>\n",
" <td>1.0</td>\n",
" <td>4.341337</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>0.125460</td>\n",
" <td>1.0</td>\n",
" <td>4.093887</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>-1.866742</td>\n",
" <td>0.0</td>\n",
" <td>-2.086384</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>-1.449010</td>\n",
" <td>0.0</td>\n",
" <td>-1.940293</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>-1.205083</td>\n",
" <td>0.0</td>\n",
" <td>-1.082687</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>-1.771017</td>\n",
" <td>0.0</td>\n",
" <td>-0.743547</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>-0.310801</td>\n",
" <td>0.0</td>\n",
" <td>-0.246174</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>-1.044497</td>\n",
" <td>1.0</td>\n",
" <td>4.259667</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>-2.090359</td>\n",
" <td>0.0</td>\n",
" <td>-3.496719</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>-3.464527</td>\n",
" <td>0.0</td>\n",
" <td>-4.716008</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>0.110692</td>\n",
" <td>0.0</td>\n",
" <td>-0.560330</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>-0.431212</td>\n",
" <td>1.0</td>\n",
" <td>5.258859</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>-0.307991</td>\n",
" <td>0.0</td>\n",
" <td>-1.340262</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>-0.382713</td>\n",
" <td>1.0</td>\n",
" <td>4.597534</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>-0.382750</td>\n",
" <td>0.0</td>\n",
" <td>0.409655</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>0.632260</td>\n",
" <td>0.0</td>\n",
" <td>1.115870</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <td>-2.797024</td>\n",
" <td>0.0</td>\n",
" <td>-3.380601</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <td>-0.812726</td>\n",
" <td>1.0</td>\n",
" <td>4.042101</td>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <td>0.410965</td>\n",
" <td>1.0</td>\n",
" <td>3.205867</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <td>-0.998865</td>\n",
" <td>0.0</td>\n",
" <td>-0.794360</td>\n",
" </tr>\n",
" <tr>\n",
" <th>24</th>\n",
" <td>-0.867534</td>\n",
" <td>0.0</td>\n",
" <td>0.449066</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25</th>\n",
" <td>-0.495761</td>\n",
" <td>0.0</td>\n",
" <td>0.839552</td>\n",
" </tr>\n",
" <tr>\n",
" <th>26</th>\n",
" <td>-1.456445</td>\n",
" <td>1.0</td>\n",
" <td>4.015004</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>-2.027530</td>\n",
" <td>0.0</td>\n",
" <td>-2.735320</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28</th>\n",
" <td>-0.471675</td>\n",
" <td>1.0</td>\n",
" <td>5.787113</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29</th>\n",
" <td>-0.052103</td>\n",
" <td>0.0</td>\n",
" <td>-0.862957</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>70</th>\n",
" <td>-1.186475</td>\n",
" <td>1.0</td>\n",
" <td>4.807903</td>\n",
" </tr>\n",
" <tr>\n",
" <th>71</th>\n",
" <td>-0.732498</td>\n",
" <td>0.0</td>\n",
" <td>-1.351341</td>\n",
" </tr>\n",
" <tr>\n",
" <th>72</th>\n",
" <td>0.944242</td>\n",
" <td>1.0</td>\n",
" <td>5.756376</td>\n",
" </tr>\n",
" <tr>\n",
" <th>73</th>\n",
" <td>-1.601822</td>\n",
" <td>0.0</td>\n",
" <td>-0.210960</td>\n",
" </tr>\n",
" <tr>\n",
" <th>74</th>\n",
" <td>-1.492551</td>\n",
" <td>0.0</td>\n",
" <td>-2.423871</td>\n",
" </tr>\n",
" <tr>\n",
" <th>75</th>\n",
" <td>-0.812373</td>\n",
" <td>0.0</td>\n",
" <td>-0.281548</td>\n",
" </tr>\n",
" <tr>\n",
" <th>76</th>\n",
" <td>-0.722041</td>\n",
" <td>1.0</td>\n",
" <td>6.346155</td>\n",
" </tr>\n",
" <tr>\n",
" <th>77</th>\n",
" <td>-0.588892</td>\n",
" <td>0.0</td>\n",
" <td>-1.297204</td>\n",
" </tr>\n",
" <tr>\n",
" <th>78</th>\n",
" <td>-0.945065</td>\n",
" <td>0.0</td>\n",
" <td>0.371523</td>\n",
" </tr>\n",
" <tr>\n",
" <th>79</th>\n",
" <td>-0.702234</td>\n",
" <td>1.0</td>\n",
" <td>3.913889</td>\n",
" </tr>\n",
" <tr>\n",
" <th>80</th>\n",
" <td>-2.040678</td>\n",
" <td>0.0</td>\n",
" <td>-1.089276</td>\n",
" </tr>\n",
" <tr>\n",
" <th>81</th>\n",
" <td>-2.382935</td>\n",
" <td>0.0</td>\n",
" <td>-2.293790</td>\n",
" </tr>\n",
" <tr>\n",
" <th>82</th>\n",
" <td>-1.898813</td>\n",
" <td>0.0</td>\n",
" <td>-2.077551</td>\n",
" </tr>\n",
" <tr>\n",
" <th>83</th>\n",
" <td>-0.443291</td>\n",
" <td>0.0</td>\n",
" <td>-1.712443</td>\n",
" </tr>\n",
" <tr>\n",
" <th>84</th>\n",
" <td>-0.037311</td>\n",
" <td>0.0</td>\n",
" <td>-0.742133</td>\n",
" </tr>\n",
" <tr>\n",
" <th>85</th>\n",
" <td>-0.054639</td>\n",
" <td>1.0</td>\n",
" <td>4.450153</td>\n",
" </tr>\n",
" <tr>\n",
" <th>86</th>\n",
" <td>-1.649032</td>\n",
" <td>0.0</td>\n",
" <td>-1.563796</td>\n",
" </tr>\n",
" <tr>\n",
" <th>87</th>\n",
" <td>0.085290</td>\n",
" <td>0.0</td>\n",
" <td>-0.118241</td>\n",
" </tr>\n",
" <tr>\n",
" <th>88</th>\n",
" <td>-1.603933</td>\n",
" <td>1.0</td>\n",
" <td>3.221698</td>\n",
" </tr>\n",
" <tr>\n",
" <th>89</th>\n",
" <td>-1.178101</td>\n",
" <td>0.0</td>\n",
" <td>-0.355170</td>\n",
" </tr>\n",
" <tr>\n",
" <th>90</th>\n",
" <td>-0.389014</td>\n",
" <td>1.0</td>\n",
" <td>4.457656</td>\n",
" </tr>\n",
" <tr>\n",
" <th>91</th>\n",
" <td>-2.252745</td>\n",
" <td>0.0</td>\n",
" <td>-2.413238</td>\n",
" </tr>\n",
" <tr>\n",
" <th>92</th>\n",
" <td>-1.958699</td>\n",
" <td>0.0</td>\n",
" <td>-3.063983</td>\n",
" </tr>\n",
" <tr>\n",
" <th>93</th>\n",
" <td>-2.165287</td>\n",
" <td>0.0</td>\n",
" <td>-1.934050</td>\n",
" </tr>\n",
" <tr>\n",
" <th>94</th>\n",
" <td>-2.802529</td>\n",
" <td>0.0</td>\n",
" <td>-2.556217</td>\n",
" </tr>\n",
" <tr>\n",
" <th>95</th>\n",
" <td>-0.049369</td>\n",
" <td>0.0</td>\n",
" <td>-0.277302</td>\n",
" </tr>\n",
" <tr>\n",
" <th>96</th>\n",
" <td>-0.836338</td>\n",
" <td>0.0</td>\n",
" <td>-2.384613</td>\n",
" </tr>\n",
" <tr>\n",
" <th>97</th>\n",
" <td>-0.938599</td>\n",
" <td>0.0</td>\n",
" <td>-1.528677</td>\n",
" </tr>\n",
" <tr>\n",
" <th>98</th>\n",
" <td>-2.214974</td>\n",
" <td>0.0</td>\n",
" <td>-2.150973</td>\n",
" </tr>\n",
" <tr>\n",
" <th>99</th>\n",
" <td>-1.698840</td>\n",
" <td>1.0</td>\n",
" <td>3.217780</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>100 rows × 3 columns</p>\n",
"</div>"
],
"text/plain": [
" X0 v y\n",
"0 -1.907772 0.0 -2.230260\n",
"1 0.707054 1.0 6.059161\n",
"2 -1.785862 0.0 -2.224435\n",
"3 -0.168836 1.0 4.101302\n",
"4 -0.600276 1.0 4.341337\n",
"5 0.125460 1.0 4.093887\n",
"6 -1.866742 0.0 -2.086384\n",
"7 -1.449010 0.0 -1.940293\n",
"8 -1.205083 0.0 -1.082687\n",
"9 -1.771017 0.0 -0.743547\n",
"10 -0.310801 0.0 -0.246174\n",
"11 -1.044497 1.0 4.259667\n",
"12 -2.090359 0.0 -3.496719\n",
"13 -3.464527 0.0 -4.716008\n",
"14 0.110692 0.0 -0.560330\n",
"15 -0.431212 1.0 5.258859\n",
"16 -0.307991 0.0 -1.340262\n",
"17 -0.382713 1.0 4.597534\n",
"18 -0.382750 0.0 0.409655\n",
"19 0.632260 0.0 1.115870\n",
"20 -2.797024 0.0 -3.380601\n",
"21 -0.812726 1.0 4.042101\n",
"22 0.410965 1.0 3.205867\n",
"23 -0.998865 0.0 -0.794360\n",
"24 -0.867534 0.0 0.449066\n",
"25 -0.495761 0.0 0.839552\n",
"26 -1.456445 1.0 4.015004\n",
"27 -2.027530 0.0 -2.735320\n",
"28 -0.471675 1.0 5.787113\n",
"29 -0.052103 0.0 -0.862957\n",
".. ... ... ...\n",
"70 -1.186475 1.0 4.807903\n",
"71 -0.732498 0.0 -1.351341\n",
"72 0.944242 1.0 5.756376\n",
"73 -1.601822 0.0 -0.210960\n",
"74 -1.492551 0.0 -2.423871\n",
"75 -0.812373 0.0 -0.281548\n",
"76 -0.722041 1.0 6.346155\n",
"77 -0.588892 0.0 -1.297204\n",
"78 -0.945065 0.0 0.371523\n",
"79 -0.702234 1.0 3.913889\n",
"80 -2.040678 0.0 -1.089276\n",
"81 -2.382935 0.0 -2.293790\n",
"82 -1.898813 0.0 -2.077551\n",
"83 -0.443291 0.0 -1.712443\n",
"84 -0.037311 0.0 -0.742133\n",
"85 -0.054639 1.0 4.450153\n",
"86 -1.649032 0.0 -1.563796\n",
"87 0.085290 0.0 -0.118241\n",
"88 -1.603933 1.0 3.221698\n",
"89 -1.178101 0.0 -0.355170\n",
"90 -0.389014 1.0 4.457656\n",
"91 -2.252745 0.0 -2.413238\n",
"92 -1.958699 0.0 -3.063983\n",
"93 -2.165287 0.0 -1.934050\n",
"94 -2.802529 0.0 -2.556217\n",
"95 -0.049369 0.0 -0.277302\n",
"96 -0.836338 0.0 -2.384613\n",
"97 -0.938599 0.0 -1.528677\n",
"98 -2.214974 0.0 -2.150973\n",
"99 -1.698840 1.0 3.217780\n",
"\n",
"[100 rows x 3 columns]"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data['dot_graph'] = 'digraph { v ->y;X0-> v;X0-> y;}'\n",
"df = data['df']\n",
"df['y'] = df['y'] + np.random.normal(size=len(df))\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<table class=\"simpletable\">\n",
"<caption>OLS Regression Results</caption>\n",
"<tr>\n",
" <th>Dep. Variable:</th> <td>y</td> <th> R-squared: </th> <td> 0.927</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Model:</th> <td>OLS</td> <th> Adj. R-squared: </th> <td> 0.925</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Method:</th> <td>Least Squares</td> <th> F-statistic: </th> <td> 622.1</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Date:</th> <td>Sun, 10 Feb 2019</td> <th> Prob (F-statistic):</th> <td>2.02e-56</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Time:</th> <td>09:18:58</td> <th> Log-Likelihood: </th> <td> -124.35</td>\n",
"</tr>\n",
"<tr>\n",
" <th>No. Observations:</th> <td> 100</td> <th> AIC: </th> <td> 252.7</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Df Residuals:</th> <td> 98</td> <th> BIC: </th> <td> 257.9</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Df Model:</th> <td> 2</td> <th> </th> <td> </td> \n",
"</tr>\n",
"<tr>\n",
" <th>Covariance Type:</th> <td>nonrobust</td> <th> </th> <td> </td> \n",
"</tr>\n",
"</table>\n",
"<table class=\"simpletable\">\n",
"<tr>\n",
" <td></td> <th>coef</th> <th>std err</th> <th>t</th> <th>P>|t|</th> <th>[0.025</th> <th>0.975]</th> \n",
"</tr>\n",
"<tr>\n",
" <th>X0</th> <td> 1.0777</td> <td> 0.064</td> <td> 16.781</td> <td> 0.000</td> <td> 0.950</td> <td> 1.205</td>\n",
"</tr>\n",
"<tr>\n",
" <th>v</th> <td> 5.1443</td> <td> 0.153</td> <td> 33.686</td> <td> 0.000</td> <td> 4.841</td> <td> 5.447</td>\n",
"</tr>\n",
"</table>\n",
"<table class=\"simpletable\">\n",
"<tr>\n",
" <th>Omnibus:</th> <td> 0.175</td> <th> Durbin-Watson: </th> <td> 1.970</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Prob(Omnibus):</th> <td> 0.916</td> <th> Jarque-Bera (JB): </th> <td> 0.169</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Skew:</th> <td> 0.092</td> <th> Prob(JB): </th> <td> 0.919</td>\n",
"</tr>\n",
"<tr>\n",
" <th>Kurtosis:</th> <td> 2.917</td> <th> Cond. No. </th> <td> 2.44</td>\n",
"</tr>\n",
"</table><br/><br/>Warnings:<br/>[1] Standard Errors assume that the covariance matrix of the errors is correctly specified."
],
"text/plain": [
"<class 'statsmodels.iolib.summary.Summary'>\n",
"\"\"\"\n",
" OLS Regression Results \n",
"==============================================================================\n",
"Dep. Variable: y R-squared: 0.927\n",
"Model: OLS Adj. R-squared: 0.925\n",
"Method: Least Squares F-statistic: 622.1\n",
"Date: Sun, 10 Feb 2019 Prob (F-statistic): 2.02e-56\n",
"Time: 09:18:58 Log-Likelihood: -124.35\n",
"No. Observations: 100 AIC: 252.7\n",
"Df Residuals: 98 BIC: 257.9\n",
"Df Model: 2 \n",
"Covariance Type: nonrobust \n",
"==============================================================================\n",
" coef std err t P>|t| [0.025 0.975]\n",
"------------------------------------------------------------------------------\n",
"X0 1.0777 0.064 16.781 0.000 0.950 1.205\n",
"v 5.1443 0.153 33.686 0.000 4.841 5.447\n",
"==============================================================================\n",
"Omnibus: 0.175 Durbin-Watson: 1.970\n",
"Prob(Omnibus): 0.916 Jarque-Bera (JB): 0.169\n",
"Skew: 0.092 Prob(JB): 0.919\n",
"Kurtosis: 2.917 Cond. No. 2.44\n",
"==============================================================================\n",
"\n",
"Warnings:\n",
"[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.\n",
"\"\"\""
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = OLS(df['y'], df[['X0', 'v']])\n",
"result = model.fit()\n",
"result.summary()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'digraph { v ->y;X0-> v;X0-> y;}'"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data['dot_graph']"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Error: Pygraphviz cannot be loaded. No module named 'pygraphviz'\n",
"Trying pydot ...\n",
"['X0']\n",
"yes\n",
"{'observed': 'yes'}\n",
"Model to find the causal effect of treatment v on outcome y\n"
]
}
],
"source": [
"causal_model= CausalModel(\n",
" data=data[\"df\"],\n",
" treatment=data[\"treatment_name\"],\n",
" outcome=data[\"outcome_name\"],\n",
" graph=data[\"dot_graph\"])"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'X0', 'U'}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'observed': 'yes'}\n",
"{'label': 'Unobserved Confounders', 'observed': 'no'}\n",
"There are unobserved common causes. Causal effect cannot be identified.\n",
"WARN: Do you want to continue by ignoring these unobserved confounders? [y/n] yes\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:[]\n"
]
}
],
"source": [
"identified_estimand = causal_model.identify_effect()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"import pymc3 as pm\n",
"import networkx as nx"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:pymc3:Auto-assigning NUTS sampler...\n",
"INFO:pymc3:Initializing NUTS using jitter+adapt_diag...\n",
"INFO:pymc3:Multiprocess sampling (4 chains in 4 jobs)\n",
"INFO:pymc3:NUTS: [y_sd, beta_y, v_sd, beta_v]\n",
"Sampling 4 chains: 100%|██████████| 6000/6000 [00:02<00:00, 2316.81draws/s]\n"
]
}
],
"source": [
"g = causal_model._graph.get_unconfounded_observed_subgraph()\n",
"df = causal_model._data\n",
"\n",
"data_types = {'X0': 'c', 'v': 'b', 'y': 'c'}\n",
"\n",
"\n",
"class CausalBayesianNetwork(object):\n",
" def __init__(self, g, df, data_types):\n",
" self.g = g\n",
" self.df = df\n",
" self.data_types = data_types\n",
" g_fit = nx.DiGraph(self.g)\n",
" _, self.fit_trace = self.fit_causal_model(g_fit, \n",
" self.df, \n",
" self.data_types)\n",
" \n",
" def apply_data_types(self, g, data_types):\n",
" for node in nx.topological_sort(g):\n",
" g.nodes()[node][\"variable_type\"] = data_types[node]\n",
" return g\n",
"\n",
" def apply_parents(self, g):\n",
" for node in nx.topological_sort(g):\n",
" if not g.nodes()[node].get(\"parent_names\"):\n",
" g.nodes()[node][\"parent_names\"] = [parent for parent, _ in g.in_edges(node)]\n",
" return g\n",
"\n",
" def apply_parameters(self, g, df, initialization_trace=None):\n",
" for node in nx.topological_sort(g):\n",
" parent_names = g.nodes()[node][\"parent_names\"]\n",
" if parent_names:\n",
" if not initialization_trace:\n",
" sd = np.array([df[node].std()] + (df[node].std() / df[parent_names].std() ).tolist())\n",
" mu = np.array([df[node].std()] + (df[node].std() / df[parent_names].std() ).tolist())\n",
" node_sd = df[node].std()\n",
" else:\n",
" node_sd = initialization_trace[\"{}_sd\".format(node)].mean()\n",
" mu = initialization_trace[\"beta_{}\".format(node)].mean(axis=0)\n",
" sd = initialization_trace[\"beta_{}\".format(node)].std(axis=0)\n",
" g.nodes()[node][\"parameters\"] = pm.Normal(\"beta_{}\".format(node), mu=mu, sd=sd, shape=len(parent_names)+1)\n",
" g.nodes()[node][\"sd\"] = pm.Exponential(\"{}_sd\".format(node), lam=node_sd)\n",
" return g\n",
"\n",
" def build_bayesian_network(self, g, df):\n",
" for node in nx.topological_sort(g):\n",
" if g.nodes()[node][\"parent_names\"]:\n",
" mu = g.nodes()[node][\"parameters\"][0] # intercept\n",
" mu += pm.math.dot(df[g.nodes()[node][\"parent_names\"]], \n",
" g.nodes()[node][\"parameters\"][1:])\n",
" if g.nodes()[node][\"variable_type\"] == 'c':\n",
" sd = g.nodes()[node][\"sd\"]\n",
" g.nodes()[node][\"variable\"] = pm.Normal(\"{}\".format(node), \n",
" mu=mu, sd=sd, \n",
" observed=df[node])\n",
" elif g.nodes()[node][\"variable_type\"] == 'b':\n",
" g.nodes()[node][\"variable\"] = pm.Bernoulli(\"{}\".format(node), \n",
" logit_p=mu, \n",
" observed=df[node])\n",
" else:\n",
" raise Exception(\"Unrecognized variable type: {}\".format(g.nodes()[node][\"variable_type\"]))\n",
" return g\n",
"\n",
" def fit_causal_model(self, g, df, data_types, initialization_trace=None):\n",
" if nx.is_directed_acyclic_graph(g):\n",
" with pm.Model() as model: \n",
" g = self.apply_data_types(g, data_types)\n",
" g = self.apply_parents(g)\n",
" g = self.apply_parameters(g, df, initialization_trace=initialization_trace)\n",
" g = self.build_bayesian_network(g, df)\n",
" trace = pm.sample(1000, tune=500)\n",
" else:\n",
" raise Exception(\"Graph is not a DAG!\")\n",
" return g, trace\n",
"\n",
" def sample_prior_causal_model(self, g, df, data_types, initialization_trace):\n",
" if nx.is_directed_acyclic_graph(g):\n",
" with pm.Model() as model: \n",
" g = self.apply_data_types(g, data_types)\n",
" g = self.apply_parents(g)\n",
" g = self.apply_parameters(g, df, initialization_trace=initialization_trace)\n",
" g = self.build_bayesian_network(g, df)\n",
" trace = pm.sample_prior_predictive(1)\n",
" else:\n",
" raise Exception(\"Graph is not a DAG!\")\n",
" return g, trace\n",
"\n",
" def do_x_surgery(self, g, x):\n",
" for xi in x.keys():\n",
" g.remove_edges_from([(parent, child) for (parent, child) in g.in_edges(xi)])\n",
" g.nodes()[xi][\"parent_names\"] = []\n",
" return g\n",
"\n",
" def make_intervention_effective(self, df, x):\n",
" df_intervened = df.copy()\n",
" for k, v in x.items():\n",
" df_intervened[k] = v\n",
" return df_intervened\n",
"\n",
" def do(self, x):\n",
" g_for_surgery = nx.DiGraph(self.g)\n",
" g_modified = self.do_x_surgery(g_for_surgery, x)\n",
" df_intervened = self.make_intervention_effective(self.df, x)\n",
" g_modified, trace = self.sample_prior_causal_model(g_modified, \n",
" df_intervened, \n",
" self.data_types, \n",
" initialization_trace=self.fit_trace)\n",
" return trace\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:pymc3:Auto-assigning NUTS sampler...\n",
"INFO:pymc3:Initializing NUTS using jitter+adapt_diag...\n",
"INFO:pymc3:Multiprocess sampling (4 chains in 4 jobs)\n",
"INFO:pymc3:NUTS: [y_sd, beta_y, v_sd, beta_v]\n",
"Sampling 4 chains: 100%|██████████| 6000/6000 [00:02<00:00, 2280.31draws/s]\n",
"WARNING:pymc3:The acceptance probability does not match the target. It is 0.8827395991724155, but should be close to 0.8. Try to increase the number of tuning steps.\n"
]
}
],
"source": [
"causal_bayes_net = CausalBayesianNetwork(g, df, data_types)\n"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"x = {'v': 1.}\n",
"df_1 = causal_bayes_net.do(x)\n",
"x = {'v': 0.}\n",
"df_0 = causal_bayes_net.do(x)"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAKkAAAAPBAMAAABtvvLvAAAAMFBMVEX///8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAv3aB7AAAAD3RSTlMAzWYQMplU74mrdiK7RN1/7zyFAAAACXBIWXMAAA7EAAAOxAGVKw4bAAACy0lEQVQ4Ea2UzWsTQRjGn2Rjsvk0UBQ8NSpo1UNXPNj20ObiwYsJSqQqpYtaiieVKlLqxypW2lJwLUY9mSCIWg8NinizAU9i0foPWL0I4qGtYr+wrs87s2nzBziHzOwzv/z23ZmdRWB76CjYjMJUFsb44DTHoTwOFYeGHMkB89FN6Q4AilUgRpqyftj2tEtGtaYVEc9zJGmFuYoNwD2OUy6aPc9zEe7j1Wa0V9h9BxSrwIBt/GAmYR5RG+i1eCW4VgSGn8s0jgCXcBLYy/ETFyeAKFqHfwHBczDTQKyftQqrwJ3AdYIMwxkYaeNBiVaFa0VUlGxXgVJlFGgH4ltdOACfO0hrgiWze7FF7lMD3wFn6WFoVmF8AZqlVsG1omZttFDKzrzFDS5kwCWSLGtso4vkAuDUrApcBmZsFSb/WGamzqoV0W2Ps9Bt1op6TTawR1kTDOXmHXkkVxCsiNVnZ/f9prWsw9LPWyTXatWKhBVf1dLkEnB73oJRVlYWrayTDkIraAOtPptcihPNlXVoeg31VqUQ4Qf5UTsf7pq9AhNiNeYYSa2TeYSWuM60sgmbcuOsNefosOfz4nRdrUohZHdFflEAdiG2aJ1W1lia0doKhLO+VdgCDLUCKgy6mOirsyoF7vCxeS++JXmA07lTjrKmqr6VuxVeaIFYNSvgMt8BW4URCyFu5tq6isLGedZqibUH8Zdz3JEdxUL3xTIo07Umqiz5TKGweNxnCVamxKPCDmJj61ZZuGiGRwNfRRpyEKsMcEcqPEM05jihVsA/BeiXY0RWgf4pYBgh/37dCq14heAAOtJ4MzT4EJ8s7KYsUuUDOhzEuFv4hv1ZdpcBxSowZRsXmDFMNiBYBiYsXgmuFfHOoiUHv9Hz/iJUlK+L+XE5g9c2z/T9+cNAy/hd/mPUG4NiFWiMPMv64cFOfl2OXduU0bhWcO7/t39S7fhfQwkz1AAAAABJRU5ErkJggg==\n",
"text/latex": [
"$$5.284710452048118$$"
],
"text/plain": [
"5.284710452048118"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(df_1['y'] - df_0['y']).mean()"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAALQAAAAPBAMAAAC/7vi3AAAAMFBMVEX///8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAv3aB7AAAAD3RSTlMAiXZmMs1UEN0i77urRJlR0qN3AAAACXBIWXMAAA7EAAAOxAGVKw4bAAADTElEQVQ4EbWUz28UZRjHP7M/Ztrd2eksEi803aXAQQ8wNOEgMen25sHQNTX2gnEJHI0OJiSkB3Y5wIUmbOLByMFuYmRjRN0YE6ImOlpjAocyMTF665RgCCi1WrbLluLyvO9bEv4B5vDM933f73zmned53gFrz07Udeavr+G93feUPkJmfqVY5EzxmhpmErDGd+KMj8Uy3LG7YozOymgTHV5ZKhZ9ZsIjNcPTijc43RS/E1GKSXBrMpgnPRgMtpyPuRTKcCjCeZM/8eAXOBV7VWOcC4bb6DAp9ogvB/9ieFqlW1hVedwLybXsAEcGqSt48rrIbeEmsvZiRD7hO44LF74n3zHG67CEDnfAhdd+amJ4SpGPsHvyeK6B27c6OA14/yiWpKAyUsVag+FDEctiZRecptAXZYxdKKGDD5LWSFYMTylGIgobcs/2cO8XHoZWAP5RtfIr9Q4pAVnZiG/VzPJVniPfEqWNzkM48IEK8t5C2aANLxIP9YTCf0pISnvcWH8e0k2N7pCTXT+ClwTdndof4w4O1hh5fYeUWhs/kV3HOshuBbA4cWubpxQLPplt9LKPNZiVNkGhZfuSjew6TjkbOd2A8wLqhtRnGKoZYwneCnRAPgeushAbnlIsJGSkruqah3f+2IzxNbouM19xbEPSLugB3A3texc/p76G1zbGfNku1XTAWdOIfOMJL994KiFeQjriQNuuaPQ+8doTU/c5IWgewGT8MqnNcKRB+pE2wvG9pYoJqapGe/0nCfb6qoy2KiPslzYJyWycRKMv6Em35/gKfU7QQRuma0NV0uvaqAw3QhOGOtJ+qjSap5X0KGnVfNgJH6okLL69urp5E7bUpFDtpdV/vihL907+Lh/tBm5Ldq2Nal0XXIIQVdW9nuZptd3iYpqCF3JN+Eb0FTnc/0urtJmWGXIR0yF3+VQaoZnp4bWM8XCY6aMD04mcOmmyqj4yWiGFn6s4XYbPFv9uFGZJl4X1mXyE1DY168zIiFwH13cu81vIBPzIXM0Y94WnAnTgXV+OVsLhpuYZxcnxn1VzZOUv0ODVFelZdg0WGb4sorgaS7R+2AoYG62QWVK/J2tFHtBGa0zcOnCpJsaPRm9jeFrJzDO6HgMVjziaVfbIdwAAAABJRU5ErkJggg==\n",
"text/latex": [
"$$0.2847946868997505$$"
],
"text/plain": [
"0.2847946868997505"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"1.96* (df_1['y'] - df_0['y']).std() / np.sqrt(len(df))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment