Skip to content

Instantly share code, notes, and snippets.

@joaoreboucas1
Last active March 10, 2024 15:08
Show Gist options
  • Save joaoreboucas1/d85163ea9899271d8e582d000e0792bb to your computer and use it in GitHub Desktop.
Save joaoreboucas1/d85163ea9899271d8e582d000e0792bb to your computer and use it in GitHub Desktop.
Marcus' idea of dividing a sample via treshold value and maximizing the difference of means
import sys
import numpy as np
from numpy.typing import NDArray
import matplotlib.pyplot as plt
# Number of samples to draw
N = 200_000
# Distribution parameters
rendimento_max = 1000
rendimento_medio = rendimento_max/2
rendimento_stdev = rendimento_max/20
# For reproducibility
np.random.seed(2611)
# Control which distribution to use via command-line argument
if len(sys.argv) == 1 or (len(sys.argv) == 2 and sys.argv[1] == "uniform"):
rendimentos = np.random.uniform(0, rendimento_max, N)
elif len(sys.argv) == 2 and sys.argv[1] == "normal":
rendimentos = np.random.normal(rendimento_medio, rendimento_stdev, N)
elif len(sys.argv) == 2 and sys.argv[1] == "binormal":
rendimentos_1 = np.random.normal(rendimento_medio, rendimento_stdev, int(N/2))
rendimentos_2 = np.random.normal(rendimento_max, rendimento_stdev, int(N/2))
rendimentos = np.concatenate((rendimentos_1, rendimentos_2))
else:
print(f"Usage: python3 {sys.argv[0]} [distribution]")
print("distribution: either uniform or normal (defaults to uniform)")
exit(1)
def split_treshold(samples: NDArray, treshold: float) -> tuple[NDArray]:
# Partitions the input array into two arrays: the first has all the elements <= treshold, the second has elements > treshold
return (samples[samples <= treshold], samples[samples > treshold])
def get_means(split_arrays: tuple[NDArray]) -> tuple[float]:
# Calculates the difference of the means of the arrays
return np.mean(split_arrays[0]), np.mean(split_arrays[1]), np.mean(split_arrays[1]) - np.mean(split_arrays[0])
tresholds = np.arange(np.amin(rendimentos), np.amax(rendimentos), 10)
diffs = np.zeros((len(tresholds)))
means = np.zeros((len(tresholds), 2))
for i, treshold in enumerate(tresholds):
means[i, 0], means[i, 1], diffs[i] = get_means(split_treshold(rendimentos, treshold))
fig, axs = plt.subplots(2, 1)
axs[0].hist(rendimentos, bins=100)
axs[1].plot(tresholds, diffs, label="Diferença entre médias")
axs[1].plot(tresholds, means[:, 0], label="Média do grupo baixo")
axs[1].plot(tresholds, means[:, 1], label="Média do grupo alto")
axs[0].set_title("Distribuição dos rendimentos")
axs[1].set_title("Médias dos grupos")
axs[1].set_xlabel("Treshold de classificação")
axs[1].legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment