Skip to content

Instantly share code, notes, and snippets.

@draftcode
Created November 4, 2011 14:21
Show Gist options
  • Select an option

  • Save draftcode/1339423 to your computer and use it in GitHub Desktop.

Select an option

Save draftcode/1339423 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import threading
class Segment(object):
version_count = 0
def __init__(self, parent):
"""
Args:
parent(Segment)
"""
self.parent = parent
if parent is not None:
parent.refcount += 1
self.written = []
self.refcount = 1
self.version = Segment.version_count
Segment.version_count += 1
def release(self):
"""Decrease the refcount."""
self.refcount -= 1
if self.refcount == 0:
for v in self.written:
v.release(self)
if self.parent is not None:
self.parent.release()
def collapse(self, main):
"""Merge all variables to main segment."""
assert main.current is self
while self.parent is not main.root and self.parent.refcount == 1:
for v in self.parent.written:
v.collapse(main.parent)
self.parent = self.parent.parent
class Revision(object):
local_store = threading.local()
def __init__(self, root, current):
"""
Args:
root(Segment)
current(Segment)
"""
self.root = root
self.current = current
def fork(self, f, *args, **kwargs):
r = Revision(self.current, Segment(self.current))
self.current.release()
self.current = Segment(self.current)
def local_func():
# 論文のバージョンだとここでなにか保存しているのだけれど、なんで
# だろう。
Revision.local_store.current_revision = r
f(*args, **kwargs)
# 論文のバージョンだとここでself.taskに保存しているのだけれど、なんで
# だろう。
r.task = threading.Thread(target=local_func)
r.task.start()
return r
def join(self, join):
try:
join.task.join()
s = join.current
while s is not join.root:
for v in s.written:
v.merge(self, join, s)
s = s.parent
finally:
join.current.release()
self.current.collapse(self)
class Versioned(object):
def __init__(self):
self.versions = dict()
def get(self, revision=None):
if revision is None:
revision = Revision.local_store.current_revision
s = revision.current
while s.version not in self.versions:
s = s.parent
return self.versions[s.version]
def set(self, value, revision=None):
if revision is None:
revision = Revision.local_store.current_revision
if revision.current.version not in self.versions:
revision.current.written.append(self)
self.versions[revision.current.version] = value
def release(self, release):
if release.version in self.versions:
del self.versions[release.version]
def collapse(self, main, parent):
if main.current.version not in self.versions \
and parent.version in self.versions:
self.set(self.versions[parent.version], main)
del self.versions[parent.version]
def merge(self, main, join_rev, join):
if join.version in self.versions:
s = join_rev.current
while (s.version not in self.versions):
s = s.parent
if s is join:
self.set(self.versions[join.version], main)
def fork(f, *args, **kwargs):
revision = Revision.local_store.current_revision
return revision.fork(f, *args, **kwargs)
def join(r):
revision = Revision.local_store.current_revision
revision.join(r)
root_segment = Segment(None)
Revision.local_store.current_revision = Revision(root_segment, root_segment)
obj = Versioned()
obj.set(1)
print "thread1:" + str(obj.get())
def f():
obj.set(2)
print "thread2:" + str(obj.get())
new_r = fork(f)
print "thread1:" + str(obj.get())
join(new_r)
print "thread1:" + str(obj.get())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment