Skip to content

Instantly share code, notes, and snippets.

@mnunberg
Created July 2, 2012 00:04
Show Gist options
  • Save mnunberg/3030137 to your computer and use it in GitHub Desktop.
Save mnunberg/3030137 to your computer and use it in GitHub Desktop.
import memcacheConstants as MCC
import struct
import mc_bin_client
import logger
import time
import Queue
from threading import Thread, Lock
from membase.api.rest_client import RestConnection, RestHelper
from memcached.helper.data_helper import MemcachedClientHelper
class NodeWaiter(object):
"""
Waits for a single node to verify all of its VBuckets.
This is the base class which provides a single method for checking the
status of a single vbucket
"""
def check_vbid(self, dc, vb):
"""
Get the status of a single VBucket. Returns true if the bucket is OK
@param dc the return of direct_client
@param vb the numeric VBucket ID
"""
log = logger.Logger.get_logger()
resp = None
dc.vbucketId = vb
try:
resp = dc.get_vbucket_state(vb)[2]
except mc_bin_client.MemcachedError as e:
log.warn("Got error for node %s VB[%d]: %s",
self.node,
vb,
e)
return False
rint = struct.unpack_from("!I", resp)[0]
if rint in (MCC.VB_STATE_ACTIVE, MCC.VB_STATE_REPLICA):
#log.debug("VB[%d] => %d", vb, rint)
return True
else:
if not rint:
log.warn("Got unexpected response (%d, len=%d)",
rint, len(resp))
else:
log.warn("Got response %d for VB state", rint)
return False
class NodeSingleWaiter(NodeWaiter):
def __init__(self, node, vbs, bucket = 'default', end_time = 0):
"""
Waits for a single node to verify its VBuckets
@param node TestInputServer
@param vbs Iterable of VBucket IDs
@param bucket bucket string
@param end_time wait until this time (epoch-time)
"""
ready = set()
waiting = set(vbs)
dc = MemcachedClientHelper.direct_client(node, bucket)
while len(ready) < len(waiting):
if end_time and time.time() > end_time:
raise Exception("Timeout")
for vb in waiting:
if (self.check_vbid(dc, vb)):
ready.add(vb)
elif vb in ready:
ready.remove(vb)
class NodeThreadedWaiter(NodeWaiter):
def __init__(self, node, vbs, bucket = 'default', end_time = 0,
workers = 8):
"""
@param workers Number of worker threads to launch
"""
self.qin = Queue.Queue()
self.qout = Queue.Queue()
self.lock = Lock()
self.node = node
self.bucket = bucket
self.done = False
remaining = set(vbs)
# Initialize the queue
[ self.qin.put(x) for x in remaining ]
threads = []
for x in xrange(1, workers):
thr = Thread(target = self._worker)
thr.start()
threads.append(thr)
while len(remaining):
if end_time and time.time() > end_time:
raise Exception("Timeout waiting for VB Verification")
vbres = None
try:
vbres = self.qout.get(timeout = 0.5)
except Queue.Empty:
continue
remaining.remove(vbres)
self._set_done()
[ thr.join for thr in threads ]
def _is_done(self):
_done = None
self.lock.acquire()
_done = self.done
self.lock.release()
return _done
def _set_done(self):
self.lock.acquire()
self.done = True
self.lock.release()
def _cycle_one(self, dc):
if self._is_done():
return False
vb = None
try:
vb = self.qin.get(timeout = 0.5)
except Queue.Empty:
return True
if self.check_vbid(dc, vb):
self.qout.put(vb)
else:
self.qin.put(vb)
return True
def _worker(self):
dc = MemcachedClientHelper.direct_client(self.node,
self.bucket)
while self._cycle_one(dc):
pass
class MemcachedWaiter(object):
def __init__(self, nodes,
bucket = 'default',
timeout = 300,
nodewait_class = NodeSingleWaiter):
"""
Wait for all VBuckets on a cluster
@param nodes Iterable of TestInputServer
@param timeout Maximum amount of time to wait
@param nodewait_class Type of node-waiting class,
can be NodeThreadedWaiter or NodeSingleWaiter
"""
self.end_time = time.time() + timeout
self.log = logger.Logger.get_logger()
self.threads = []
self.vb_owners = {}
self.nodes = nodes
self.bucket = bucket
self._load_vb_owners()
for n, vb in self.vb_owners.items():
thr = Thread(
target = nodewait_class,
name = "VB-Waiter-{0}".format(n),
args = (n, vb, bucket),
kwargs = { 'end_time' : self.end_time })
thr.start()
self.log.info("Started thread for %s", n)
self.threads.append(thr)
for thr in self.threads:
thr.join()
def _load_vb_owners(self):
rest = RestConnection(self.nodes[0])
RestHelper(rest).vbucket_map_ready(self.bucket)
vbraw = rest.get_vbuckets()
self.log.debug("Got all VBuckets")
for vb in vbraw:
self.vb_owners.setdefault(vb.master, []).append(vb.id)
for replica in vb.replica:
self.vb_owners.setdefault(replica, []).append(vb.id)
for k in self.vb_owners.keys():
selected = None
ip, port = k.split(':')
for n in self.nodes:
if n.ip == ip:
selected = n
break
if not selected:
msg = "Couldn't find input for {0}".format(k)
raise Exception(msg)
self.vb_owners[selected] = self.vb_owners.pop(k)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment