Created
June 21, 2019 08:43
-
-
Save dmitryhd/c79cf89009eb24b52c8b9feb9b6c3821 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import tqdm | |
import numpy as np | |
import multiprocessing as mp | |
def mix_observations(observations: np.ndarray): | |
indices = np.random.randint(0, len(observations), size=len(observations)) | |
new_sample = observations[indices] | |
return new_sample | |
def boostrap_mean(observations: np.ndarray): | |
return mix_observations(observations).mean() | |
def bootstrap_mean_dist(observations: np.ndarray, n_iters: int) -> list: | |
mean_result = [] | |
for i in tqdm.tqdm(range(n_iters)): | |
mean = boostrap_mean(observations) | |
mean_result.append(mean) | |
return mean_result | |
# выделил агрумент отдельно, обязательно спроси меня почему так | |
# важно чтобы эта штука была глобальной | |
observations = [np.zeros(100)] * 100 | |
def bootstrap_mean_batch(observations_id: int, n_iters: int) -> list: | |
# теперь аргумент - инт! не передаем весь массив | |
# print('processing ', observations_id) | |
return bootstrap_mean_dist(observations[observations_id], n_iters) | |
def main(): | |
# создаю пул процессов. они существуют все время работы | |
t0 = time.time() | |
with mp.Pool(processes=10) as pool: | |
# аргументы для функций, где буду распараллеливать | |
args = [(observations_id, 5) for observations_id in range(len(observations))] | |
results = pool.starmap(bootstrap_mean_batch, iterable=args) | |
# процессы завершаются) | |
print(f'processed {len(observations)} in {time.time() - t0:.2} seconds') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment