Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save agibsonccc/09e794e024d49db99e452fc5be460e76 to your computer and use it in GitHub Desktop.
Save agibsonccc/09e794e024d49db99e452fc5be460e76 to your computer and use it in GitHub Desktop.
Graph<NDArrayInformation,OpState> graph = new Graph<>();
ArrayFactory arrayFactory = new ArrayFactory(graph);
DifferentialFunctionFactory<ArrayField> arrayFieldDifferentialFunctionFactory = new DifferentialFunctionFactory<>(graph,arrayFactory);
NDArrayInformation xInfo = NDArrayInformation.
builder().
shape(new int[]{1,1}).
id("x").
build();
NDArrayVertex xVertex = new NDArrayVertex(0,xInfo);
//2 * x
Variable<ArrayField> x = arrayFieldDifferentialFunctionFactory.var("x",new ArrayField(xVertex, graph));
DifferentialFunction<ArrayField> h = x.mul(x);
//x and result are the vertices
assertEquals(2,graph.numVertices());
//x * x - edges for only 1 vertex
assertEquals(1,graph.getEdges().size());
//2 edges
assertEquals(2,graph.getEdges().get(0).size());
System.out.println("Pre graph " + graph);
// for(int i = 0; i < 8; i++)
System.out.println(h.diff(x).getValue());
System.out.println(graph);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment