Last active
November 7, 2017 17:15
-
-
Save ndevenish/df17d30793a70ce8831d25b2fbd6d746 to your computer and use it in GitHub Desktop.
Diffs the environments from sourcing a script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
""" | |
Works out the environmental changes from sourcing a script. | |
Usage: | |
source_diff.py [--bash | --modules] <script> [<arg> [<arg> ...]] | |
source_diff.py -h | --help | |
Options: | |
--bash Generate bash scripting output. Default | |
--modules Generate GNU Modules Modulefile syntax output | |
--warn-empty Warn when replacing an empty variable (normally counts as added) | |
""" | |
# (WIP possible future feature) | |
# [-D <definition>]... | |
# -D <definition> Add variables to the output. e.g. -DSOME_VAR=/some/path | |
# will add SOME_VAR to the output, and instances of /some/path | |
# in other variables will be replaced with SOME_VAR. | |
from __future__ import print_function | |
import argparse | |
import os | |
import sys | |
import subprocess | |
import re | |
from abc import ABCMeta, abstractmethod | |
# Look for :, but not :// | |
re_list_splitter = re.compile(r":(?!\/\/)") | |
# Look for characters that require quoting in tcl | |
re_tcl_string = re.compile(r"[\s\[\]{}\"'$]") | |
def is_bash_listlike(entry): | |
"Does this string look like a bash list?" | |
return re_list_splitter.search(entry) is not None | |
def index_of_sublist(a, b): | |
"Returns the index of b in a, or None" | |
if len(a) < len(b): | |
return None | |
for i in range(len(a)-len(b)+1): | |
if a[i:i+len(b)] == b: | |
return i | |
return None | |
def contains_sublist(a, b): | |
"Tests if list a contains sequence b" | |
return index_of_sublist(a,b) is not None | |
class OutputCategories(object): | |
"""Holds raw output, grouped into category""" | |
def __init__(self): | |
self.added = [] | |
self.replaced = [] | |
self.removed = [] | |
self.listchange = [] | |
self.assumed_listchange = [] | |
self.unhandled = [] | |
class OutputFormatter(object): | |
__metaclass__ = ABCMeta | |
def __init__(self, prior_definitions=None): | |
self.definitions = prior_definitions | |
self._output = OutputCategories() | |
@abstractmethod | |
def add(self, key, value): | |
raise NotImplementedError() | |
@abstractmethod | |
def replace(self, key, value): | |
"When a variable is replaced/written over completely" | |
raise NotImplementedError() | |
@abstractmethod | |
def unhandled(self, key, value, comment=""): | |
"When we don't know how to handle, just replace with a warning" | |
raise NotImplementedError() | |
@abstractmethod | |
def remove(self, key): | |
raise NotImplementedError() | |
@abstractmethod | |
def expand_list(self, key, prefix=[], postfix=[], assumed=True): | |
"""A list has been expanded by adding things in front or behind. | |
:param assumed: This is an assumption. Used to annotate output. | |
""" | |
raise NotImplementedError() | |
@abstractmethod | |
def dump(self): | |
raise NotImplementedError() | |
class BashFormatter(OutputFormatter): | |
def __init__(self, *args, **kwargs): | |
super(BashFormatter, self).__init__(*args, **kwargs) | |
def add(self, key, value): | |
self._output.added.append("export {}={}".format(key, value)) | |
def replace(self, key, value): | |
"When a variable is replaced/written over completely" | |
self._output.replaced.append("export {}={}".format(key, value)) | |
def unhandled(self, key, value, comment=""): | |
"When we don't know how to handle, just replace with a warning" | |
out = "export {}={}".format(key, value) | |
if comment: | |
out += " # {}".format(comment) | |
self._output.unhandled.append(out) | |
def remove(self, key): | |
self._output.removed.append("unset {}".format(key)) | |
def expand_list(self, key, prefix=[], postfix=[], assumed=False): | |
dest_list = self._output.assumed_listchange if assumed else self._output.listchange | |
dest_list.append("export {}={}".format(key, ":".join(prefix + ["$"+key] + postfix))) | |
def dump(self): | |
lines = [] | |
# Do the actual output, grouped, with information | |
if self._output.added: | |
lines.append("# Variables added") | |
lines.append("\n".join(sorted(self._output.added))) | |
lines.append("") | |
if self._output.replaced: | |
lines.append("# Variables replaced - these had a value before that changed") | |
lines.append("\n".join(sorted(self._output.replaced))) | |
lines.append("") | |
if self._output.removed: | |
lines.append("# Variables deleted/unset") | |
lines.append("\n".join(sorted(self._output.removed))) | |
lines.append("") | |
if self._output.listchange: | |
lines.append("# Lists prefixed/appended to") | |
lines.append("\n".join(sorted(self._output.listchange))) | |
lines.append("") | |
if self._output.assumed_listchange: | |
lines.append("# Variables created - but looked like a list; assuming prefix operation") | |
lines.append("\n".join(sorted(self._output.assumed_listchange))) | |
lines.append("") | |
if self._output.unhandled: | |
lines.append("# WARNING: The following were unhandled/unknown/too complex") | |
lines.append("\n".join(x + "#" for x in sorted(self._output.unhandled))) | |
lines.append("") | |
return "\n".join(lines) | |
def _tcl_escape(s): | |
"""TCL-escape a string""" | |
if not re_tcl_string.search(s): | |
return s | |
return "{" + s.replace("{", r"\{").replace("}", r"\}") + "}" | |
class GNUModulesFormatter(OutputFormatter): | |
def __init__(self, *args, **kwargs): | |
super(GNUModulesFormatter, self).__init__(*args, **kwargs) | |
def add(self, key, value): | |
self._output.added.append("setenv {} {}".format(key, _tcl_escape(value))) | |
def replace(self, key, value): | |
"When a variable is replaced/written over completely" | |
self._output.replaced.append("setenv {} {}".format(key, _tcl_escape(value))) | |
def unhandled(self, key, value, comment=""): | |
"When we don't know how to handle, just replace with a warning" | |
# out = "export {}={}".format(key, value) | |
out = "setenv {} {}".format(key, _tcl_escape(value)) | |
if comment: | |
out += " # {}".format(comment) | |
self._output.unhandled.append(out) | |
def remove(self, key): | |
self._output.removed.append("unsetenv {}".format(key)) | |
def expand_list(self, key, prefix=[], postfix=[], assumed=False): | |
dest_list = self._output.assumed_listchange if assumed else self._output.listchange | |
# Don't support single-item empty prefix/postfix | |
if len(prefix) == 1 and prefix[0].strip() == "": | |
prefix = [] | |
if len(postfix) == 1 and postfix[0].strip() == "": | |
postfix = [] | |
if prefix: | |
dest_list.append("prepend-path {} {}".format(key, _tcl_escape(":".join(prefix)))) | |
if postfix: | |
dest_list.append("append-path {} {}".format(key, _tcl_escape(":".join(postfix)))) | |
def dump(self): | |
lines = [] | |
# Do the actual output, grouped, with information | |
if self._output.added: | |
lines.append("# Variables added") | |
lines.append("\n".join(sorted(self._output.added))) | |
lines.append("") | |
if self._output.replaced: | |
lines.append("# Variables replaced - these had a value before that changed") | |
lines.append("\n".join(sorted(self._output.replaced))) | |
lines.append("") | |
if self._output.removed: | |
lines.append("# Variables deleted/unset") | |
lines.append("\n".join(sorted(self._output.removed))) | |
lines.append("") | |
if self._output.listchange: | |
lines.append("# Lists prefixed/appended to") | |
lines.append("\n".join(sorted(self._output.listchange))) | |
lines.append("") | |
if self._output.assumed_listchange: | |
lines.append("# Variables created - but looked like a list; assuming prefix operation") | |
lines.append("\n".join(sorted(self._output.assumed_listchange))) | |
lines.append("") | |
if self._output.unhandled: | |
lines.append("# WARNING: The following were unhandled/unknown/too complex") | |
lines.append("\n".join(x + "#" for x in sorted(self._output.unhandled))) | |
lines.append("") | |
return "\n".join(lines) | |
def process_argv(): | |
"""Build the parser and process the args""" | |
parser = argparse.ArgumentParser(description="Works out the environmental changes from sourcing a script.") | |
parser.add_argument('--warn-empty', action='store_true', | |
help="Warn when replacing an empty variable (normally counts as added)") | |
group = parser.add_mutually_exclusive_group() | |
group.add_argument('--bash', dest="formatter", action='store_const', | |
const=BashFormatter, default=BashFormatter, | |
help="Generate bash scripting output. Default") | |
group.add_argument('--modules', dest="formatter", action='store_const', | |
const=GNUModulesFormatter, | |
help="Generate GNU Modules Modulefile syntax output") | |
parser.add_argument("script", metavar="<script>", help="The name of the script to source") | |
parser.add_argument("args", metavar="<arg>", nargs=argparse.REMAINDER, help="Any arguments to pass to the sourced script") | |
return parser.parse_args() | |
def main(): | |
# Handle arguments and help | |
options = process_argv() | |
# The start environment is simple... | |
start_env = dict(os.environ) | |
# The magic string used to separate output | |
SIGNATURE = "#ENVDIFF_ENVDUMP#" | |
# Generate the after-environment by sourcing the script | |
script = " ".join([options.script] + [" ".join(options.args)]) | |
shell_command = ". {} 1>&2 && python -c 'import os; print(\"{}\" + repr(os.environ))'".format(script, SIGNATURE) | |
try: | |
env_output = subprocess.check_output(shell_command, shell=True, executable="/bin/bash", stderr=subprocess.STDOUT) | |
assert SIGNATURE in env_output | |
env_output = env_output[env_output.find(SIGNATURE)+len(SIGNATURE):].strip() | |
except subprocess.CalledProcessError as ex: | |
print("Error loading script: Returned non-zero status code.") | |
if ex.output: | |
print("Output from failed process:") | |
print("\n".join(" " + x for x in ex.output.splitlines())) | |
sys.exit(1) | |
sourced_env = eval(env_output) | |
# Keys to ignore - e.g. things that normally change in any sourced script | |
IGNORE = {"SHLVL", "_", "OLDPWD"} | |
for key in IGNORE: | |
if key in sourced_env: | |
del sourced_env[key] | |
if key in start_env: | |
del start_env[key] | |
# Make useful sets out of the dictionary keys | |
start_keys = set(start_env.keys()) | |
sourced_keys = set(sourced_env.keys()) | |
added_keys = sourced_keys - start_keys | |
changed_keys = {x for x in (start_keys & sourced_keys) if start_env[x] != sourced_env[x]} | |
# Choose the formatting class for output | |
formatter = options.formatter() | |
# Look for added keys that are listlike - pretend these are changes | |
for changelike in [x for x in added_keys if is_bash_listlike(sourced_env[x])]: | |
# print("({} is changelike but added - treating as list)".format(changelike)) | |
changed_keys |= {changelike} | |
added_keys = added_keys - {changelike} | |
# Keys that changed, but are not listlike, are treated separately | |
replaced_keys = set() | |
for key in list(changed_keys): | |
if not (is_bash_listlike(start_env.get(key, "")) or is_bash_listlike(sourced_env[key])): | |
# print("({} changed but not in a listlike way, overwriting)".format(key)) | |
changed_keys = changed_keys - {key} | |
# If we changed from nothing, then still count as added | |
if start_env.get(key) == "": | |
if options.warn_empty: | |
print("Warning: variable {} was replaced, but was originally empty. Emitting as add operation".format(key)) | |
added_keys |= {key} | |
else: | |
replaced_keys |= {key} | |
# Removed keys are the easy case: Must have been unset | |
for key in start_keys - sourced_keys: | |
formatter.remove(key) | |
# Firstly, added keys | |
for key in added_keys: | |
formatter.add(key, sourced_env[key]) | |
# Handle keys explicitly overwritten separately | |
for key in replaced_keys: | |
formatter.replace(key, sourced_env[key]) | |
# Now, changed keys, but we know they are lists or look like one | |
for key in changed_keys: | |
# Treat an empty start as an explicitly empty list | |
if start_env.get(key): | |
start = re_list_splitter.split(start_env.get(key, "")) | |
else: | |
start = [] | |
end = re_list_splitter.split(sourced_env[key]) | |
# If we don't have a start, assume that we added as a prefix | |
if not start: | |
formatter.expand_list(key, prefix=end, assumed=True) | |
# output.assumed_listchange.append("export {}={}".format(key, ":".join(end + ["$"+key]))) | |
else: | |
# Look for the start embedded in the end | |
if not contains_sublist(end, start): | |
formatter.unhandled(key, sourced_env[key]) | |
# output.unhandled.append("export {}={} # complex list handling?".format(key, sourced_env[key])) | |
# We don't have the original list embedded in the end list... | |
# raise NotImplementedError("Not yet handling lists with removed items") | |
else: | |
ind = index_of_sublist(end, start) | |
prefix = end[:ind] | |
suffix = end[ind+len(start):] | |
formatter.expand_list(key, prefix, suffix) | |
# new_list = prefix + ["$"+key] + suffix | |
# output.listchange.append("export {}={}".format(key, ":".join(new_list))) | |
print(formatter.dump()) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
from envdiff import contains_sublist, sublist_index | |
def test_sublist(): | |
assert contains_sublist([1,2,3,4,5], [5]) | |
assert index_of_sublist([1,2,3,4,5], [5]) == 4 | |
assert contains_sublist([1,2,3,4,5], [2,3]) | |
assert index_of_sublist([1,2,3,4,5], [2,3]) == 1 | |
assert not contains_sublist([1,2,3,4,5], [6]) | |
assert not contains_sublist([], [6]) | |
assert contains_sublist([1,2,3,4,5], []) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment