Created
November 13, 2020 20:29
-
-
Save koenvo/3d8a99949e5131a69fceed7664ceeb5f to your computer and use it in GitHub Desktop.
Pass chain
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
import streamlit as st | |
from kloppy import datasets, event_pattern_matching as pm | |
from mplsoccer.pitch import Pitch | |
@st.cache(allow_output_mutation=True) | |
def load_dataset(match_id): | |
return datasets.load( | |
"statsbomb", options={"event_types": ["pass", "shot"]}, match_id=match_id | |
) | |
def get_results(dataset, pass_count): | |
pattern = ( | |
pm.match_pass( | |
success=True, | |
capture="first_touch" | |
) | |
+ pm.match_pass( | |
success=True, | |
team=pm.same_as("first_touch.team"), | |
) * (pass_count - 1) | |
+ | |
pm.match_shot( | |
team=pm.same_as("first_touch.team") | |
) | |
) | |
return pm.search(dataset, pattern) | |
def get_passes(results, team): | |
passes = defaultdict(list) | |
for match in results: | |
if match.events[0].team == team: | |
shooter = match.events[-1].player | |
for i in range(len(match.events) - 1): | |
passes[shooter].append(( | |
match.events[i].coordinates.x, | |
match.events[i].coordinates.y, | |
match.events[i + 1].coordinates.x, | |
match.events[i + 1].coordinates.y, | |
)) | |
return passes | |
def main(): | |
pass_count = st.sidebar.selectbox( | |
"Number of passes", | |
[2, 3, 4, 5] | |
) | |
dataset = load_dataset(15946) | |
results = get_results(dataset, pass_count) | |
home_team, away_team = dataset.metadata.teams | |
team_name = st.sidebar.selectbox( | |
'Team', | |
[str(home_team), str(away_team)] | |
) | |
team = home_team if str(home_team) == team_name else away_team | |
passes_per_player = get_passes(results, team) | |
shooter_name = st.sidebar.selectbox( | |
'Shooter', | |
sorted(["-- all --"] + [str(player) for player in passes_per_player.keys()])) | |
st.write(f"Shooter: {shooter_name}") | |
selected_passes = [] | |
for shooter, passes in passes_per_player.items(): | |
if shooter_name == "-- all --" or str(shooter) == shooter_name: | |
selected_passes.extend(passes) | |
pitch = Pitch(pitch_type='statsbomb', orientation='horizontal', | |
pitch_color='#22312b', line_color='#c7d5cc', figsize=(16, 11), | |
constrained_layout=True, tight_layout=False) | |
fig, ax = pitch.draw() | |
x1, y1, x2, y2 = zip(*selected_passes) | |
pitch.arrows(x1, y1, x2, y2, width=2, | |
headwidth=10, headlength=10, color='#ad993c', ax=ax, label='completed passes') | |
st.pyplot(fig) | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment