Created
May 16, 2020 03:14
-
-
Save Derfies/11d9c6550352f6974fa67ef700c77ed6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import abc | |
import random | |
import logging | |
import pyglet | |
from pglib import utils | |
from pglib.region import Region | |
from pglib.samplers.range import Range | |
from pglib.geometry.point import Point2d | |
from pglib.draw.pyglet.drawables import Rect | |
logger = logging.getLogger(__name__) | |
GRID_SPACING = 10 | |
class Factory(object): | |
def __init__(self, cls, *args, **kwargs): | |
self.cls = cls | |
self.args = args | |
self.kwargs = kwargs | |
def create_instance(self, *args): | |
all_args = self.args + args | |
return self.cls(*all_args, **self.kwargs) | |
class DataWrapper(object): | |
def __init__(self, data): | |
self.data = data | |
self.outputs = [] | |
class Node(object): | |
def __init__(self, name): | |
self.name = name | |
self.inputs = [] | |
self.children = [] | |
self.gen_factory = None | |
def add_input(self, input_): | |
self.inputs.append(input_) | |
def add_inputs(self, inputs): | |
map(self.add_input, inputs) | |
def add_raw_input(self, input_): | |
self.add_input(DataWrapper(input_)) | |
def add_raw_inputs(self, inputs): | |
map(self.add_raw_input, inputs) | |
def set_generator(self, cls, *args, **kwargs): | |
self.gen_factory = Factory(cls, *args, **kwargs) | |
def evaluate(self): | |
for input_ in self.inputs: | |
g = self.gen_factory.create_instance(input_.data) | |
g.run() | |
input_.outputs = map(DataWrapper, g.outputs) | |
class GeneratorBase(object): | |
__metaclass__ = abc.ABCMeta | |
def __init__(self, input_, max_iters=100): | |
self.input = input_ | |
self.max_iters = max_iters | |
self.outputs = [] | |
self.iter = 0 | |
@property | |
def is_iteration_done(self): | |
return False | |
# Make a node method? | |
@staticmethod | |
def is_recursion_done(output): | |
return True | |
def is_data_ok(self, data): | |
return True | |
@abc.abstractmethod | |
def create_data(self): | |
"""""" | |
def run(self): | |
while self.iter < self.max_iters and not self.is_iteration_done: | |
data = self.create_data() | |
if self.is_data_ok(data): | |
self.outputs.append(data) | |
self.iter += 1 | |
logger.info('Iteration ended after {} iterations'.format(self.iter)) | |
class RectGenerator(GeneratorBase): | |
def __init__(self, width, height, *args, **kwargs): | |
super(RectGenerator, self).__init__(*args, **kwargs) | |
self.width = width | |
self.height = height | |
def is_data_ok(self, region): | |
for output in self.outputs: | |
if region.intersects(output): | |
return False | |
else: | |
return True | |
def create_data(self): | |
w = self.width.run() | |
h = self.height.run() | |
rx = self.input.width - w | |
ry = self.input.height - h | |
x = random.randint(0, rx) | |
y = random.randint(0, ry) | |
return Region( | |
self.input.x1 + x, | |
self.input.y1 + y, | |
self.input.x1 + x + w, | |
self.input.y1 + y + h | |
) | |
if __name__ == '__main__': | |
logging.basicConfig(level=logging.INFO) | |
node = Node('root') | |
node.add_raw_input(Region(0, 0, 50, 50)) | |
node.add_raw_input(Region(50, 50, 100, 100)) | |
node.set_generator(RectGenerator, Range(10), Range(10), max_iters=1000) | |
node.evaluate() | |
logger.info('Num inputs: {}'.format(len(node.inputs))) | |
for i, input_ in enumerate(node.inputs): | |
logger.info('Input: {} has {} outputs'.format(i, len(input_.outputs))) | |
# Create test app and run. | |
window = pyglet.window.Window(1000, 1000) | |
@window.event | |
def on_draw(): | |
window.clear() | |
for i, input_ in enumerate(node.inputs): | |
r = Rect( | |
Point2d(input_.data.x1 * GRID_SPACING, input_.data.y1 * GRID_SPACING), | |
Point2d(input_.data.x2 * GRID_SPACING, input_.data.y2 * GRID_SPACING), | |
colour=None, | |
line_colour=utils.get_random_colour(1), | |
line_width=4, | |
) | |
r.draw() | |
for output in input_.outputs: | |
r = Rect( | |
Point2d(output.data.x1 * GRID_SPACING, output.data.y1 * GRID_SPACING), | |
Point2d(output.data.x2 * GRID_SPACING, output.data.y2 * GRID_SPACING), | |
colour=None, | |
line_colour=utils.get_random_colour(1), | |
line_width=4, | |
) | |
r.draw() | |
#image.blit(0, 0) | |
pyglet.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment