|
import redis |
|
|
|
class RedisTF(object): |
|
def __init__(self, conn): |
|
self.__conn = conn |
|
|
|
# Loads a graph |
|
def SetGraph(self, graph, path): |
|
with open(path, 'r') as f: |
|
payload = f.read() |
|
return self.__conn.execute_command('TF.GRAPH', graph, payload) |
|
|
|
# Sets a tensor |
|
def SetTensor(self, tensor, dtype, shape, values): |
|
args = [tensor, dtype] |
|
args += [str(x) for x in shape] |
|
args.append('VALUES') |
|
args += [str(x) for x in values] |
|
return self.__conn.execute_command('TF.TENSOR', *args) |
|
|
|
# Runs a graph with a list of input tensor-name tuples, storing the result |
|
def Run(self, graph, inputs, output): |
|
args = [graph, len(inputs)] |
|
for i in inputs: |
|
args.append(i[0]) |
|
args.append(i[1]) |
|
args += [str(x) for x in output] |
|
return self.__conn.execute_command('TF.RUN', *args) |
|
|
|
# Gets the value of a tensor |
|
def Values(self, tensor): |
|
return self.__conn.execute_command('TF.VALUES', tensor) |
|
|
|
if __name__ == '__main__': |
|
conn = redis.StrictRedis() |
|
rtf = RedisTF(conn) |
|
print 'Setting the graph: {}'.format(rtf.SetGraph('graph', 'graph.pb')) |
|
print 'Setting tensor t1: {}'.format(rtf.SetTensor('t1', 'FLOAT', [1, 2], [2, 3])) |
|
print 'Setting tensor t2: {}'.format(rtf.SetTensor('t2', 'FLOAT', [1, 2], [2, 3])) |
|
print 'Running the thing: {}'.format(rtf.Run('graph', [('t1', 'a'), ('t2', 'b')], ('t3', 'c'))) |
|
print 'Resulting values: {}'.format(rtf.Values('t3')) |