Last active
December 24, 2015 05:09
-
-
Save dstahlke/6748551 to your computer and use it in GitHub Desktop.
IPython extension for transparent use of joblib.Parallel.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# NOTE: this is just a proof of concept and breaks when, for example, you have | |
# parenthesis inside a quotation. I'm still trying to find the easiest way to | |
# handle this properly. Suggestions are welcome. | |
# | |
# An IPython extension for transparent use of joblib. | |
# To use, type something like this in IPython: | |
# [ x*2 parfor x in range(10) ] | |
# To use custom options, type this at the IPython prompt: | |
# _parfor = joblib.Parallel(n_jobs=-1, verbose=5) | |
# | |
# The skeleton of this is copied from the physics.py extension by Georg Brandl. | |
import re | |
class JobsTransformer(object): | |
priority = 99 | |
enabled = True | |
def transform(self, line, continue_prompt): | |
my_re = re.compile('(.*) parfor (.+?) in (.*)') | |
fn_counter = 0 | |
output = [] | |
stack = [] | |
accum = '' | |
for i in range(len(line)): | |
c = line[i] | |
if c == '[': | |
stack.append(accum) | |
accum = '' | |
elif c == ']': | |
match = my_re.match(accum) | |
if match is None: | |
# The text inside the brackets doesn't contain 'parfor' so | |
# it is not of concern to us. Put the square brackets back | |
# on it and leave it as is. | |
rewritten = '['+accum+']' | |
else: | |
# The 'parfor' token was detected. Convert it to a | |
# joblib.Parallel call. | |
fn_counter += 1 | |
fn = '_parfor_task_%d' % fn_counter | |
(expr, var, arr) = match.group(1, 2, 3) | |
output.append("def "+fn+"("+var+"):") | |
output.append(" return "+expr+"\n") | |
rewritten = ('_parfor(joblib.delayed(%s)(x) for x in %s)') % (fn, arr) | |
accum = stack.pop() + rewritten | |
else: | |
accum += c | |
output.append(accum) | |
output = "".join([ x+"\n" for x in output]) | |
#print "---\n", output, "\n---" | |
return output | |
jobs_transformer = JobsTransformer() | |
def load_ipython_extension(ip): | |
ip.prefilter_manager.register_transformer(jobs_transformer) | |
exec ip.compile('import joblib', '<input>', 'single') in ip.user_ns | |
exec ip.compile('_parfor = joblib.Parallel(n_jobs=-1)', '<input>', 'single') in ip.user_ns | |
print 'parfor extension activated.' | |
def unload_ipython_extension(ip): | |
ip.prefilter_manager.unregister_transformer(jobs_transformer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment