Created
November 4, 2011 14:21
-
-
Save draftcode/1339423 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding: utf-8 -*- | |
| from __future__ import unicode_literals | |
| import threading | |
| class Segment(object): | |
| version_count = 0 | |
| def __init__(self, parent): | |
| """ | |
| Args: | |
| parent(Segment) | |
| """ | |
| self.parent = parent | |
| if parent is not None: | |
| parent.refcount += 1 | |
| self.written = [] | |
| self.refcount = 1 | |
| self.version = Segment.version_count | |
| Segment.version_count += 1 | |
| def release(self): | |
| """Decrease the refcount.""" | |
| self.refcount -= 1 | |
| if self.refcount == 0: | |
| for v in self.written: | |
| v.release(self) | |
| if self.parent is not None: | |
| self.parent.release() | |
| def collapse(self, main): | |
| """Merge all variables to main segment.""" | |
| assert main.current is self | |
| while self.parent is not main.root and self.parent.refcount == 1: | |
| for v in self.parent.written: | |
| v.collapse(main.parent) | |
| self.parent = self.parent.parent | |
| class Revision(object): | |
| local_store = threading.local() | |
| def __init__(self, root, current): | |
| """ | |
| Args: | |
| root(Segment) | |
| current(Segment) | |
| """ | |
| self.root = root | |
| self.current = current | |
| def fork(self, f, *args, **kwargs): | |
| r = Revision(self.current, Segment(self.current)) | |
| self.current.release() | |
| self.current = Segment(self.current) | |
| def local_func(): | |
| # 論文のバージョンだとここでなにか保存しているのだけれど、なんで | |
| # だろう。 | |
| Revision.local_store.current_revision = r | |
| f(*args, **kwargs) | |
| # 論文のバージョンだとここでself.taskに保存しているのだけれど、なんで | |
| # だろう。 | |
| r.task = threading.Thread(target=local_func) | |
| r.task.start() | |
| return r | |
| def join(self, join): | |
| try: | |
| join.task.join() | |
| s = join.current | |
| while s is not join.root: | |
| for v in s.written: | |
| v.merge(self, join, s) | |
| s = s.parent | |
| finally: | |
| join.current.release() | |
| self.current.collapse(self) | |
| class Versioned(object): | |
| def __init__(self): | |
| self.versions = dict() | |
| def get(self, revision=None): | |
| if revision is None: | |
| revision = Revision.local_store.current_revision | |
| s = revision.current | |
| while s.version not in self.versions: | |
| s = s.parent | |
| return self.versions[s.version] | |
| def set(self, value, revision=None): | |
| if revision is None: | |
| revision = Revision.local_store.current_revision | |
| if revision.current.version not in self.versions: | |
| revision.current.written.append(self) | |
| self.versions[revision.current.version] = value | |
| def release(self, release): | |
| if release.version in self.versions: | |
| del self.versions[release.version] | |
| def collapse(self, main, parent): | |
| if main.current.version not in self.versions \ | |
| and parent.version in self.versions: | |
| self.set(self.versions[parent.version], main) | |
| del self.versions[parent.version] | |
| def merge(self, main, join_rev, join): | |
| if join.version in self.versions: | |
| s = join_rev.current | |
| while (s.version not in self.versions): | |
| s = s.parent | |
| if s is join: | |
| self.set(self.versions[join.version], main) | |
| def fork(f, *args, **kwargs): | |
| revision = Revision.local_store.current_revision | |
| return revision.fork(f, *args, **kwargs) | |
| def join(r): | |
| revision = Revision.local_store.current_revision | |
| revision.join(r) | |
| root_segment = Segment(None) | |
| Revision.local_store.current_revision = Revision(root_segment, root_segment) | |
| obj = Versioned() | |
| obj.set(1) | |
| print "thread1:" + str(obj.get()) | |
| def f(): | |
| obj.set(2) | |
| print "thread2:" + str(obj.get()) | |
| new_r = fork(f) | |
| print "thread1:" + str(obj.get()) | |
| join(new_r) | |
| print "thread1:" + str(obj.get()) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment