Created
January 20, 2013 15:07
-
-
Save wence-/4579273 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from mpi4py import MPI | |
| from pyop2 import op2 | |
| import numpy as np | |
| op2.init(backend='sequential') | |
| c = MPI.COMM_WORLD | |
| if c.size != 2: | |
| c.Abort() | |
| if c.rank == 0: | |
| node_halo = op2.Halo(sends=([], [4,5,6,7]), | |
| receives=([], [8,9,10,11]), | |
| comm=c) | |
| node_set = op2.Set((8,8,8,12), 'nodes', halo=node_halo) | |
| ele_halo = op2.Halo(sends=([], [3,4]), | |
| receives=([], [5]), | |
| comm=c) | |
| ele_set = op2.Set((3,5,6,6), 'elements', halo=ele_halo) | |
| ele_node_map = op2.Map(ele_set, node_set, 4, | |
| [0,1,4,5, | |
| 1,2,5,6, | |
| 2,3,6,7, | |
| 5,6,9,10, | |
| 6,7,10,11, | |
| 4,5,8,9], 'elements_to_nodes') | |
| if c.rank == 1: | |
| node_halo = op2.Halo(sends=([0,1,2,3], []), | |
| receives=([8,9,10,11], []), | |
| comm=c) | |
| node_set = op2.Set((8,8,8,12), 'nodes', halo=node_halo) | |
| ele_halo = op2.Halo(sends=([3], []), | |
| receives=([4,5], []), | |
| comm=c) | |
| ele_set = op2.Set((3,4,6,6), 'elements', halo=ele_halo) | |
| ele_node_map = op2.Map(ele_set, node_set, 4, | |
| [0,1,4,5, | |
| 1,2,5,6, | |
| 2,3,6,7, | |
| 0,1,8,9, | |
| 1,2,9,10, | |
| 2,3,10,11], 'elements_to_nodes') | |
| inc = op2.Kernel('void inc(int *x) { ++*x; }', 'inc') | |
| ele_count = op2.Global(1, 0, dtype=np.int32, name='ele_count') | |
| node_count = op2.Global(1, 0, dtype=np.int32, name='node_count') | |
| # Count global number of nodes | |
| op2.par_loop(inc, node_set, node_count(op2.INC)) | |
| print 'On rank %s, we think mesh has %d global nodes' % (c.rank, node_count.data) | |
| # Count global number of elements | |
| op2.par_loop(inc, ele_set, ele_count(op2.INC)) | |
| print 'On rank %s, we think mesh has %d global elements' % (c.rank, ele_count.data) | |
| node_dat = op2.Dat(node_set, 1, data=[0]*node_set.total_size, dtype=np.int32, | |
| name='node_data') | |
| ele_dat = op2.Dat(ele_set, 1, data=[0]*ele_set.total_size, dtype=np.int32, | |
| name='ele_data') | |
| ecount = op2.Kernel("""void count(int *x[4], int *y) { | |
| for ( int i = 0; i < 4; i++ ) x[i][0] += *y; }""", | |
| 'count') | |
| # Force dirtyness of ele_dat by incrementing each value | |
| op2.par_loop(inc, ele_set, ele_dat(op2.IdentityMap, op2.INC)) | |
| # increment each node by the value of the elements touching it | |
| # This will necessitate a halo exchange of ele_dat. | |
| op2.par_loop(ecount, ele_set, node_dat(ele_node_map, op2.INC), | |
| ele_dat(op2.IdentityMap, op2.READ)) | |
| print c.rank, node_dat.data_ro[:node_set.size] | |
| op2.exit() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment