Skip to content

Instantly share code, notes, and snippets.

@plowsec
Created May 16, 2024 08:25
Show Gist options
  • Select an option

  • Save plowsec/5a2f85d920cea25d78ef57ae03bfa6f5 to your computer and use it in GitHub Desktop.

Select an option

Save plowsec/5a2f85d920cea25d78ef57ae03bfa6f5 to your computer and use it in GitHub Desktop.
Use IDA Pro and several threads to analyze all binaries in a folder, run a script on each one and then decompile them
import shutil
import subprocess
import tempfile
import traceback
import json
import os
import typing
import logging
import sys
import pathlib
import threading
import concurrent.futures
import time
from tqdm import tqdm
from typing import List
fmt = '%(asctime)s | %(levelname)3s | [%(filename)s:%(lineno)3d] %(funcName)s() | %(message)s'
datefmt = '%Y-%m-%d %H:%M:%S' # Date format without milliseconds
class CustomFormatter(logging.Formatter):
COLOR_CODES = {
'DEBUG': '\033[36m', # Cyan
'INFO': '\033[35m', # Green
'WARNING': '\033[33m', # Yellow
'ERROR': '\033[31m', # Red
'CRITICAL': '\033[41m', # Red background
'RESET': '\033[0m' # Reset to default
}
def format(self, record):
color_code = self.COLOR_CODES.get(record.levelname, self.COLOR_CODES['RESET'])
record.msg = f"{color_code}{record.msg}{self.COLOR_CODES['RESET']}"
return super().format(record)
logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
handler.setFormatter(CustomFormatter(fmt, datefmt))
logger.addHandler(handler)
# Add a file handler to the logger
file_handler = logging.FileHandler('logfile.log')
file_handler.setFormatter(CustomFormatter(fmt, datefmt))
logger.addHandler(file_handler)
logger.setLevel(logging.INFO)
is_64_bit = True
def decompile(binary: str) -> (str, str):
"""
Decompiles the given binary and returns the decompiled code.
:param binary: absolute path to the binary to decompile
:return: the decompiled code
"""
try:
idat_path = "idat64" if is_64_bit else "idat"
if os.name == "nt":
idat_path = 'C:\\Program Files\\IDA Pro 8.4\\' + idat_path + '.exe'
#binary = binary.replace('\\', '/')
idb_path = f'""{binary}""'
logger.info(f"IDA default analysis...")
script_path = os.path.join(
str(os.getcwd()), "VDR\ida_ioctl_propagate.py"
)
if not os.path.exists(script_path):
logger.critical(f"You must have VDR/ida_ioctl_propagate at {script_path}")
exit(-1)
dst = os.path.splitext(binary)[0]+'.log'
# "/home/debian/idapro-8.4/idat64 -B -o${PATCH_PATH}.i64 -L${PATCH_PATH}.log ${PATCH_PATH}.${EXTENSION}"
script_path = f'""{script_path}""'
dst = f'""{dst}""'
idat_path = f'"{idat_path}"'
cmd = f'{idat_path} -B -P -L"{dst}" {idb_path}'
logger.info(f"Running command (standard analysis): {cmd}")
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
print(result.stdout)
print(result.stderr)
idb_path = f'""{binary}.i64""'
output_file = os.path.splitext(binary)[0]+'.c'
# output_file = output_file.replace("\\", "\\\\") # :( can't use absolute file names on windows...
output_file = os.path.basename(output_file)
cmd = f'{idat_path} -A -P -L"{dst}" -S"{script_path} {dst}" {idb_path}'
logger.info(f"Running command (Applying types): {cmd}")
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
print(result.stdout)
print(result.stderr)
logger.info(f"Decompiling {binary}...")
time.sleep(1)
command = f'{idat_path} -A -P -Ohexrays:{output_file}:ALL {idb_path}'
logger.info(f"Running command (decompiling): {command}")
result = subprocess.run(command, shell=True, capture_output=True, text=True)
# result = None
#result = subprocess.run([idat_path, '-Ohexrays:-errs:{output_file}:ALL', '-A', binary], shell=True, capture_output=True, text=True)
if not os.path.exists(output_file):
print("Output")
print(result.stdout)
print(result.stderr)
print("End output")
return "", ""
if os.path.exists(output_file):
new_output_file = os.path.join(os.path.dirname(binary), output_file)
logging.info(f"Moving file to {new_output_file}")
shutil.move(output_file, new_output_file)
output_file = new_output_file
# read and return the output file
with open(output_file, "r") as f:
return f.read(), output_file
except:
logger.error(f"Exception while decompiling {binary}: {traceback.format_exc()}")
return "", ""
def read_cache(cache_file: str) -> dict:
"""
Reads the cache from a given JSON file.
:param cache_file: Path to the JSON cache file
:return: Dictionary with cache data
"""
if os.path.exists(cache_file):
with open(cache_file, "r") as f:
return json.load(f)
return {}
def write_cache(cache_file: str, cache_data: dict) -> None:
"""
Writes the cache data to a given JSON file.
:param cache_file: Path to the JSON cache file
:param cache_data: Dictionary with cache data
"""
with open(cache_file, "w") as f:
json.dump(cache_data, f)
def batch_decompile(binaries: list, cache_file="decompile_cache.json") -> typing.Dict[str, str]:
"""
Decompiles a list of binaries and caches the results.
:param binaries: List of paths to the binaries
:param cache_file: Path to the cache JSON file
"""
cache_data = {}
lock = threading.Lock()
def decompile_and_cache(binary):
decompiled_code, decompiled_path = decompile(binary)
if len(decompiled_code) > 0:
with lock:
cache_data[binary] = decompiled_path
with concurrent.futures.ThreadPoolExecutor() as executor:
# Wrap the binaries list with tqdm to create a progress bar
list(tqdm(executor.map(decompile_and_cache, binaries), total=len(binaries), desc="Decompiling"))
write_cache(cache_file, cache_data)
return cache_data
def enumerate_binaries(path: str, extension="") -> List[str]:
"""
Walks through a given path and appends any executables to the list 'all_binaries'.
:param path: str, The file system path to walk through.
"""
logger.info(f"Enumerating binaries in {path}...")
all_files = [os.path.join(dirpath, filename) for dirpath, dirnames, filenames in os.walk(path) for filename
in filenames]
all_binaries = []
for file in all_files:
if file.endswith(extension):
all_binaries.append(file)
logger.info(f"Found: {all_binaries}")
return all_binaries
def filter_binaries(binaries: List[str], whitelist: List):
result = []
for binary in binaries:
basename = os.path.basename(binary)
if basename.lower() in whitelist:
result.append(binary)
return result
if __name__ == '__main__':
path = sys.argv[1]
if len(sys.argv) > 2:
is_64_bit = sys.argv[2] == "x64"
binaries = enumerate_binaries(path, extension=".sys")
#decompiled_code, decompiled_path = decompile(binaries[0])
if len(sys.argv) > 3:
whitelist = []
with open(sys.argv[3], "r") as f:
whitelist = f.read().split('\n')
logger.info(whitelist)
filtered_binaries = filter_binaries(binaries, whitelist)
logger.info(f"Filtered {len(binaries)-len(filtered_binaries)} entries, {len(filtered_binaries)} to analyze")
binaries = filtered_binaries
print(f"Len argv {len(sys.argv)}")
batch_decompile(binaries)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment