Skip to content

Instantly share code, notes, and snippets.

@gabraganca
Created August 23, 2017 15:34
Show Gist options
  • Save gabraganca/d788429b97c206f7ab84a3077aa123ca to your computer and use it in GitHub Desktop.
Save gabraganca/d788429b97c206f7ab84a3077aa123ca to your computer and use it in GitHub Desktop.
Script to make a scatter plot of two variables and color code the dots according to a thrid variable.
#!/usr/bin/env python
"""
Script to make a scatter plot of two variables and color code the dots
according to a thrid variable
author: Gustavo Bragança
e-mail: ga.braganca at gmail dot com
"""
# import needed libraries
import os
import argparse
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
def load_data(filename):
"""Load the dataset into a Pandas DataFrame.
Parameters
----------
filename : str
Name of the file to be loaded.
Returns
-------
df : pandas.core.frame.DataFrame
The data loaded into a Pandas' DataFrame.
"""
return pd.read_csv(os.path.abspath(filename), usecols=range(3))
def plot(dataframe, savename=None):
"""Create a scatter plot of two variables color coded by a third.
It saves the plot in `savename` path.
Parameters
----------
dataframe : pandas.core.frame.DataFrame
The data loaded into a Pandas' DataFrame
savename : str
Name to save the plot.
"""
# Get name of variables to use on labels
x_name, y_name, z_name = dataframe.columns
# Make the plot.
# The legend awas put manually becaaus it was not being savef outised the
# plot as usual when using Seaborn's FacetGrid
sns.lmplot(x_name, y_name, hue=z_name, data=dataframe, fit_reg=False,
legend=False)
plt.legend(title=z_name)
# add the spines back
sns.despine(right=False, top=False)
# Save the plot
if savename is None:
savename = 'plot.pdf'
elif len(savename.split('.')) < 2:
savename += '.pdf'
plt.savefig(savename)
if __name__ == "__main__":
# Get arguments
raw_formatter = argparse.RawDescriptionHelpFormatter
parser = argparse.ArgumentParser(description=__doc__,
formatter_class=raw_formatter)
parser.add_argument('data',
help='name of the CSV file')
parser.add_argument('-s', '--savename',
help='Name to save the plot.')
args = parser.parse_args()
# Create the plot
df = load_data(args.data)
plot(df,savename=args.savename)
RA DEC Cluster
1 4 SPAM
2 2 SPAM
3 8 EGG
4 2 EGG
5 4 SPAM
6 1 EGG
7 9 SPAM
8 5 SPAM
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment