Skip to content

Instantly share code, notes, and snippets.

@erikerlandson
Created April 22, 2021 14:21
Show Gist options
  • Save erikerlandson/ace59702a7b66287f9b8de5d0d3d04d5 to your computer and use it in GitHub Desktop.
Save erikerlandson/ace59702a7b66287f9b8de5d0d3d04d5 to your computer and use it in GitHub Desktop.
SimpleRDD implemented in Ray
def ray_pi(n = 1000, k = 10):
c = SimpleRDD(range(n*k), k=k) \
.map(lambda _: (random.uniform(-1,1), random.uniform(-1,1))) \
.filter(lambda p: p[0]*p[0] + p[1]*p[1] <= 1) \
.count()
return 4 * c / (n*k)
@ray.remote
def simple_rdd_part(data):
return data
@ray.remote
def simple_rdd_map(f, data):
return [f(x) for x in data]
@ray.remote
def simple_rdd_filter(f, data):
return [x for x in data if f(x)]
@ray.remote
def simple_rdd_pmap(f, data):
# expected to return another list
# should be type checked, etc
return f(data)
@ray.remote
def simple_rdd_reduce(z, f, data):
s = z
for x in data:
s = f(s, x)
return s
@ray.remote
def simple_rdd_reduce_2(f, x1, x2):
return f(x1, x2)
@ray.remote
def simple_rdd_reduce_z(z):
return z
class SimpleRDD:
def __init__(self, data, k=2, partitions=None):
if partitions is not None:
self.nparts = len(partitions)
self.parts = partitions
else:
n = len(data)
k = max(1, min(k, n // 10))
self.nparts = k
s = n // k
b = 0
parts = []
while b < n:
parts.append(data[b:b+s])
b += s
self.parts = [simple_rdd_part.remote(p) for p in parts]
def map(self, f):
parts = [simple_rdd_map.remote(f, p) for p in self.parts]
return SimpleRDD(None, partitions=parts)
def pmap(self, f):
parts = [simple_rdd_pmap.remote(f, p) for p in self.parts]
return SimpleRDD(None, partitions=parts)
def filter(self, f):
parts = [simple_rdd_filter.remote(f, p) for p in self.parts]
return SimpleRDD(None, partitions = parts)
def collect(self):
parts = ray.get(self.parts)
data = []
for p in parts:
data.extend(p)
return data
def reduce(self, z, f):
rparts = [simple_rdd_reduce.remote(z, f, p) for p in self.parts]
r = simple_rdd_reduce_z.remote(z)
for p in rparts:
r = simple_rdd_reduce_2.remote(f, r, p)
return ray.get(r)
def count(self):
return self.pmap(lambda x: [len(x)]).reduce(0, lambda x,y: x + y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment