Skip to content

Instantly share code, notes, and snippets.

@segfo
Last active March 15, 2026 23:02
Show Gist options
  • Select an option

  • Save segfo/f85b7e56b86ca056d208f81f0ab86c7c to your computer and use it in GitHub Desktop.

Select an option

Save segfo/f85b7e56b86ca056d208f81f0ab86c7c to your computer and use it in GitHub Desktop.
チェックポイントに含まれるpickleのRCE(疑い)コードの検知スクリプト
import pickletools
import zipfile
from pathlib import Path
import sys
import dis
import inspect
import importlib
# ---------------------------------------
# dangerous calls
# ---------------------------------------
DANGEROUS_CALLS = {
("os","system"),
("nt","system"),
("posix","system"),
("subprocess","Popen"),
("subprocess","run"),
("subprocess","call"),
("builtins","eval"),
("builtins","exec"),
}
# tensor rebuild
TENSOR_REBUILD = {
("torch._utils","_rebuild_tensor_v2"),
("numpy.core.multiarray","_reconstruct"),
("copyreg","_reconstructor"),
}
# suspicious strings
DANGEROUS_STRINGS = [
"cmd.exe",
"powershell",
"curl ",
"wget ",
"bash ",
"nc ",
"netcat",
"Invoke-WebRequest",
"certutil",
"bitsadmin"
]
# ---------------------------------------
# バイトコードアナライザ
# ---------------------------------------
def analyze_method_bytecode(func):
print(f"\n[Bytecode analysis] {func}")
try:
instructions = list(dis.get_instructions(func))
except:
return
last_global=None
for ins in instructions:
if ins.opname=="LOAD_GLOBAL":
last_global=ins.argval
elif ins.opname in ("LOAD_ATTR","LOAD_METHOD"):
attr=ins.argval
if last_global:
key=(last_global,attr)
if key in DANGEROUS_CALLS:
print("⚠ Dangerous bytecode call detected")
print(" ",last_global+"."+attr)
elif ins.opname.startswith("CALL"):
pass
def analyze_class_methods(module_name, class_name):
cls = None
try:
if module_name == "__main__":
main_mod = sys.modules.get("__main__")
if hasattr(main_mod, class_name):
cls = getattr(main_mod, class_name)
else:
module = importlib.import_module(module_name)
cls = getattr(module, class_name)
except Exception as e:
print("class load failed:", e)
return
if cls is None:
print("⚠ class not found:", module_name, class_name)
return
targets = [
"__setstate__",
"__reduce__",
"__reduce_ex__",
"__getstate__"
]
for name in targets:
if hasattr(cls, name):
func = getattr(cls, name)
if inspect.isfunction(func) or inspect.ismethod(func):
analyze_method_bytecode(func)
# ---------------------------------------
# PICKLEアナライザ
# ---------------------------------------
class PickleAnalyzer:
def __init__(self):
self.stack=[]
self.markers=[]
self.memo={}
self.globals_found=set()
self.score=0
# オペコードハンドラ用ディスパッチテーブル
self._handlers = {
"FRAME": self._handle_FRAME,
"GLOBAL": self._handle_GLOBAL,
"STACK_GLOBAL": self._handle_STACK_GLOBAL,
"BINUNICODE": self._handle_STR,
"SHORT_BINUNICODE": self._handle_STR,
"UNICODE": self._handle_STR,
"STRING": self._handle_STR,
"BININT": self._handle_INT,
"BININT1": self._handle_INT,
"BININT2": self._handle_INT,
"LONG": self._handle_INT,
"NONE": self._handle_NONE,
"MARK": self._handle_MARK,
"TUPLE": self._handle_TUPLE,
"TUPLE1": self._handle_TUPLE1,
"TUPLE2": self._handle_TUPLE2,
"TUPLE3": self._handle_TUPLE3,
"LIST": self._handle_LIST,
"DICT": self._handle_DICT,
"PUT": self._handle_PUT,
"BINPUT": self._handle_PUT,
"LONG_BINPUT": self._handle_PUT,
"GET": self._handle_GET,
"BINGET": self._handle_GET,
"LONG_BINGET": self._handle_GET,
"NEWOBJ": self._handle_NEWOBJ,
"NEWOBJ_EX": self._handle_NEWOBJ_EX,
"BUILD": self._handle_BUILD,
"REDUCE": self._handle_REDUCE,
}
def push(self,obj):
self.stack.append(obj)
def pop(self):
if not self.stack:
return None
return self.stack.pop()
def mark(self):
self.markers.append(len(self.stack))
def pop_mark(self):
if not self.markers:
return []
mark=self.markers.pop()
items=self.stack[mark:]
self.stack=self.stack[:mark]
return items
def check_string(self,val):
if not isinstance(val,str):
return
for bad in DANGEROUS_STRINGS:
if bad.lower() in val.lower():
print("⚠ suspicious string:",val)
self.score+=3
# -----------------------------
# pickleオペコードに対応するハンドラ
# -----------------------------
def _handle_FRAME(self, op, arg, pos):
print("FRAME detected size:",arg)
def _handle_GLOBAL(self, op, arg, pos):
module,func=arg.split(" ")
self.globals_found.add(arg)
print("GLOBAL:",module,func)
analyze_class_methods(module,func)
self.push(("callable",module,func))
def _handle_STACK_GLOBAL(self, op, arg, pos):
func_name=self.pop()
module_name=self.pop()
if not func_name or not module_name:
return
module=module_name[1]
func=func_name[1]
self.globals_found.add(module+" "+func)
self.push(("callable",module,func))
def _handle_STR(self, op, arg, pos):
self.push(("str",arg))
self.check_string(arg)
def _handle_INT(self, op, arg, pos):
self.push(("int",arg))
def _handle_NONE(self, op, arg, pos):
self.push(("none",None))
def _handle_MARK(self, op, arg, pos):
self.mark()
def _handle_TUPLE(self, op, arg, pos):
items=self.pop_mark()
self.push(("tuple",items))
def _handle_TUPLE1(self, op, arg, pos):
a=self.pop()
self.push(("tuple",[a]))
def _handle_TUPLE2(self, op, arg, pos):
b=self.pop()
a=self.pop()
self.push(("tuple",[a,b]))
def _handle_TUPLE3(self, op, arg, pos):
c=self.pop()
b=self.pop()
a=self.pop()
self.push(("tuple",[a,b,c]))
def _handle_LIST(self, op, arg, pos):
items=self.pop_mark()
self.push(("list",items))
def _handle_DICT(self, op, arg, pos):
items=self.pop_mark()
d = self._build_dict_from_items(items)
self.push(("dict",d))
def _build_dict_from_items(self, items):
d = {}
for i in range(0, len(items), 2):
key = items[i]
val = items[i+1] if i+1 < len(items) else None
if key and key[0] == "str":
d[key[1]] = val
return d
def _handle_PUT(self, op, arg, pos):
self.memo[arg]=self.stack[:-1]
def _handle_GET(self, op, arg, pos):
if arg in self.memo:
self.push(self.memo[arg])
def _handle_NEWOBJ(self, op, arg, pos):
args=self.pop()
cls=self.pop()
print("⚠ NEWOBJ object:",cls)
self.score+=1
self.push(("object",(cls,args)))
def _handle_NEWOBJ_EX(self, op, arg, pos):
kwargs=self.pop()
args=self.pop()
cls=self.pop()
print("⚠ NEWOBJ_EX object:",cls)
self.score+=1
self.push(("object",(cls,args,kwargs)))
def _handle_BUILD(self, op, arg, pos):
state=self.pop()
obj=self.pop()
print("⚠ BUILD state injection")
safe_print_object(obj)
print(" state:",state)
self.score+=2
self.push(("object",obj))
def _handle_REDUCE(self, op, arg, pos):
args=self.pop()
func=self.pop()
if func and func[0]=="callable":
module,fname=func[1],func[2]
if (module,fname) in DANGEROUS_CALLS:
print("\n⚠ DANGEROUS CALL DETECTED")
print(" function:",module+"."+fname)
self.score+=10
if (module,fname) in TENSOR_REBUILD:
print("Tensor rebuild:",module+"."+fname)
self.push(("result",None))
def analyze(self,data):
for op,arg,pos in pickletools.genops(data):
name = op.name
handler = self._handlers.get(name, self._handle_default)
try:
handler(op, arg, pos)
except Exception:
# preserve previous behavior of skipping on unexpected errors
pass
def _handle_default(self, op, arg, pos):
# unhandled opcode: do nothing
return
def safe_print_object(obj, max_items=10):
"""tuple/list/dict の先頭 max_items 要素だけ表示"""
if isinstance(obj, (list, tuple)):
preview = obj[:max_items]
s = f"{preview}"
if len(obj) > max_items:
s = s[:-1] + f", ... ({len(obj)} items total)]"
return s
elif isinstance(obj, dict):
items = list(obj.items())[:max_items]
s = f"{dict(items)}"
if len(obj) > max_items:
s = s[:-1] + f", ... ({len(obj)} items total)}}"
return s
else:
return repr(obj)
# ---------------------------------------
# チェックポイントのスキャン
# ---------------------------------------
def analyze_checkpoint(path):
analyzer=PickleAnalyzer()
path=Path(path)
if path.suffix.lower()==".pkl":
print("\n--- scanning",path.name,"---")
with open(path,"rb") as f:
data=f.read()
analyzer.analyze(data)
else:
with zipfile.ZipFile(path) as z:
for name in z.namelist():
if name.endswith(".pkl"):
print("\n--- scanning",name,"---")
data=z.read(name)
analyzer.analyze(data)
print("\nGLOBAL entries:")
for g in sorted(analyzer.globals_found):
print(" ",g)
print("\nRISK SCORE:",analyzer.score)
if analyzer.score>=10:
print("⚠ HIGH RISK PICKLE")
elif analyzer.score>=3:
print("⚠ SUSPICIOUS PICKLE")
else:
print("OK")
# ---------------------------------------
if __name__=="__main__":
if len(sys.argv)!=2:
print("Usage: python analyze_ckpt.py model.pt")
sys.exit(1)
analyze_checkpoint(sys.argv[1])
@segfo
Copy link
Copy Markdown
Author

segfo commented Mar 15, 2026

テストコード

import pickle
import subprocess
import os

class Evil:
    def __reduce__(self):
        return (os.system, ("calc.exe",))

# Subprocess版
# class Evil:
#     def __reduce__(self):
#         return (subprocess.Popen, (["calc.exe"],))

# SetState版(特定のプログラムの特定コードを実行させる系のコードである場合はこういうのが使われる・・・かも)
# class Evil:
#     def __getstate__(self):
#         return {"x":1}

#     def __setstate__(self,state):
#         import os
#         os.system("calc.exe")

payload = pickle.dumps(Evil())

with open("test.pkl","wb") as fp:
    fp.write(payload)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment