use std::fmt::Display;

use rand::Rng;

const DIM: usize = 100;
type V = [f64; DIM];

struct ObjectiveFunctionStruct {
    name: String,
    function: fn(&V) -> f64,
    lower_bound: V,
    upper_bound: V,
}

struct Particle {
    position: V,
    velocity: V,
    personal_best_position: V,
    personal_best_fitness: f64,
}

impl Display for Particle {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "(Position: {:?}, Velocity: {:?}, Personal Best Position: {:?}, Personal Best Fitness: {})", self.position, self.velocity, self.personal_best_position, self.personal_best_fitness)
    }
}

struct Swarm {
    w: f64,
    c1: f64,
    c2: f64,
    particles: Vec<Particle>,
    global_best_position: V,
    global_best_fitness: f64,
    objective_function_struct: ObjectiveFunctionStruct,
    rng_thread: rand::rngs::ThreadRng,
}

impl Swarm {
    fn new(
        w: f64,
        c1: f64,
        c2: f64,
        swarm_size: usize,
        objective_function_struct: ObjectiveFunctionStruct,
    ) -> Swarm {
        let mut particles = Vec::new();
        let mut rng_thread = rand::thread_rng();
        let mut global_best_position = [0.0; DIM];
        let mut global_best_fitness = std::f64::MAX;
        for _ in 0..swarm_size {
            let mut particle: Particle = Particle {
                position: std::array::from_fn(|i| {
                    rng_thread.gen_range(
                        objective_function_struct.lower_bound[i]
                            ..objective_function_struct.upper_bound[i],
                    )
                }),
                velocity: [0.0; DIM],
                personal_best_position: [0.0; DIM],
                personal_best_fitness: std::f64::MAX,
            };
            particle.personal_best_position = particle.position;
            particle.personal_best_fitness =
                (objective_function_struct.function)(&particle.position);
            if particle.personal_best_fitness < global_best_fitness {
                global_best_fitness = particle.personal_best_fitness;
                global_best_position = particle.personal_best_position;
            }
            particles.push(particle);
        }
        Swarm {
            w,
            c1,
            c2,
            particles,
            global_best_position,
            global_best_fitness,
            objective_function_struct,
            rng_thread,
        }
    }

    fn update_particles(&mut self) {
        for particle in &mut self.particles {
            for i in 0..DIM {
                particle.velocity[i] = self.w * particle.velocity[i]
                    + self.c1
                        * self.rng_thread.gen_range(0.0..1.0)
                        * (particle.personal_best_position[i] - particle.position[i])
                    + self.c2
                        * self.rng_thread.gen_range(0.0..1.0)
                        * (self.global_best_position[i] - particle.position[i]);
                particle.position[i] += particle.velocity[i];
            }
            let fitness: f64 = (self.objective_function_struct.function)(&particle.position);
            if fitness < self.global_best_fitness {
                particle.personal_best_fitness = fitness;
                particle.personal_best_position = particle.position;
                self.global_best_fitness = fitness;
                self.global_best_position = particle.position;
            } else if fitness < particle.personal_best_fitness {
                particle.personal_best_fitness = fitness;
                particle.personal_best_position = particle.position;
            }
        }
    }

    fn run(&mut self, iterations: usize) {
        for _ in 0..iterations {
            self.update_particles();
        }
    }

    fn print(&self) {
        println!("Global Best Position: {:?}", self.global_best_position);
        println!("Global Best Fitness: {}", self.global_best_fitness);
    }
}

fn sphere_function(x: &V) -> f64 {
    x.iter().map(|a: &f64| a.powi(2)).sum()
}

fn main() {
    use std::time::Instant;
    let now = Instant::now();
    let objective_function_struct: ObjectiveFunctionStruct = ObjectiveFunctionStruct {
        name: "Sphere Function".to_string(),
        function: sphere_function,
        lower_bound: [-5.12; DIM],
        upper_bound: [5.12; DIM],
    };
    let mut swarm: Swarm = Swarm::new(0.729, 1.49445, 1.49445, 1000, objective_function_struct);
    swarm.run(10000);
    swarm.print();
    let elapsed = now.elapsed();
    println!("Elapsed: {} ms", elapsed.as_millis());
}