-
-
Save treyhunner/1644c56401103136520ba1535967a735 to your computer and use it in GitHub Desktop.
Pythonic Code Refactoring Session
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def add(matrix1, matrix2): | |
"""Add corresponding numbers in given 2-D matrices.""" | |
combined = [] | |
for i in range(len(matrix1)): | |
row = [] | |
for j in range(len(matrix1[i])): | |
row.append(matrix1[i][j] + matrix2[i][j]) | |
combined.append(row) | |
return combined |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import reduce | |
def all_same(iterable): | |
for item in iterable: | |
first_item = item | |
break | |
return reduce(lambda x, y: x and y == first_item, iterable, True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def all_same(iterable): | |
first_item = next(iter(iterable), None) | |
return all( | |
item == first_item | |
for item in iterable | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def count_words(string): | |
"""Return the number of times each word occurs in the string.""" | |
count = {} | |
for word in string.split(): | |
word = word.strip(',;.!?"()').lower() | |
if word in count: | |
count[word] += 1 | |
else: | |
count[word] = 1 | |
return count |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
old_filename = sys.argv[1] | |
new_filename = sys.argv[2] | |
old_file = open(old_filename) | |
rows = [ | |
line.split('|') | |
for line in old_file.read().splitlines() | |
] | |
new_file = open(new_filename, mode='wt', newline='\r\n') | |
print("\n".join( | |
",".join(row) | |
for row in rows | |
), file=new_file) | |
old_file.close() | |
new_file.close() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from argparse import ArgumentParser, FileType | |
import csv | |
parser = ArgumentParser() | |
parser.add_argument('old_file', type=FileType('rt')) | |
parser.add_argument('new_file', type=FileType('wt')) | |
args = parser.parse_args() | |
reader = csv.reader(args.old_file, delimiter='|') | |
writer = csv.writer(args.new_file, delimiter=',') | |
writer.writerows(reader) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_earliest(date1, date2): | |
mdy1 = date1.split('/') | |
mdy2 = date2.split('/') | |
if mdy1[2] < mdy2[2]: | |
return date1 | |
elif mdy1[2] > mdy2[2]: | |
return date2 | |
elif mdy1[0] < mdy2[0]: | |
return date1 | |
elif mdy1[0] > mdy2[0]: | |
return date2 | |
elif mdy1[1] < mdy2[1]: | |
return date1 | |
elif mdy1[1] > mdy2[1]: | |
return date2 | |
else: | |
return date1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime | |
def get_earliest(date1, date2): | |
d1 = datetime.strptime(date1, "%m/%d/%Y") | |
d2 = datetime.strptime(date2, "%m/%d/%Y") | |
return date1 if (d1 < d2) else date2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def interleave(*iterables): | |
"""Return iterable of one item at a time from each given iterable.""" | |
interleaved = [] | |
for i in range(len(iterables[0])): | |
for iterable in iterables: | |
interleaved.append(iterable[i]) | |
return interleaved |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def interleave(*iterables): | |
"""Return iterable of one item at a time from each given iterable.""" | |
return ( | |
item | |
for items in zip(*iterables) | |
for item in items | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_color_ratios(colors, ratios): | |
"""Return dictionary of color ratios from color and ratio lists.""" | |
assert len(colors) == len(ratios) | |
color_ratios = {} | |
for i in range(len(colors)): | |
color_ratios[colors[i]] = ratios[i] | |
return color_ratios | |
if __name__ == "__main__": | |
test_colors = ["red", "green", "blue"] | |
test_ratios = [0.1, 0.6, 0.3] | |
combined_dict = {'red': 0.1, 'green': 0.6, 'blue': 0.3} | |
assert get_color_ratios(test_colors, test_ratios) == combined_dict |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_color_ratios(colors, ratios): | |
"""Return dictionary of color ratios from color and ratio lists.""" | |
assert len(colors) == len(ratios) | |
return dict(zip(colors, ratios)) | |
if __name__ == "__main__": | |
test_colors = ["red", "green", "blue"] | |
test_ratios = [0.1, 0.6, 0.3] | |
combined_dict = {'red': 0.1, 'green': 0.6, 'blue': 0.3} | |
assert get_color_ratios(test_colors, test_ratios) == combined_dict | |
print("Tests passed") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
import sys | |
filename = sys.argv[1] | |
column_numbers = sys.argv[2:] | |
numbers = [] | |
for n in column_numbers: | |
numbers.append(int(n)) | |
with open(filename) as csv_file: | |
reader = csv.reader(csv_file) | |
rows = [row for row in reader] | |
sorted_rows = sorted(rows, key=lambda row: [row[n] for n in numbers]) | |
writer = csv.writer(sys.stdout) | |
writer.writerows(sorted_rows) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def sum_timestamps(timestamps): | |
total_time = 0 | |
for time in timestamps: | |
total_time += parse_time(time) | |
return format_time(total_time) | |
def parse_time(time_string): | |
sections = time_string.split(':') | |
if len(sections) == 2: | |
seconds = int(sections[1]) | |
minutes = int(sections[0]) | |
hours = 0 | |
else: | |
seconds = int(sections[2]) | |
minutes = int(sections[1]) | |
hours = int(sections[0]) | |
return hours*3600 + minutes*60 + seconds | |
def format_time(total_seconds): | |
hours = str(int(total_seconds / 3600)) | |
minutes = str(int(total_seconds / 60) % 60) | |
seconds = str(total_seconds % 60) | |
if len(minutes) < 2 and hours != "0": | |
minutes = "0" + minutes | |
if len(seconds) < 2: | |
seconds = "0" + seconds | |
time = minutes + ":" + seconds | |
if hours != "0": | |
time = hours + ":" + time | |
return time |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
TIME_RE = re.compile(r''' | |
^ | |
(?: # Optional hours | |
( \d+ ) | |
: | |
)? | |
( \d+ ) # Minutes | |
: | |
( \d+ ) # Seconds | |
$ | |
''', re.VERBOSE) | |
def sum_timestamps(timestamps): | |
total_time = sum( | |
parse_time(time) | |
for time in timestamps | |
) | |
return format_time(total_time) | |
def parse_time(time_string): | |
hours, minutes, seconds = TIME_RE.search(time_string).groups() | |
if hours is None: | |
hours = 0 | |
return int(hours)*3600 + int(minutes)*60 + int(seconds) | |
def format_time(total_seconds): | |
minutes, seconds = divmod(total_seconds, 60) | |
hours, minutes = divmod(minutes, 60) | |
if hours > 0: | |
return f"{hours}:{minutes:02d}:{seconds:02d}" | |
else: | |
return f"{minutes}:{seconds:02d}" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from copy import deepcopy | |
import unittest | |
from add import add | |
class AddTests(unittest.TestCase): | |
"""Tests for add.""" | |
def test_single_items(self): | |
self.assertEqual(add([[5]], [[-2]]), [[3]]) | |
def test_two_by_two_matrixes(self): | |
m1 = [[6, 6], [3, 1]] | |
m2 = [[1, 2], [3, 4]] | |
m3 = [[7, 8], [6, 5]] | |
self.assertEqual(add(m1, m2), m3) | |
def test_two_by_three_matrixes(self): | |
m1 = [[1, 2, 3], [4, 5, 6]] | |
m2 = [[-1, -2, -3], [-4, -5, -6]] | |
m3 = [[0, 0, 0], [0, 0, 0]] | |
self.assertEqual(add(m1, m2), m3) | |
def test_input_unchanged(self): | |
m1 = [[6, 6], [3, 1]] | |
m2 = [[1, 2], [3, 4]] | |
m1_original = deepcopy(m1) | |
m2_original = deepcopy(m2) | |
add(m1, m2) | |
self.assertEqual(m1, m1_original) | |
self.assertEqual(m2, m2_original) | |
if __name__ == "__main__": | |
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
from all_same import all_same | |
class AllSameTests(unittest.TestCase): | |
"""Tests for all_same.""" | |
def test_one_item_number(self): | |
self.assertTrue(all_same([4])) | |
self.assertTrue(all_same([0])) | |
self.assertTrue(all_same([-1])) | |
def test_one_string(self): | |
self.assertTrue(all_same(['hello'])) | |
def test_one_none_value(self): | |
self.assertTrue(all_same([None])) | |
def test_one_tuple(self): | |
self.assertTrue(all_same([()])) | |
self.assertTrue(all_same([(1,)])) | |
self.assertTrue(all_same([(1, 2)])) | |
def test_empty_sequence(self): | |
self.assertTrue(all_same([])) | |
self.assertTrue(all_same(())) | |
self.assertTrue(all_same('')) | |
def test_two_same_item(self): | |
self.assertTrue(all_same([1, 1])) | |
self.assertTrue(all_same([0, 0])) | |
self.assertTrue(all_same(['hello', 'hello'])) | |
self.assertTrue(all_same([-1, -1])) | |
self.assertTrue(all_same([(1, 2), (1, 2)])) | |
self.assertTrue(all_same([None, None])) | |
def test_two_different_items(self): | |
self.assertFalse(all_same(['hello', 'hi'])) | |
self.assertFalse(all_same([-1, 1])) | |
self.assertFalse(all_same([-1, 'hi'])) | |
self.assertFalse(all_same([(1, 3), (1, 2)])) | |
self.assertFalse(all_same(['hello', (4, 5)])) | |
self.assertFalse(all_same([4, None])) | |
self.assertFalse(all_same([None, 4])) | |
def test_many_items(self): | |
self.assertTrue(all_same([1, 1, 1, 1, 1, 1])) | |
self.assertFalse(all_same([1, 1, 1, 1, 2, 1])) | |
self.assertFalse(all_same(['hi', 'hello', 'hey'])) | |
self.assertFalse(all_same(['hello', 'hella', 'hello'])) | |
self.assertTrue(all_same(['hi', 'hi', 'hi', 'hi', 'hi'])) | |
self.assertTrue(all_same(['hello', 'hello', 'hello'])) | |
self.assertTrue(all_same([(1, 2, 3), (1, 2, 3), (1, 2, 3)])) | |
self.assertFalse(all_same([(1, 2, 3), (1, 2, 3), (1, 4, 3)])) | |
def test_nonhashable_values(self): | |
self.assertFalse(all_same([['hi', 'hi'], ['hi', 'hi', 'hi']])) | |
self.assertTrue(all_same([['hi', 'hi'], ['hi', 'hi']])) | |
self.assertTrue(all_same([{1: 2}, {1: 2}])) | |
self.assertFalse(all_same([{1: 2}, {1: 3}])) | |
def test_nonsequences(self): | |
numbers = [1, 3, 5, 7, 9] | |
self.assertTrue(all_same({1})) | |
self.assertFalse(all_same({1, 2})) | |
self.assertFalse(all_same(n**2 for n in numbers)) | |
self.assertTrue(all_same(n % 2 for n in numbers)) | |
@unittest.expectedFailure | |
def test_return_early(self): | |
self.assertFalse(all_same(n**2 for n in [2, 3, {}])) | |
from itertools import count | |
self.assertFalse(all_same(count())) | |
if __name__ == "__main__": | |
unittest.main(verbosity=2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
from count_words import count_words | |
class CountWordsTests(unittest.TestCase): | |
"""Tests for count_words.""" | |
def test_simple_sentence(self): | |
actual = count_words("oh what a day what a lovely day") | |
expected = {'oh': 1, 'what': 2, 'a': 2, 'day': 2, 'lovely': 1} | |
self.assertEqual(actual, expected) | |
def test_apostrophe(self): | |
actual = count_words("don't stop believing") | |
expected = {"don't": 1, 'stop': 1, 'believing': 1} | |
self.assertEqual(actual, expected) | |
def test_capitalization(self): | |
actual = count_words("Oh what a day what a lovely day") | |
expected = {'oh': 1, 'what': 2, 'a': 2, 'day': 2, 'lovely': 1} | |
self.assertEqual(actual, expected) | |
def test_symbols(self): | |
actual = count_words("Oh what a day, what a lovely day!") | |
expected = {'oh': 1, 'what': 2, 'a': 2, 'day': 2, 'lovely': 1} | |
self.assertEqual(actual, expected) | |
if __name__ == "__main__": | |
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextlib import contextmanager, redirect_stdout | |
from io import StringIO | |
import imp | |
import os | |
import sys | |
from textwrap import dedent | |
from tempfile import NamedTemporaryFile | |
import unittest | |
class FixCSVTests(unittest.TestCase): | |
"""Tests for fix_csv.py""" | |
maxDiff = None | |
def test_pipe_file_to_csv_file(self): | |
old_contents = dedent(""" | |
2012|Lexus|LFA | |
2009|GMC|Yukon XL 1500 | |
1965|Ford|Mustang | |
2005|Hyundai|Sonata | |
1995|Mercedes-Benz|C-Class | |
""").lstrip() | |
expected = dedent(""" | |
2012,Lexus,LFA | |
2009,GMC,Yukon XL 1500 | |
1965,Ford,Mustang | |
2005,Hyundai,Sonata | |
1995,Mercedes-Benz,C-Class | |
""").lstrip() | |
with make_file(old_contents) as old, make_file("") as new: | |
output = run_program('fix_csv.py', args=[old, new]) | |
with open(new) as new_file: | |
new_contents = new_file.read() | |
self.assertEqual(expected, new_contents) | |
self.assertEqual("", output) | |
def test_original_file_is_unchanged(self): | |
old_contents = dedent(""" | |
2012|Lexus|LFA | |
2009|GMC|Yukon XL 1500 | |
""").lstrip() | |
with make_file(old_contents) as old, make_file("") as new: | |
run_program('fix_csv.py', args=[old, new]) | |
with open(old) as old_file: | |
contents = old_file.read() | |
self.assertEqual(old_contents, contents) | |
@unittest.expectedFailure | |
def test_delimiter_in_output(self): | |
old_contents = dedent(""" | |
02|Waylon Jennings|Honky Tonk Heroes (Like Me) | |
04|Kris Kristofferson|To Beat The Devil | |
11|Johnny Cash|Folsom Prison Blues | |
13|Billy Joe Shaver|Low Down Freedom | |
21|Hank Williams III|Mississippi Mud | |
22|David Allan Coe|Willie, Waylon, And Me | |
24|Bob Dylan|House Of The Risin' Sun | |
""").lstrip() | |
expected = dedent(""" | |
02,Waylon Jennings,Honky Tonk Heroes (Like Me) | |
04,Kris Kristofferson,To Beat The Devil | |
11,Johnny Cash,Folsom Prison Blues | |
13,Billy Joe Shaver,Low Down Freedom | |
21,Hank Williams III,Mississippi Mud | |
22,David Allan Coe,"Willie, Waylon, And Me" | |
24,Bob Dylan,House Of The Risin' Sun | |
""").lstrip() | |
with make_file(old_contents) as old, make_file("") as new: | |
output = run_program('fix_csv.py', args=[old, new]) | |
with open(new) as new_file: | |
new_contents = new_file.read() | |
self.assertEqual(expected, new_contents) | |
self.assertEqual("", output) | |
@unittest.expectedFailure | |
def test_call_with_too_many_files(self): | |
with make_file("") as old, make_file("") as new: | |
with self.assertRaises(BaseException): | |
run_program('fix_csv.py', args=[old, new, old]) | |
def run_program(path, args=[]): | |
""" | |
Run program at given path with given arguments. | |
If raises is specified, ensure the given exception is raised. | |
""" | |
old_args = sys.argv | |
assert all(isinstance(a, str) for a in args) | |
try: | |
sys.argv = [path] + args | |
with redirect_stdout(StringIO()) as output: | |
try: | |
if '__main__' in sys.modules: | |
del sys.modules['__main__'] | |
imp.load_source('__main__', path) | |
except SystemExit as e: | |
if e.args != (0,): | |
raise | |
del sys.modules['__main__'] | |
return output.getvalue() | |
finally: | |
sys.argv = old_args | |
@contextmanager | |
def make_file(contents=None): | |
"""Context manager providing name of a file containing given contents.""" | |
with NamedTemporaryFile(mode='wt', encoding='utf-8', delete=False) as f: | |
if contents: | |
f.write(contents) | |
try: | |
yield f.name | |
finally: | |
os.remove(f.name) | |
if __name__ == "__main__": | |
unittest.main(verbosity=2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
from get_earliest import get_earliest | |
class GetEarliestTests(unittest.TestCase): | |
"""Tests for get_earliest.""" | |
def test_same_month_and_day(self): | |
newer = "01/27/1832" | |
older = "01/27/1756" | |
self.assertEqual(get_earliest(newer, older), older) | |
def test_february_29th(self): | |
newer = "02/29/1972" | |
older = "12/21/1946" | |
self.assertEqual(get_earliest(newer, older), older) | |
def test_smaller_month_bigger_day(self): | |
newer = "03/21/1946" | |
older = "02/24/1946" | |
self.assertEqual(get_earliest(older, newer), older) | |
def test_same_month_and_year(self): | |
newer = "06/24/1958" | |
older = "06/21/1958" | |
self.assertEqual(get_earliest(older, newer), older) | |
if __name__ == "__main__": | |
unittest.main(verbosity=2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import count | |
import unittest | |
from interleave import interleave | |
class InterleaveTests(unittest.TestCase): | |
"""Tests for interleave.""" | |
def assertIterableEqual(self, iterable1, iterable2): | |
self.assertEqual(list(iterable1), list(iterable2)) | |
def test_empty_lists(self): | |
self.assertIterableEqual(interleave([], []), []) | |
def test_single_item_each(self): | |
self.assertIterableEqual(interleave([1], [2]), [1, 2]) | |
def test_two_items_each(self): | |
self.assertIterableEqual(interleave([1, 2], [3, 4]), [1, 3, 2, 4]) | |
def test_four_items_each(self): | |
in1 = [1, 2, 3, 4] | |
in2 = [5, 6, 7, 8] | |
out = [1, 5, 2, 6, 3, 7, 4, 8] | |
self.assertIterableEqual(interleave(in1, in2), out) | |
def test_none_value(self): | |
in1 = [1, 2, 3, None] | |
in2 = [4, 5, 6, 7] | |
out = [1, 4, 2, 5, 3, 6, None, 7] | |
self.assertIterableEqual(interleave(in1, in2), out) | |
# To test the Bonus part of this exercise, comment out the following line | |
@unittest.expectedFailure | |
def test_non_sequences(self): | |
in1 = [1, 2, 3, 4] | |
in2 = (n**2 for n in in1) | |
out = [1, 1, 2, 4, 3, 9, 4, 16] | |
self.assertIterableEqual(interleave(in1, in2), out) | |
if __name__ == "__main__": | |
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextlib import contextmanager, redirect_stdout | |
from io import StringIO | |
import imp | |
import os | |
import sys | |
from textwrap import dedent | |
from tempfile import NamedTemporaryFile | |
import unittest | |
class SortByColumnTests(unittest.TestCase): | |
"""Tests for sort_by_column.py""" | |
maxDiff = None | |
def test_sort_by_first_column(self): | |
contents = dedent(""" | |
2012,Lexus,LFA | |
2009,GMC,Yukon XL 1500 | |
1965,Ford,Mustang | |
2005,Hyundai,Sonata | |
1995,Mercedes-Benz,C-Class | |
""").lstrip() | |
expected = dedent(""" | |
1965,Ford,Mustang | |
1995,Mercedes-Benz,C-Class | |
2005,Hyundai,Sonata | |
2009,GMC,Yukon XL 1500 | |
2012,Lexus,LFA | |
""").lstrip().replace('\n', '\r\n') | |
with make_file(contents) as csv_file: | |
output = run_program('sort_by_column.py', args=[csv_file, '0']) | |
self.assertEqual(expected, output) | |
def test_sort_by_second_column(self): | |
contents = dedent(""" | |
2012,Lexus,LFA | |
2009,GMC,Yukon XL 1500 | |
1965,Ford,Mustang | |
2005,Hyundai,Sonata | |
1995,Mercedes-Benz,C-Class | |
""").lstrip() | |
expected = dedent(""" | |
1965,Ford,Mustang | |
2009,GMC,Yukon XL 1500 | |
2005,Hyundai,Sonata | |
2012,Lexus,LFA | |
1995,Mercedes-Benz,C-Class | |
""").lstrip().replace('\n', '\r\n') | |
with make_file(contents) as csv_file: | |
output = run_program('sort_by_column.py', args=[csv_file, '1']) | |
self.assertEqual(expected, output) | |
def test_original_file_is_unchanged(self): | |
old_contents = dedent(""" | |
2012,Lexus,LFA | |
2009,GMC,Yukon XL 1500 | |
""").lstrip() | |
with make_file(old_contents) as filename: | |
run_program('sort_by_column.py', args=[filename, '0']) | |
with open(filename) as csv_file: | |
new_contents = csv_file.read() | |
self.assertEqual(old_contents, new_contents) | |
def test_sorting_with_commas(self): | |
contents = dedent(""" | |
"Hughes, John",Baby's Day Out | |
"Hughes, John",The Breakfast Club | |
"Hughes, Langston",A Dream Deferred | |
"Hughes, Langston",Dreams | |
""").lstrip().replace('\n', '\r\n') | |
expected = dedent(""" | |
"Hughes, Langston",A Dream Deferred | |
"Hughes, John",Baby's Day Out | |
"Hughes, Langston",Dreams | |
"Hughes, John",The Breakfast Club | |
""").lstrip().replace('\n', '\r\n') | |
with make_file(contents) as old: | |
output = run_program('sort_by_column.py', args=[old, '1']) | |
self.assertEqual(expected, output) | |
def test_sort_by_one_column_only(self): | |
contents = dedent(""" | |
11,Johnny Cash,Folsom Prison Blues | |
13,Billy Joe Shaver,Low Down Freedom | |
2,Waylon Jennings,Honky Tonk Heroes (Like Me) | |
2,Hank Williams III,Mississippi Mud | |
4,Kris Kristofferson,To Beat The Devil | |
22,David Allan Coe,"Willie, Waylon, And Me" | |
4,Bob Dylan,House Of The Risin' Sun | |
""").lstrip().replace('\n', '\r\n') | |
expected = dedent(""" | |
11,Johnny Cash,Folsom Prison Blues | |
13,Billy Joe Shaver,Low Down Freedom | |
2,Waylon Jennings,Honky Tonk Heroes (Like Me) | |
2,Hank Williams III,Mississippi Mud | |
22,David Allan Coe,"Willie, Waylon, And Me" | |
4,Kris Kristofferson,To Beat The Devil | |
4,Bob Dylan,House Of The Risin' Sun | |
""").lstrip().replace('\n', '\r\n') | |
with make_file(contents) as old: | |
output = run_program('sort_by_column.py', args=[old, '0']) | |
self.assertEqual(expected, output) | |
def test_sort_by_multiple_columns(self): | |
contents = dedent(""" | |
2005,Lexus,LFA | |
2009,GMC,Yukon XL 1500 | |
1995,Ford,Mustang | |
2005,Hyundai,Sonata | |
1995,Mercedes-Benz,C-Class | |
""").lstrip() | |
expected = dedent(""" | |
1995,Mercedes-Benz,C-Class | |
1995,Ford,Mustang | |
2005,Lexus,LFA | |
2005,Hyundai,Sonata | |
2009,GMC,Yukon XL 1500 | |
""").lstrip().replace('\n', '\r\n') | |
with make_file(contents) as name: | |
output = run_program('sort_by_column.py', args=[name, '0', '2']) | |
self.assertEqual(expected, output) | |
def run_program(path, args=[]): | |
""" | |
Run program at given path with given arguments. | |
If raises is specified, ensure the given exception is raised. | |
""" | |
old_args = sys.argv | |
assert all(isinstance(a, str) for a in args) | |
try: | |
sys.argv = [path] + args | |
with redirect_stdout(StringIO()) as output: | |
try: | |
if '__main__' in sys.modules: | |
del sys.modules['__main__'] | |
imp.load_source('__main__', path) | |
except SystemExit as e: | |
if e.args != (0,): | |
raise | |
return output.getvalue() | |
finally: | |
sys.argv = old_args | |
@contextmanager | |
def make_file(contents=None): | |
"""Context manager providing name of a file containing given contents.""" | |
with NamedTemporaryFile(mode='wt', encoding='utf-8', delete=False) as f: | |
if contents: | |
f.write(contents) | |
try: | |
yield f.name | |
finally: | |
os.remove(f.name) | |
if __name__ == "__main__": | |
unittest.main(verbosity=2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
from sum_timestamps import sum_timestamps | |
class SumTimeStampsTests(unittest.TestCase): | |
"""Tests for sum_timestamps.""" | |
def test_single_timestamp(self): | |
self.assertEqual(sum_timestamps(['02:01']), '2:01') | |
self.assertEqual(sum_timestamps(['2:01']), '2:01') | |
def test_multiple_timestamps(self): | |
self.assertEqual(sum_timestamps(['02:01', '04:05']), '6:06') | |
self.assertEqual(sum_timestamps(['9:38', '4:45', '3:52']), '18:15') | |
def test_many_timestamps(self): | |
times = [ | |
'3:52', '3:29', '3:23', '4:05', '3:24', '2:29', '2:16', '2:44', | |
'1:58', '3:21', '2:51', '2:53', '2:51', '3:32', '3:20', '2:40', | |
'2:50', '3:24', '3:22', '0:42'] | |
self.assertEqual(sum_timestamps(times), '59:26') | |
def test_no_minutes(self): | |
self.assertEqual(sum_timestamps(['00:01', '00:05']), '0:06') | |
self.assertEqual(sum_timestamps(['0:38', '0:15']), '0:53') | |
# To test the Bonus part of this exercise, comment out the following line | |
@unittest.expectedFailure | |
def test_timestamps_over_an_hour(self): | |
times = [ | |
'3:52', '3:29', '3:23', '4:05', '3:24', '2:29', '2:16', '2:44', | |
'1:58', '3:21', '2:51', '2:53', '2:51', '3:32', '3:20', '2:40', | |
'2:50', '3:24', '1:20', '3:22', '3:26', '0:42', '5:20'] | |
self.assertEqual(sum_timestamps(times), '1:09:32') | |
times2 = [ | |
'50:52', '34:29', '36:23', '47:05', '32:24', '20:29', '22:16', | |
'23:44', '19:58', '30:21', '24:51', '22:53', '23:51', '34:32', | |
'36:20', '25:40', '27:50', '39:24', '18:20', '36:22', '4:00', | |
] | |
self.assertEqual(sum_timestamps(times2), '10:12:04') | |
# To test the Bonus part of this exercise, comment out the following line | |
@unittest.expectedFailure | |
def test_allow_optional_hour(self): | |
self.assertEqual(sum_timestamps(['1:02:01', '04:05']), '1:06:06') | |
self.assertEqual( | |
sum_timestamps(['9:05:00', '4:45:10', '3:52']), | |
'13:54:02', | |
) | |
if __name__ == "__main__": | |
unittest.main(verbosity=2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
from vector import Vector | |
class VectorTests(unittest.TestCase): | |
"""Tests for Vector.""" | |
def test_attributes(self): | |
v = Vector(1, 2, 3) | |
self.assertEqual((v.x, v.y, v.z), (1, 2, 3)) | |
def test_equality_and_inequality(self): | |
self.assertNotEqual(Vector(1, 2, 3), Vector(1, 2, 4)) | |
self.assertEqual(Vector(1, 2, 3), Vector(1, 2, 3)) | |
self.assertFalse(Vector(1, 2, 3) != Vector(1, 2, 3)) | |
v1 = Vector(1, 2, 3) | |
v2 = Vector(1, 2, 4) | |
v3 = Vector(1, 2, 3) | |
self.assertNotEqual(v1, v2) | |
self.assertEqual(v1, v3) | |
def test_iterable_vector(self): | |
x, y, z = Vector(x=1, y=2, z=3) | |
self.assertEqual((x, y, z), (1, 2, 3)) | |
def test_no_weird_extras(self): | |
v1 = Vector(1, 2, 3) | |
v2 = Vector(4, 5, 6) | |
with self.assertRaises(TypeError): | |
len(v1) | |
with self.assertRaises(TypeError): | |
v1 < v2 | |
with self.assertRaises(TypeError): | |
v1 > v2 | |
with self.assertRaises(TypeError): | |
v1 <= v2 | |
with self.assertRaises(TypeError): | |
v1 >= v2 | |
with self.assertRaises(TypeError): | |
v1 + (1, 2, 3) | |
with self.assertRaises(TypeError): | |
(1, 2, 3) + v1 | |
with self.assertRaises(TypeError): | |
v1 - (1, 2, 3) | |
with self.assertRaises(TypeError): | |
v1 * 'a' | |
with self.assertRaises(TypeError): | |
v1 / v2 | |
def test_memory_efficient_attributes(self): | |
v = Vector(1, 2, 3) | |
with self.assertRaises(AttributeError): | |
v.a = 3 | |
with self.assertRaises(AttributeError): | |
v.__dict__ | |
def test_shifting(self): | |
v1 = Vector(1, 2, 3) | |
v2 = Vector(4, 5, 6) | |
v3 = v2 + v1 | |
v4 = v3 - v1 | |
self.assertEqual((v3.x, v3.y, v3.z), (5, 7, 9)) | |
self.assertEqual((v4.x, v4.y, v4.z), (v2.x, v2.y, v2.z)) | |
def test_scaling(self): | |
v1 = Vector(1, 2, 3) | |
v2 = Vector(4, 5, 6) | |
v3 = v1 * 4 | |
v4 = 2 * v2 | |
self.assertEqual((v3.x, v3.y, v3.z), (4, 8, 12)) | |
self.assertEqual((v4.x, v4.y, v4.z), (8, 10, 12)) | |
if __name__ == "__main__": | |
unittest.main(verbosity=2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Vector: | |
__slots__ = 'x', 'y', 'z' | |
def __init__(self, x, y, z): | |
self.x = x | |
self.y = y | |
self.z = z | |
def __iter__(self): | |
return iter([self.x, self.y, self.z]) | |
def __eq__(self, other): | |
return self.x == other.x and self.y == other.y and self.z == other.z | |
def __ne__(self, other): | |
return self.x != other.x and self.y != other.y and self.z != other.z | |
def __add__(self, other): | |
if not isinstance(other, Vector): | |
return NotImplemented | |
return Vector(self.x + other.x, self.y + other.y, self.z + other.z) | |
def __sub__(self, other): | |
if not isinstance(other, Vector): | |
return NotImplemented | |
return Vector(self.x - other.x, self.y - other.y, self.z - other.z) | |
def __mul__(self, scalar): | |
if not isinstance(scalar, (int, float)): | |
return NotImplemented | |
return Vector(self.x * scalar, self.y * scalar, self.z * scalar) | |
def __rmul__(self, scalar): | |
if not isinstance(scalar, (int, float)): | |
return NotImplemented | |
return Vector(self.x * scalar, self.y * scalar, self.z * scalar) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Vector: | |
__slots__ = 'x', 'y', 'z' | |
def __init__(self, x, y, z): | |
self.x, self.y, self.z = x, y, z | |
def __iter__(self): | |
yield from (self.x, self.y, self.z) | |
def __eq__(self, other): | |
return tuple(self) == tuple(other) | |
def __add__(self, other): | |
if not isinstance(other, Vector): | |
return NotImplemented | |
x1, y1, z1 = self | |
x2, y2, z2 = other | |
return Vector(x1+x2, y1+y2, z1+z2) | |
def __sub__(self, other): | |
if not isinstance(other, Vector): | |
return NotImplemented | |
x1, y1, z1 = self | |
x2, y2, z2 = other | |
return Vector(x1-x2, y1-y2, z1-z2) | |
def __mul__(self, scalar): | |
if not isinstance(scalar, (int, float)): | |
return NotImplemented | |
x, y, z = self | |
return Vector(x*scalar, y*scalar, z*scalar) | |
def __rmul__(self, scalar): | |
if not isinstance(scalar, (int, float)): | |
return NotImplemented | |
return Vector(self.x * scalar, self.y * scalar, self.z * scalar) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment