Last active
October 21, 2016 16:10
-
-
Save lebedov/30e397307c8aa6057b848193ba85a82e to your computer and use it in GitHub Desktop.
Filter out certain outputs in a nipype MapNode.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
Filter out certain outputs in a nipype MapNode. | |
""" | |
import logging | |
import nipype | |
import nipype.interfaces.utility as utility | |
iflogger = logging.getLogger('interface') | |
class FilterMapNode(nipype.MapNode): | |
""" | |
For every input/output field name `name` in `in_filt/out_filt.keys()`, | |
filter out the values in `in_filt/out_filt[name]`. | |
""" | |
def __init__(self, interface, iterfield, name, serial=False, nested=False, in_filt=dict(), out_filt=dict(), **kwargs): | |
super(FilterMapNode, self).__init__(interface, iterfield, name, serial=False, nested=False, **kwargs) | |
self.in_filt = in_filt | |
self.out_filt = out_filt | |
def _run_interface(self): | |
super(FilterMapNode, self)._run_interface() | |
if hasattr(self, 'in_filt'): | |
if not isinstance(self.in_filt, dict): | |
raise ValueError('in_filt must be dict') | |
for name, filt_list in self.in_filt.items(): | |
try: | |
values = getattr(self._inputs, name) | |
except: | |
ValueError('in_filt must contain existing input field') | |
if not isinstance(filt_list, list): | |
raise ValueError('filt_list must be a list') | |
iflogger.debug('Removing %s from input field %s' % (filt_list, name)) | |
setattr(self._inputs, name, [v for v in values if v not in filt_list]) | |
if hasattr(self, 'out_filt'): | |
if not isinstance(self.out_filt, dict): | |
raise ValueError('out_filt must be dict') | |
for name, filt_list in self.out_filt.items(): | |
try: | |
values = getattr(self.result.outputs, name) | |
except: | |
ValueError('filt must contain existing output field') | |
if not isinstance(filt_list, list): | |
raise ValueError('filt_list must be a list') | |
iflogger.debug('Removing %s from output field %s' % (filt_list, name)) | |
setattr(self._result.outputs, name, [v for v in values if v not in filt_list]) | |
if __name__ == '__main__': | |
def func(in_file): | |
""" | |
Function that returns None for certain inputs. | |
""" | |
if isinstance(in_file, str) and in_file.startswith('xxx'): | |
return None | |
else: | |
return in_file | |
which = 'out' | |
in_filt = dict() | |
out_filt = dict() | |
if which == 'out': | |
out_filt = {'out_file': [None]} | |
else: | |
in_filt = {'in_file': ['xxxyyy']} | |
n = FilterMapNode(interface=utility.Function(input_names=['in_file'], | |
output_names=['out_file'], | |
function=func), | |
name='func', | |
iterfield=['in_file'], | |
in_filt=in_filt, out_filt=out_filt) | |
n.inputs.in_file = ['aaabbb', 'cccddd', 'xxxyyy', 'pppqqq'] | |
n.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment