Created
November 4, 2021 16:05
-
-
Save dmasad/e080016635dc8ca914ec6439d97287ea to your computer and use it in GitHub Desktop.
Mesa spatial SIR model with visualization server
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from mesa import Agent, Model | |
from mesa.time import RandomActivation | |
from mesa.space import MultiGrid | |
from mesa.datacollection import DataCollector | |
from mesa.visualization.ModularVisualization import ModularServer | |
from mesa.visualization.modules import CanvasGrid, ChartModule | |
from mesa.visualization.UserParam import UserSettableParameter | |
# Model | |
# =============================================================== | |
class SpatialSIRModel(Model): | |
def __init__(self, n_agents, initial_infected=1, infection_radius=1): | |
self.schedule = RandomActivation(self) | |
self.grid = MultiGrid(20, 20, True) | |
self.datacollector = DataCollector(model_reporters={ | |
"Susceptible": 'susceptible', | |
"Infected": 'infected', | |
"Recovered": 'recovered'}, | |
agent_reporters={"coordinates": "pos", | |
"status": "status"}) | |
# Create agents | |
for i in range(n_agents): | |
a = SpatialSIRAgent(i, self) | |
self.schedule.add(a) | |
# Place the agent somewhere at random on the grid | |
x = self.random.randrange(self.grid.width) | |
y = self.random.randrange(self.grid.height) | |
self.grid.place_agent(a, (x, y)) | |
for i in range(initial_infected): | |
self.agents[i].status = "Infected" | |
self.infection_radius = infection_radius | |
self.running = True | |
@property | |
def susceptible(self): | |
return len([agent for agent in self.agents if agent.status=='Susceptible']) | |
@property | |
def infected(self): | |
return len([agent for agent in self.agents if agent.status=='Infected']) | |
@property | |
def recovered(self): | |
return len([agent for agent in self.agents if agent.status=='Recovered']) | |
@property | |
def agents(self): | |
return self.schedule.agents | |
def step(self): | |
self.datacollector.collect(self) | |
self.schedule.step() | |
# Check end condition | |
if len([agent for agent in self.agents if agent.status=="Infected"]) == 0: | |
self.running = False | |
class SpatialSIRAgent(Agent): | |
def __init__(self, unique_id, model): | |
super().__init__(unique_id, model) | |
self.status = "Susceptible" | |
def move(self): | |
# We can use self.model.grid.torus_adj, or: | |
possible_steps = self.model.grid.get_neighborhood( | |
self.pos, | |
moore=True, | |
include_center=False | |
) | |
new_pos = self.random.choice(possible_steps) | |
self.model.grid.move_agent(self, new_pos) | |
def infect(self): | |
exposed = self.model.grid.get_neighbors(self.pos, moore=True, | |
include_center=True, | |
radius=self.model.infection_radius) | |
for agent in exposed: | |
if agent.status == "Susceptible": | |
agent.status = "Infected" | |
def step(self): | |
self.move() | |
if self.status == "Infected": | |
if self.random.random() < 0.25: | |
self.infect() | |
if self.random.random() < 0.25: | |
self.status = "Recovered" | |
# Server | |
# =============================================================== | |
COLORS = {"Susceptible": "blue", "Infected": "red", "Recovered": "green"} | |
def sir_model_portrayal(cell): | |
if cell is None: | |
return | |
portrayal = {"Shape": "circle", | |
"r": 0.9, | |
"Filled": "true", | |
"Layer": 0} | |
(x, y) = cell.pos | |
portrayal["x"] = x | |
portrayal["y"] = y | |
portrayal["Color"] = COLORS[cell.status] | |
return portrayal | |
if __name__ == "__main__": | |
grid_element = CanvasGrid(sir_model_portrayal, 20, 20, 500, 500) | |
chart_element = ChartModule([{"Label": label, "Color": color } | |
for label, color in COLORS.items()]) | |
model_params = { | |
"n_agents": 100, | |
"initial_infected": 10, | |
"infection_radius": UserSettableParameter("slider", "Infection Radius", | |
1, 1, 10, 1) | |
} | |
server = ModularServer(SpatialSIRModel, [grid_element, chart_element], | |
"Spatial SIR Model", model_params) | |
server.launch() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment