Last active
March 15, 2026 23:02
-
-
Save segfo/f85b7e56b86ca056d208f81f0ab86c7c to your computer and use it in GitHub Desktop.
チェックポイントに含まれるpickleのRCE(疑い)コードの検知スクリプト
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import pickletools | |
| import zipfile | |
| from pathlib import Path | |
| import sys | |
| import dis | |
| import inspect | |
| import importlib | |
| # --------------------------------------- | |
| # dangerous calls | |
| # --------------------------------------- | |
| DANGEROUS_CALLS = { | |
| ("os","system"), | |
| ("nt","system"), | |
| ("posix","system"), | |
| ("subprocess","Popen"), | |
| ("subprocess","run"), | |
| ("subprocess","call"), | |
| ("builtins","eval"), | |
| ("builtins","exec"), | |
| } | |
| # tensor rebuild | |
| TENSOR_REBUILD = { | |
| ("torch._utils","_rebuild_tensor_v2"), | |
| ("numpy.core.multiarray","_reconstruct"), | |
| ("copyreg","_reconstructor"), | |
| } | |
| # suspicious strings | |
| DANGEROUS_STRINGS = [ | |
| "cmd.exe", | |
| "powershell", | |
| "curl ", | |
| "wget ", | |
| "bash ", | |
| "nc ", | |
| "netcat", | |
| "Invoke-WebRequest", | |
| "certutil", | |
| "bitsadmin" | |
| ] | |
| # --------------------------------------- | |
| # バイトコードアナライザ | |
| # --------------------------------------- | |
| def analyze_method_bytecode(func): | |
| print(f"\n[Bytecode analysis] {func}") | |
| try: | |
| instructions = list(dis.get_instructions(func)) | |
| except: | |
| return | |
| last_global=None | |
| for ins in instructions: | |
| if ins.opname=="LOAD_GLOBAL": | |
| last_global=ins.argval | |
| elif ins.opname in ("LOAD_ATTR","LOAD_METHOD"): | |
| attr=ins.argval | |
| if last_global: | |
| key=(last_global,attr) | |
| if key in DANGEROUS_CALLS: | |
| print("⚠ Dangerous bytecode call detected") | |
| print(" ",last_global+"."+attr) | |
| elif ins.opname.startswith("CALL"): | |
| pass | |
| def analyze_class_methods(module_name, class_name): | |
| cls = None | |
| try: | |
| if module_name == "__main__": | |
| main_mod = sys.modules.get("__main__") | |
| if hasattr(main_mod, class_name): | |
| cls = getattr(main_mod, class_name) | |
| else: | |
| module = importlib.import_module(module_name) | |
| cls = getattr(module, class_name) | |
| except Exception as e: | |
| print("class load failed:", e) | |
| return | |
| if cls is None: | |
| print("⚠ class not found:", module_name, class_name) | |
| return | |
| targets = [ | |
| "__setstate__", | |
| "__reduce__", | |
| "__reduce_ex__", | |
| "__getstate__" | |
| ] | |
| for name in targets: | |
| if hasattr(cls, name): | |
| func = getattr(cls, name) | |
| if inspect.isfunction(func) or inspect.ismethod(func): | |
| analyze_method_bytecode(func) | |
| # --------------------------------------- | |
| # PICKLEアナライザ | |
| # --------------------------------------- | |
| class PickleAnalyzer: | |
| def __init__(self): | |
| self.stack=[] | |
| self.markers=[] | |
| self.memo={} | |
| self.globals_found=set() | |
| self.score=0 | |
| # オペコードハンドラ用ディスパッチテーブル | |
| self._handlers = { | |
| "FRAME": self._handle_FRAME, | |
| "GLOBAL": self._handle_GLOBAL, | |
| "STACK_GLOBAL": self._handle_STACK_GLOBAL, | |
| "BINUNICODE": self._handle_STR, | |
| "SHORT_BINUNICODE": self._handle_STR, | |
| "UNICODE": self._handle_STR, | |
| "STRING": self._handle_STR, | |
| "BININT": self._handle_INT, | |
| "BININT1": self._handle_INT, | |
| "BININT2": self._handle_INT, | |
| "LONG": self._handle_INT, | |
| "NONE": self._handle_NONE, | |
| "MARK": self._handle_MARK, | |
| "TUPLE": self._handle_TUPLE, | |
| "TUPLE1": self._handle_TUPLE1, | |
| "TUPLE2": self._handle_TUPLE2, | |
| "TUPLE3": self._handle_TUPLE3, | |
| "LIST": self._handle_LIST, | |
| "DICT": self._handle_DICT, | |
| "PUT": self._handle_PUT, | |
| "BINPUT": self._handle_PUT, | |
| "LONG_BINPUT": self._handle_PUT, | |
| "GET": self._handle_GET, | |
| "BINGET": self._handle_GET, | |
| "LONG_BINGET": self._handle_GET, | |
| "NEWOBJ": self._handle_NEWOBJ, | |
| "NEWOBJ_EX": self._handle_NEWOBJ_EX, | |
| "BUILD": self._handle_BUILD, | |
| "REDUCE": self._handle_REDUCE, | |
| } | |
| def push(self,obj): | |
| self.stack.append(obj) | |
| def pop(self): | |
| if not self.stack: | |
| return None | |
| return self.stack.pop() | |
| def mark(self): | |
| self.markers.append(len(self.stack)) | |
| def pop_mark(self): | |
| if not self.markers: | |
| return [] | |
| mark=self.markers.pop() | |
| items=self.stack[mark:] | |
| self.stack=self.stack[:mark] | |
| return items | |
| def check_string(self,val): | |
| if not isinstance(val,str): | |
| return | |
| for bad in DANGEROUS_STRINGS: | |
| if bad.lower() in val.lower(): | |
| print("⚠ suspicious string:",val) | |
| self.score+=3 | |
| # ----------------------------- | |
| # pickleオペコードに対応するハンドラ | |
| # ----------------------------- | |
| def _handle_FRAME(self, op, arg, pos): | |
| print("FRAME detected size:",arg) | |
| def _handle_GLOBAL(self, op, arg, pos): | |
| module,func=arg.split(" ") | |
| self.globals_found.add(arg) | |
| print("GLOBAL:",module,func) | |
| analyze_class_methods(module,func) | |
| self.push(("callable",module,func)) | |
| def _handle_STACK_GLOBAL(self, op, arg, pos): | |
| func_name=self.pop() | |
| module_name=self.pop() | |
| if not func_name or not module_name: | |
| return | |
| module=module_name[1] | |
| func=func_name[1] | |
| self.globals_found.add(module+" "+func) | |
| self.push(("callable",module,func)) | |
| def _handle_STR(self, op, arg, pos): | |
| self.push(("str",arg)) | |
| self.check_string(arg) | |
| def _handle_INT(self, op, arg, pos): | |
| self.push(("int",arg)) | |
| def _handle_NONE(self, op, arg, pos): | |
| self.push(("none",None)) | |
| def _handle_MARK(self, op, arg, pos): | |
| self.mark() | |
| def _handle_TUPLE(self, op, arg, pos): | |
| items=self.pop_mark() | |
| self.push(("tuple",items)) | |
| def _handle_TUPLE1(self, op, arg, pos): | |
| a=self.pop() | |
| self.push(("tuple",[a])) | |
| def _handle_TUPLE2(self, op, arg, pos): | |
| b=self.pop() | |
| a=self.pop() | |
| self.push(("tuple",[a,b])) | |
| def _handle_TUPLE3(self, op, arg, pos): | |
| c=self.pop() | |
| b=self.pop() | |
| a=self.pop() | |
| self.push(("tuple",[a,b,c])) | |
| def _handle_LIST(self, op, arg, pos): | |
| items=self.pop_mark() | |
| self.push(("list",items)) | |
| def _handle_DICT(self, op, arg, pos): | |
| items=self.pop_mark() | |
| d = self._build_dict_from_items(items) | |
| self.push(("dict",d)) | |
| def _build_dict_from_items(self, items): | |
| d = {} | |
| for i in range(0, len(items), 2): | |
| key = items[i] | |
| val = items[i+1] if i+1 < len(items) else None | |
| if key and key[0] == "str": | |
| d[key[1]] = val | |
| return d | |
| def _handle_PUT(self, op, arg, pos): | |
| self.memo[arg]=self.stack[:-1] | |
| def _handle_GET(self, op, arg, pos): | |
| if arg in self.memo: | |
| self.push(self.memo[arg]) | |
| def _handle_NEWOBJ(self, op, arg, pos): | |
| args=self.pop() | |
| cls=self.pop() | |
| print("⚠ NEWOBJ object:",cls) | |
| self.score+=1 | |
| self.push(("object",(cls,args))) | |
| def _handle_NEWOBJ_EX(self, op, arg, pos): | |
| kwargs=self.pop() | |
| args=self.pop() | |
| cls=self.pop() | |
| print("⚠ NEWOBJ_EX object:",cls) | |
| self.score+=1 | |
| self.push(("object",(cls,args,kwargs))) | |
| def _handle_BUILD(self, op, arg, pos): | |
| state=self.pop() | |
| obj=self.pop() | |
| print("⚠ BUILD state injection") | |
| safe_print_object(obj) | |
| print(" state:",state) | |
| self.score+=2 | |
| self.push(("object",obj)) | |
| def _handle_REDUCE(self, op, arg, pos): | |
| args=self.pop() | |
| func=self.pop() | |
| if func and func[0]=="callable": | |
| module,fname=func[1],func[2] | |
| if (module,fname) in DANGEROUS_CALLS: | |
| print("\n⚠ DANGEROUS CALL DETECTED") | |
| print(" function:",module+"."+fname) | |
| self.score+=10 | |
| if (module,fname) in TENSOR_REBUILD: | |
| print("Tensor rebuild:",module+"."+fname) | |
| self.push(("result",None)) | |
| def analyze(self,data): | |
| for op,arg,pos in pickletools.genops(data): | |
| name = op.name | |
| handler = self._handlers.get(name, self._handle_default) | |
| try: | |
| handler(op, arg, pos) | |
| except Exception: | |
| # preserve previous behavior of skipping on unexpected errors | |
| pass | |
| def _handle_default(self, op, arg, pos): | |
| # unhandled opcode: do nothing | |
| return | |
| def safe_print_object(obj, max_items=10): | |
| """tuple/list/dict の先頭 max_items 要素だけ表示""" | |
| if isinstance(obj, (list, tuple)): | |
| preview = obj[:max_items] | |
| s = f"{preview}" | |
| if len(obj) > max_items: | |
| s = s[:-1] + f", ... ({len(obj)} items total)]" | |
| return s | |
| elif isinstance(obj, dict): | |
| items = list(obj.items())[:max_items] | |
| s = f"{dict(items)}" | |
| if len(obj) > max_items: | |
| s = s[:-1] + f", ... ({len(obj)} items total)}}" | |
| return s | |
| else: | |
| return repr(obj) | |
| # --------------------------------------- | |
| # チェックポイントのスキャン | |
| # --------------------------------------- | |
| def analyze_checkpoint(path): | |
| analyzer=PickleAnalyzer() | |
| path=Path(path) | |
| if path.suffix.lower()==".pkl": | |
| print("\n--- scanning",path.name,"---") | |
| with open(path,"rb") as f: | |
| data=f.read() | |
| analyzer.analyze(data) | |
| else: | |
| with zipfile.ZipFile(path) as z: | |
| for name in z.namelist(): | |
| if name.endswith(".pkl"): | |
| print("\n--- scanning",name,"---") | |
| data=z.read(name) | |
| analyzer.analyze(data) | |
| print("\nGLOBAL entries:") | |
| for g in sorted(analyzer.globals_found): | |
| print(" ",g) | |
| print("\nRISK SCORE:",analyzer.score) | |
| if analyzer.score>=10: | |
| print("⚠ HIGH RISK PICKLE") | |
| elif analyzer.score>=3: | |
| print("⚠ SUSPICIOUS PICKLE") | |
| else: | |
| print("OK") | |
| # --------------------------------------- | |
| if __name__=="__main__": | |
| if len(sys.argv)!=2: | |
| print("Usage: python analyze_ckpt.py model.pt") | |
| sys.exit(1) | |
| analyze_checkpoint(sys.argv[1]) |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
テストコード