Created
February 10, 2019 14:33
-
-
Save akelleh/7a31184ce88453599188c568b09e062d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"\n", | |
"from dowhy.do_samplers.kernel_density_sampler import KernelDensitySampler\n", | |
"from dowhy.do_why import CausalModel\n", | |
"\n", | |
"import numpy as np\n", | |
"import pandas as pd\n", | |
"\n", | |
"from statsmodels.api import OLS" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import dowhy.datasets\n", | |
"\n", | |
"data = dowhy.datasets.linear_dataset(beta=5,\n", | |
" num_common_causes=1,\n", | |
" num_instruments = 0,\n", | |
" num_samples=100,\n", | |
" treatment_is_binary=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>X0</th>\n", | |
" <th>v</th>\n", | |
" <th>y</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>-1.907772</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-2.230260</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>0.707054</td>\n", | |
" <td>1.0</td>\n", | |
" <td>6.059161</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>-1.785862</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-2.224435</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>-0.168836</td>\n", | |
" <td>1.0</td>\n", | |
" <td>4.101302</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>-0.600276</td>\n", | |
" <td>1.0</td>\n", | |
" <td>4.341337</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>0.125460</td>\n", | |
" <td>1.0</td>\n", | |
" <td>4.093887</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>6</th>\n", | |
" <td>-1.866742</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-2.086384</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>-1.449010</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-1.940293</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>8</th>\n", | |
" <td>-1.205083</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-1.082687</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>9</th>\n", | |
" <td>-1.771017</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-0.743547</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>10</th>\n", | |
" <td>-0.310801</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-0.246174</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>11</th>\n", | |
" <td>-1.044497</td>\n", | |
" <td>1.0</td>\n", | |
" <td>4.259667</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>12</th>\n", | |
" <td>-2.090359</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-3.496719</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>13</th>\n", | |
" <td>-3.464527</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-4.716008</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>14</th>\n", | |
" <td>0.110692</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-0.560330</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>15</th>\n", | |
" <td>-0.431212</td>\n", | |
" <td>1.0</td>\n", | |
" <td>5.258859</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>16</th>\n", | |
" <td>-0.307991</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-1.340262</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>17</th>\n", | |
" <td>-0.382713</td>\n", | |
" <td>1.0</td>\n", | |
" <td>4.597534</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>18</th>\n", | |
" <td>-0.382750</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.409655</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>19</th>\n", | |
" <td>0.632260</td>\n", | |
" <td>0.0</td>\n", | |
" <td>1.115870</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>20</th>\n", | |
" <td>-2.797024</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-3.380601</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>21</th>\n", | |
" <td>-0.812726</td>\n", | |
" <td>1.0</td>\n", | |
" <td>4.042101</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>22</th>\n", | |
" <td>0.410965</td>\n", | |
" <td>1.0</td>\n", | |
" <td>3.205867</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>23</th>\n", | |
" <td>-0.998865</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-0.794360</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>24</th>\n", | |
" <td>-0.867534</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.449066</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>25</th>\n", | |
" <td>-0.495761</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.839552</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>26</th>\n", | |
" <td>-1.456445</td>\n", | |
" <td>1.0</td>\n", | |
" <td>4.015004</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>27</th>\n", | |
" <td>-2.027530</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-2.735320</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>28</th>\n", | |
" <td>-0.471675</td>\n", | |
" <td>1.0</td>\n", | |
" <td>5.787113</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>29</th>\n", | |
" <td>-0.052103</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-0.862957</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>...</th>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" <td>...</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>70</th>\n", | |
" <td>-1.186475</td>\n", | |
" <td>1.0</td>\n", | |
" <td>4.807903</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>71</th>\n", | |
" <td>-0.732498</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-1.351341</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>72</th>\n", | |
" <td>0.944242</td>\n", | |
" <td>1.0</td>\n", | |
" <td>5.756376</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>73</th>\n", | |
" <td>-1.601822</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-0.210960</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>74</th>\n", | |
" <td>-1.492551</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-2.423871</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>75</th>\n", | |
" <td>-0.812373</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-0.281548</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>76</th>\n", | |
" <td>-0.722041</td>\n", | |
" <td>1.0</td>\n", | |
" <td>6.346155</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>77</th>\n", | |
" <td>-0.588892</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-1.297204</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>78</th>\n", | |
" <td>-0.945065</td>\n", | |
" <td>0.0</td>\n", | |
" <td>0.371523</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>79</th>\n", | |
" <td>-0.702234</td>\n", | |
" <td>1.0</td>\n", | |
" <td>3.913889</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>80</th>\n", | |
" <td>-2.040678</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-1.089276</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>81</th>\n", | |
" <td>-2.382935</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-2.293790</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>82</th>\n", | |
" <td>-1.898813</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-2.077551</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>83</th>\n", | |
" <td>-0.443291</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-1.712443</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>84</th>\n", | |
" <td>-0.037311</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-0.742133</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>85</th>\n", | |
" <td>-0.054639</td>\n", | |
" <td>1.0</td>\n", | |
" <td>4.450153</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>86</th>\n", | |
" <td>-1.649032</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-1.563796</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>87</th>\n", | |
" <td>0.085290</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-0.118241</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>88</th>\n", | |
" <td>-1.603933</td>\n", | |
" <td>1.0</td>\n", | |
" <td>3.221698</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>89</th>\n", | |
" <td>-1.178101</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-0.355170</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>90</th>\n", | |
" <td>-0.389014</td>\n", | |
" <td>1.0</td>\n", | |
" <td>4.457656</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>91</th>\n", | |
" <td>-2.252745</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-2.413238</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>92</th>\n", | |
" <td>-1.958699</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-3.063983</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>93</th>\n", | |
" <td>-2.165287</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-1.934050</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>94</th>\n", | |
" <td>-2.802529</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-2.556217</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>95</th>\n", | |
" <td>-0.049369</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-0.277302</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>96</th>\n", | |
" <td>-0.836338</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-2.384613</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>97</th>\n", | |
" <td>-0.938599</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-1.528677</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>98</th>\n", | |
" <td>-2.214974</td>\n", | |
" <td>0.0</td>\n", | |
" <td>-2.150973</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>99</th>\n", | |
" <td>-1.698840</td>\n", | |
" <td>1.0</td>\n", | |
" <td>3.217780</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"<p>100 rows × 3 columns</p>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" X0 v y\n", | |
"0 -1.907772 0.0 -2.230260\n", | |
"1 0.707054 1.0 6.059161\n", | |
"2 -1.785862 0.0 -2.224435\n", | |
"3 -0.168836 1.0 4.101302\n", | |
"4 -0.600276 1.0 4.341337\n", | |
"5 0.125460 1.0 4.093887\n", | |
"6 -1.866742 0.0 -2.086384\n", | |
"7 -1.449010 0.0 -1.940293\n", | |
"8 -1.205083 0.0 -1.082687\n", | |
"9 -1.771017 0.0 -0.743547\n", | |
"10 -0.310801 0.0 -0.246174\n", | |
"11 -1.044497 1.0 4.259667\n", | |
"12 -2.090359 0.0 -3.496719\n", | |
"13 -3.464527 0.0 -4.716008\n", | |
"14 0.110692 0.0 -0.560330\n", | |
"15 -0.431212 1.0 5.258859\n", | |
"16 -0.307991 0.0 -1.340262\n", | |
"17 -0.382713 1.0 4.597534\n", | |
"18 -0.382750 0.0 0.409655\n", | |
"19 0.632260 0.0 1.115870\n", | |
"20 -2.797024 0.0 -3.380601\n", | |
"21 -0.812726 1.0 4.042101\n", | |
"22 0.410965 1.0 3.205867\n", | |
"23 -0.998865 0.0 -0.794360\n", | |
"24 -0.867534 0.0 0.449066\n", | |
"25 -0.495761 0.0 0.839552\n", | |
"26 -1.456445 1.0 4.015004\n", | |
"27 -2.027530 0.0 -2.735320\n", | |
"28 -0.471675 1.0 5.787113\n", | |
"29 -0.052103 0.0 -0.862957\n", | |
".. ... ... ...\n", | |
"70 -1.186475 1.0 4.807903\n", | |
"71 -0.732498 0.0 -1.351341\n", | |
"72 0.944242 1.0 5.756376\n", | |
"73 -1.601822 0.0 -0.210960\n", | |
"74 -1.492551 0.0 -2.423871\n", | |
"75 -0.812373 0.0 -0.281548\n", | |
"76 -0.722041 1.0 6.346155\n", | |
"77 -0.588892 0.0 -1.297204\n", | |
"78 -0.945065 0.0 0.371523\n", | |
"79 -0.702234 1.0 3.913889\n", | |
"80 -2.040678 0.0 -1.089276\n", | |
"81 -2.382935 0.0 -2.293790\n", | |
"82 -1.898813 0.0 -2.077551\n", | |
"83 -0.443291 0.0 -1.712443\n", | |
"84 -0.037311 0.0 -0.742133\n", | |
"85 -0.054639 1.0 4.450153\n", | |
"86 -1.649032 0.0 -1.563796\n", | |
"87 0.085290 0.0 -0.118241\n", | |
"88 -1.603933 1.0 3.221698\n", | |
"89 -1.178101 0.0 -0.355170\n", | |
"90 -0.389014 1.0 4.457656\n", | |
"91 -2.252745 0.0 -2.413238\n", | |
"92 -1.958699 0.0 -3.063983\n", | |
"93 -2.165287 0.0 -1.934050\n", | |
"94 -2.802529 0.0 -2.556217\n", | |
"95 -0.049369 0.0 -0.277302\n", | |
"96 -0.836338 0.0 -2.384613\n", | |
"97 -0.938599 0.0 -1.528677\n", | |
"98 -2.214974 0.0 -2.150973\n", | |
"99 -1.698840 1.0 3.217780\n", | |
"\n", | |
"[100 rows x 3 columns]" | |
] | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"data['dot_graph'] = 'digraph { v ->y;X0-> v;X0-> y;}'\n", | |
"df = data['df']\n", | |
"df['y'] = df['y'] + np.random.normal(size=len(df))\n", | |
"df" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<table class=\"simpletable\">\n", | |
"<caption>OLS Regression Results</caption>\n", | |
"<tr>\n", | |
" <th>Dep. Variable:</th> <td>y</td> <th> R-squared: </th> <td> 0.927</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
" <th>Model:</th> <td>OLS</td> <th> Adj. R-squared: </th> <td> 0.925</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
" <th>Method:</th> <td>Least Squares</td> <th> F-statistic: </th> <td> 622.1</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
" <th>Date:</th> <td>Sun, 10 Feb 2019</td> <th> Prob (F-statistic):</th> <td>2.02e-56</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
" <th>Time:</th> <td>09:18:58</td> <th> Log-Likelihood: </th> <td> -124.35</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
" <th>No. Observations:</th> <td> 100</td> <th> AIC: </th> <td> 252.7</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
" <th>Df Residuals:</th> <td> 98</td> <th> BIC: </th> <td> 257.9</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
" <th>Df Model:</th> <td> 2</td> <th> </th> <td> </td> \n", | |
"</tr>\n", | |
"<tr>\n", | |
" <th>Covariance Type:</th> <td>nonrobust</td> <th> </th> <td> </td> \n", | |
"</tr>\n", | |
"</table>\n", | |
"<table class=\"simpletable\">\n", | |
"<tr>\n", | |
" <td></td> <th>coef</th> <th>std err</th> <th>t</th> <th>P>|t|</th> <th>[0.025</th> <th>0.975]</th> \n", | |
"</tr>\n", | |
"<tr>\n", | |
" <th>X0</th> <td> 1.0777</td> <td> 0.064</td> <td> 16.781</td> <td> 0.000</td> <td> 0.950</td> <td> 1.205</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
" <th>v</th> <td> 5.1443</td> <td> 0.153</td> <td> 33.686</td> <td> 0.000</td> <td> 4.841</td> <td> 5.447</td>\n", | |
"</tr>\n", | |
"</table>\n", | |
"<table class=\"simpletable\">\n", | |
"<tr>\n", | |
" <th>Omnibus:</th> <td> 0.175</td> <th> Durbin-Watson: </th> <td> 1.970</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
" <th>Prob(Omnibus):</th> <td> 0.916</td> <th> Jarque-Bera (JB): </th> <td> 0.169</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
" <th>Skew:</th> <td> 0.092</td> <th> Prob(JB): </th> <td> 0.919</td>\n", | |
"</tr>\n", | |
"<tr>\n", | |
" <th>Kurtosis:</th> <td> 2.917</td> <th> Cond. No. </th> <td> 2.44</td>\n", | |
"</tr>\n", | |
"</table><br/><br/>Warnings:<br/>[1] Standard Errors assume that the covariance matrix of the errors is correctly specified." | |
], | |
"text/plain": [ | |
"<class 'statsmodels.iolib.summary.Summary'>\n", | |
"\"\"\"\n", | |
" OLS Regression Results \n", | |
"==============================================================================\n", | |
"Dep. Variable: y R-squared: 0.927\n", | |
"Model: OLS Adj. R-squared: 0.925\n", | |
"Method: Least Squares F-statistic: 622.1\n", | |
"Date: Sun, 10 Feb 2019 Prob (F-statistic): 2.02e-56\n", | |
"Time: 09:18:58 Log-Likelihood: -124.35\n", | |
"No. Observations: 100 AIC: 252.7\n", | |
"Df Residuals: 98 BIC: 257.9\n", | |
"Df Model: 2 \n", | |
"Covariance Type: nonrobust \n", | |
"==============================================================================\n", | |
" coef std err t P>|t| [0.025 0.975]\n", | |
"------------------------------------------------------------------------------\n", | |
"X0 1.0777 0.064 16.781 0.000 0.950 1.205\n", | |
"v 5.1443 0.153 33.686 0.000 4.841 5.447\n", | |
"==============================================================================\n", | |
"Omnibus: 0.175 Durbin-Watson: 1.970\n", | |
"Prob(Omnibus): 0.916 Jarque-Bera (JB): 0.169\n", | |
"Skew: 0.092 Prob(JB): 0.919\n", | |
"Kurtosis: 2.917 Cond. No. 2.44\n", | |
"==============================================================================\n", | |
"\n", | |
"Warnings:\n", | |
"[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.\n", | |
"\"\"\"" | |
] | |
}, | |
"execution_count": 27, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = OLS(df['y'], df[['X0', 'v']])\n", | |
"result = model.fit()\n", | |
"result.summary()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'digraph { v ->y;X0-> v;X0-> y;}'" | |
] | |
}, | |
"execution_count": 28, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"data['dot_graph']" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Error: Pygraphviz cannot be loaded. No module named 'pygraphviz'\n", | |
"Trying pydot ...\n", | |
"['X0']\n", | |
"yes\n", | |
"{'observed': 'yes'}\n", | |
"Model to find the causal effect of treatment v on outcome y\n" | |
] | |
} | |
], | |
"source": [ | |
"causal_model= CausalModel(\n", | |
" data=data[\"df\"],\n", | |
" treatment=data[\"treatment_name\"],\n", | |
" outcome=data[\"outcome_name\"],\n", | |
" graph=data[\"dot_graph\"])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:dowhy.causal_identifier:Common causes of treatment and outcome:{'X0', 'U'}\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'observed': 'yes'}\n", | |
"{'label': 'Unobserved Confounders', 'observed': 'no'}\n", | |
"There are unobserved common causes. Causal effect cannot be identified.\n", | |
"WARN: Do you want to continue by ignoring these unobserved confounders? [y/n] yes\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:dowhy.causal_identifier:Instrumental variables for treatment and outcome:[]\n" | |
] | |
} | |
], | |
"source": [ | |
"identified_estimand = causal_model.identify_effect()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pymc3 as pm\n", | |
"import networkx as nx" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 41, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:pymc3:Auto-assigning NUTS sampler...\n", | |
"INFO:pymc3:Initializing NUTS using jitter+adapt_diag...\n", | |
"INFO:pymc3:Multiprocess sampling (4 chains in 4 jobs)\n", | |
"INFO:pymc3:NUTS: [y_sd, beta_y, v_sd, beta_v]\n", | |
"Sampling 4 chains: 100%|██████████| 6000/6000 [00:02<00:00, 2316.81draws/s]\n" | |
] | |
} | |
], | |
"source": [ | |
"g = causal_model._graph.get_unconfounded_observed_subgraph()\n", | |
"df = causal_model._data\n", | |
"\n", | |
"data_types = {'X0': 'c', 'v': 'b', 'y': 'c'}\n", | |
"\n", | |
"\n", | |
"class CausalBayesianNetwork(object):\n", | |
" def __init__(self, g, df, data_types):\n", | |
" self.g = g\n", | |
" self.df = df\n", | |
" self.data_types = data_types\n", | |
" g_fit = nx.DiGraph(self.g)\n", | |
" _, self.fit_trace = self.fit_causal_model(g_fit, \n", | |
" self.df, \n", | |
" self.data_types)\n", | |
" \n", | |
" def apply_data_types(self, g, data_types):\n", | |
" for node in nx.topological_sort(g):\n", | |
" g.nodes()[node][\"variable_type\"] = data_types[node]\n", | |
" return g\n", | |
"\n", | |
" def apply_parents(self, g):\n", | |
" for node in nx.topological_sort(g):\n", | |
" if not g.nodes()[node].get(\"parent_names\"):\n", | |
" g.nodes()[node][\"parent_names\"] = [parent for parent, _ in g.in_edges(node)]\n", | |
" return g\n", | |
"\n", | |
" def apply_parameters(self, g, df, initialization_trace=None):\n", | |
" for node in nx.topological_sort(g):\n", | |
" parent_names = g.nodes()[node][\"parent_names\"]\n", | |
" if parent_names:\n", | |
" if not initialization_trace:\n", | |
" sd = np.array([df[node].std()] + (df[node].std() / df[parent_names].std() ).tolist())\n", | |
" mu = np.array([df[node].std()] + (df[node].std() / df[parent_names].std() ).tolist())\n", | |
" node_sd = df[node].std()\n", | |
" else:\n", | |
" node_sd = initialization_trace[\"{}_sd\".format(node)].mean()\n", | |
" mu = initialization_trace[\"beta_{}\".format(node)].mean(axis=0)\n", | |
" sd = initialization_trace[\"beta_{}\".format(node)].std(axis=0)\n", | |
" g.nodes()[node][\"parameters\"] = pm.Normal(\"beta_{}\".format(node), mu=mu, sd=sd, shape=len(parent_names)+1)\n", | |
" g.nodes()[node][\"sd\"] = pm.Exponential(\"{}_sd\".format(node), lam=node_sd)\n", | |
" return g\n", | |
"\n", | |
" def build_bayesian_network(self, g, df):\n", | |
" for node in nx.topological_sort(g):\n", | |
" if g.nodes()[node][\"parent_names\"]:\n", | |
" mu = g.nodes()[node][\"parameters\"][0] # intercept\n", | |
" mu += pm.math.dot(df[g.nodes()[node][\"parent_names\"]], \n", | |
" g.nodes()[node][\"parameters\"][1:])\n", | |
" if g.nodes()[node][\"variable_type\"] == 'c':\n", | |
" sd = g.nodes()[node][\"sd\"]\n", | |
" g.nodes()[node][\"variable\"] = pm.Normal(\"{}\".format(node), \n", | |
" mu=mu, sd=sd, \n", | |
" observed=df[node])\n", | |
" elif g.nodes()[node][\"variable_type\"] == 'b':\n", | |
" g.nodes()[node][\"variable\"] = pm.Bernoulli(\"{}\".format(node), \n", | |
" logit_p=mu, \n", | |
" observed=df[node])\n", | |
" else:\n", | |
" raise Exception(\"Unrecognized variable type: {}\".format(g.nodes()[node][\"variable_type\"]))\n", | |
" return g\n", | |
"\n", | |
" def fit_causal_model(self, g, df, data_types, initialization_trace=None):\n", | |
" if nx.is_directed_acyclic_graph(g):\n", | |
" with pm.Model() as model: \n", | |
" g = self.apply_data_types(g, data_types)\n", | |
" g = self.apply_parents(g)\n", | |
" g = self.apply_parameters(g, df, initialization_trace=initialization_trace)\n", | |
" g = self.build_bayesian_network(g, df)\n", | |
" trace = pm.sample(1000, tune=500)\n", | |
" else:\n", | |
" raise Exception(\"Graph is not a DAG!\")\n", | |
" return g, trace\n", | |
"\n", | |
" def sample_prior_causal_model(self, g, df, data_types, initialization_trace):\n", | |
" if nx.is_directed_acyclic_graph(g):\n", | |
" with pm.Model() as model: \n", | |
" g = self.apply_data_types(g, data_types)\n", | |
" g = self.apply_parents(g)\n", | |
" g = self.apply_parameters(g, df, initialization_trace=initialization_trace)\n", | |
" g = self.build_bayesian_network(g, df)\n", | |
" trace = pm.sample_prior_predictive(1)\n", | |
" else:\n", | |
" raise Exception(\"Graph is not a DAG!\")\n", | |
" return g, trace\n", | |
"\n", | |
" def do_x_surgery(self, g, x):\n", | |
" for xi in x.keys():\n", | |
" g.remove_edges_from([(parent, child) for (parent, child) in g.in_edges(xi)])\n", | |
" g.nodes()[xi][\"parent_names\"] = []\n", | |
" return g\n", | |
"\n", | |
" def make_intervention_effective(self, df, x):\n", | |
" df_intervened = df.copy()\n", | |
" for k, v in x.items():\n", | |
" df_intervened[k] = v\n", | |
" return df_intervened\n", | |
"\n", | |
" def do(self, x):\n", | |
" g_for_surgery = nx.DiGraph(self.g)\n", | |
" g_modified = self.do_x_surgery(g_for_surgery, x)\n", | |
" df_intervened = self.make_intervention_effective(self.df, x)\n", | |
" g_modified, trace = self.sample_prior_causal_model(g_modified, \n", | |
" df_intervened, \n", | |
" self.data_types, \n", | |
" initialization_trace=self.fit_trace)\n", | |
" return trace\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 59, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:pymc3:Auto-assigning NUTS sampler...\n", | |
"INFO:pymc3:Initializing NUTS using jitter+adapt_diag...\n", | |
"INFO:pymc3:Multiprocess sampling (4 chains in 4 jobs)\n", | |
"INFO:pymc3:NUTS: [y_sd, beta_y, v_sd, beta_v]\n", | |
"Sampling 4 chains: 100%|██████████| 6000/6000 [00:02<00:00, 2280.31draws/s]\n", | |
"WARNING:pymc3:The acceptance probability does not match the target. It is 0.8827395991724155, but should be close to 0.8. Try to increase the number of tuning steps.\n" | |
] | |
} | |
], | |
"source": [ | |
"causal_bayes_net = CausalBayesianNetwork(g, df, data_types)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 60, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x = {'v': 1.}\n", | |
"df_1 = causal_bayes_net.do(x)\n", | |
"x = {'v': 0.}\n", | |
"df_0 = causal_bayes_net.do(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 61, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAKkAAAAPBAMAAABtvvLvAAAAMFBMVEX///8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAv3aB7AAAAD3RSTlMAzWYQMplU74mrdiK7RN1/7zyFAAAACXBIWXMAAA7EAAAOxAGVKw4bAAACy0lEQVQ4Ea2UzWsTQRjGn2Rjsvk0UBQ8NSpo1UNXPNj20ObiwYsJSqQqpYtaiieVKlLqxypW2lJwLUY9mSCIWg8NinizAU9i0foPWL0I4qGtYr+wrs87s2nzBziHzOwzv/z23ZmdRWB76CjYjMJUFsb44DTHoTwOFYeGHMkB89FN6Q4AilUgRpqyftj2tEtGtaYVEc9zJGmFuYoNwD2OUy6aPc9zEe7j1Wa0V9h9BxSrwIBt/GAmYR5RG+i1eCW4VgSGn8s0jgCXcBLYy/ETFyeAKFqHfwHBczDTQKyftQqrwJ3AdYIMwxkYaeNBiVaFa0VUlGxXgVJlFGgH4ltdOACfO0hrgiWze7FF7lMD3wFn6WFoVmF8AZqlVsG1omZttFDKzrzFDS5kwCWSLGtso4vkAuDUrApcBmZsFSb/WGamzqoV0W2Ps9Bt1op6TTawR1kTDOXmHXkkVxCsiNVnZ/f9prWsw9LPWyTXatWKhBVf1dLkEnB73oJRVlYWrayTDkIraAOtPptcihPNlXVoeg31VqUQ4Qf5UTsf7pq9AhNiNeYYSa2TeYSWuM60sgmbcuOsNefosOfz4nRdrUohZHdFflEAdiG2aJ1W1lia0doKhLO+VdgCDLUCKgy6mOirsyoF7vCxeS++JXmA07lTjrKmqr6VuxVeaIFYNSvgMt8BW4URCyFu5tq6isLGedZqibUH8Zdz3JEdxUL3xTIo07Umqiz5TKGweNxnCVamxKPCDmJj61ZZuGiGRwNfRRpyEKsMcEcqPEM05jihVsA/BeiXY0RWgf4pYBgh/37dCq14heAAOtJ4MzT4EJ8s7KYsUuUDOhzEuFv4hv1ZdpcBxSowZRsXmDFMNiBYBiYsXgmuFfHOoiUHv9Hz/iJUlK+L+XE5g9c2z/T9+cNAy/hd/mPUG4NiFWiMPMv64cFOfl2OXduU0bhWcO7/t39S7fhfQwkz1AAAAABJRU5ErkJggg==\n", | |
"text/latex": [ | |
"$$5.284710452048118$$" | |
], | |
"text/plain": [ | |
"5.284710452048118" | |
] | |
}, | |
"execution_count": 61, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"(df_1['y'] - df_0['y']).mean()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 62, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAALQAAAAPBAMAAAC/7vi3AAAAMFBMVEX///8AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAv3aB7AAAAD3RSTlMAiXZmMs1UEN0i77urRJlR0qN3AAAACXBIWXMAAA7EAAAOxAGVKw4bAAADTElEQVQ4EbWUz28UZRjHP7M/Ztrd2eksEi803aXAQQ8wNOEgMen25sHQNTX2gnEJHI0OJiSkB3Y5wIUmbOLByMFuYmRjRN0YE6ImOlpjAocyMTF665RgCCi1WrbLluLyvO9bEv4B5vDM933f73zmned53gFrz07Udeavr+G93feUPkJmfqVY5EzxmhpmErDGd+KMj8Uy3LG7YozOymgTHV5ZKhZ9ZsIjNcPTijc43RS/E1GKSXBrMpgnPRgMtpyPuRTKcCjCeZM/8eAXOBV7VWOcC4bb6DAp9ogvB/9ieFqlW1hVedwLybXsAEcGqSt48rrIbeEmsvZiRD7hO44LF74n3zHG67CEDnfAhdd+amJ4SpGPsHvyeK6B27c6OA14/yiWpKAyUsVag+FDEctiZRecptAXZYxdKKGDD5LWSFYMTylGIgobcs/2cO8XHoZWAP5RtfIr9Q4pAVnZiG/VzPJVniPfEqWNzkM48IEK8t5C2aANLxIP9YTCf0pISnvcWH8e0k2N7pCTXT+ClwTdndof4w4O1hh5fYeUWhs/kV3HOshuBbA4cWubpxQLPplt9LKPNZiVNkGhZfuSjew6TjkbOd2A8wLqhtRnGKoZYwneCnRAPgeushAbnlIsJGSkruqah3f+2IzxNbouM19xbEPSLugB3A3texc/p76G1zbGfNku1XTAWdOIfOMJL994KiFeQjriQNuuaPQ+8doTU/c5IWgewGT8MqnNcKRB+pE2wvG9pYoJqapGe/0nCfb6qoy2KiPslzYJyWycRKMv6Em35/gKfU7QQRuma0NV0uvaqAw3QhOGOtJ+qjSap5X0KGnVfNgJH6okLL69urp5E7bUpFDtpdV/vihL907+Lh/tBm5Ldq2Nal0XXIIQVdW9nuZptd3iYpqCF3JN+Eb0FTnc/0urtJmWGXIR0yF3+VQaoZnp4bWM8XCY6aMD04mcOmmyqj4yWiGFn6s4XYbPFv9uFGZJl4X1mXyE1DY168zIiFwH13cu81vIBPzIXM0Y94WnAnTgXV+OVsLhpuYZxcnxn1VzZOUv0ODVFelZdg0WGb4sorgaS7R+2AoYG62QWVK/J2tFHtBGa0zcOnCpJsaPRm9jeFrJzDO6HgMVjziaVfbIdwAAAABJRU5ErkJggg==\n", | |
"text/latex": [ | |
"$$0.2847946868997505$$" | |
], | |
"text/plain": [ | |
"0.2847946868997505" | |
] | |
}, | |
"execution_count": 62, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"1.96* (df_1['y'] - df_0['y']).std() / np.sqrt(len(df))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment