Created
December 4, 2023 10:35
-
-
Save malte-j/5d846a92159f00f83a1d7db69adaf68a to your computer and use it in GitHub Desktop.
dart thompson sampling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import 'dart:math'; | |
class ThompsonSampling { | |
final List<double> means; | |
final List<double> variances; | |
ThompsonSampling(List<double> initialMeans, List<double> initialVariances) | |
: means = List.from(initialMeans), | |
variances = List.from(initialVariances); | |
void updateObservations(int armIndex, double newObservation) { | |
// Update mean and variance based on new observation | |
final double oldMean = means[armIndex]; | |
final double oldVariance = variances[armIndex]; | |
// Update mean and variance using online update formulas | |
final double newMean = (oldMean + newObservation) / 2; | |
final double newVariance = | |
(oldVariance + pow(newObservation - oldMean, 2)) / 2; | |
means[armIndex] = newMean; | |
variances[armIndex] = newVariance; | |
} | |
int selectArm() { | |
// Number of arms (options) | |
final int numArms = means.length; | |
// Perform Thompson Sampling for each arm | |
final List<double> samples = List.generate(numArms, (index) { | |
// Generate a random sample for each arm using the Normal distribution | |
final double sample = Random().nextDouble(); | |
// Calculate the sampled value from the Normal distribution | |
return means[index] + sqrt(variances[index]) * cos(2 * pi * sample); | |
}); | |
// Choose the arm with the highest sampled value | |
final int selectedArm = samples.indexOf(samples.reduce(max)); | |
return selectedArm; | |
} | |
} | |
void main() { | |
// Example usage | |
final List<double> initialMeans = [ | |
1.0, | |
1.0, | |
1.0, | |
]; // Initial mean for each arm | |
final List<double> initialVariances = [ | |
2.0, | |
2.0, | |
2.0, | |
]; // Initial variance for each arm | |
// Create Thompson Sampling instance | |
final ThompsonSampling thompsonSampling = | |
ThompsonSampling(initialMeans, initialVariances); | |
// // Simulate new observations (adjust to new data) | |
thompsonSampling.updateObservations(0, 11.0); | |
// thompsonSampling.updateObservations(1, 10.0); | |
// thompsonSampling.updateObservations(2, 3.0); | |
// Get the index of the selected arm using Thompson Sampling | |
final int selectedArm = thompsonSampling.selectArm(); | |
print("Selected Arm: $selectedArm"); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment