Created
June 13, 2025 06:41
-
-
Save ochilab/69e4d68fa420d56a1a5fc7794e1f8863 to your computer and use it in GitHub Desktop.
IRT(Item Response Theory)のサンプル
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import pandas as pd | |
| import torch | |
| from py_irt.models.two_param_logistic import TwoParamLog | |
| class SafeTwoParamLog(TwoParamLog): | |
| def fit_MCMC(self, models, items, responses, num_epochs=1000): | |
| from pyro.infer import MCMC, NUTS | |
| nuts_kernel = NUTS(self.model_vague, adapt_step_size=True) | |
| hmc_posterior = MCMC( | |
| nuts_kernel, num_samples=1000, warmup_steps=100 | |
| ) | |
| hmc_posterior.run(models, items, responses) # ← run() は None を返すので、分ける! | |
| self.trace = hmc_posterior | |
| return hmc_posterior # ✅ これで正しい TracePosterior を返す | |
| # データ準備 | |
| data = pd.DataFrame({ | |
| 'user_id': ['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C'], | |
| 'item_id': ['Q1', 'Q2', 'Q3', 'Q1', 'Q2', 'Q3', 'Q1', 'Q2', 'Q3'], | |
| 'response': [1, 1, 0, 1, 0, 0, 0, 1, 1] | |
| }) | |
| user2idx = {u:i for i,u in enumerate(data['user_id'].unique())} | |
| item2idx = {q:i for i,q in enumerate(data['item_id'].unique())} | |
| # Tensor に変換 | |
| subjects = torch.tensor(data['user_id'].map(user2idx).to_numpy(), dtype=torch.long) | |
| items = torch.tensor(data['item_id'].map(item2idx).to_numpy(), dtype=torch.long) | |
| responses = torch.tensor(data['response'].to_numpy(), dtype=torch.float) | |
| # モデル初期化 | |
| model = SafeTwoParamLog( | |
| priors='vague', | |
| num_items=len(item2idx), | |
| num_subjects=len(user2idx), | |
| verbose=True, | |
| device='cpu' | |
| ) | |
| # # 🔄 MCMC推定 | |
| mcmc = model.fit_MCMC(subjects, items, responses, num_epochs=1000) | |
| # 各パラメータのサンプル取得 | |
| samples = mcmc.get_samples() | |
| import pandas as pd | |
| # 各パラメータの統計量を DataFrame に | |
| summary_df = pd.DataFrame({ | |
| "theta_mean": samples["theta"].mean(0).numpy(), | |
| "theta_std": samples["theta"].std(0).numpy(), | |
| "b_mean": samples["b"].mean(0).numpy(), | |
| "b_std": samples["b"].std(0).numpy(), | |
| "a_mean": samples["a"].mean(0).numpy(), | |
| "a_std": samples["a"].std(0).numpy(), | |
| }) | |
| print(summary_df) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment