I would recommend Dask for this sort of work as it is easy to produce a task graph and minimise computation overhead
import dask
def A(argument):
print('Running A with', argument)
return argument
def B(argument_one, argument_two):
print('Running B with', argument_one, argument_two)
return argument_one*argument_two
# I produce a mapping between A's input argument and the dask Delayed object
# that will compute its result.
a_results = {
a_argument: dask.delayed(A)(a_argument)
for a_argument in range(1, 4)
}
# Build an interable of the outputs you desire to compute all at once. This ensures
# that no result is computed more than once.
b_results = [
dask.delayed(B)(a_results[i-1], a_results[i])
for i in range(2, 4)
]
print('Results:', dask.compute(*b_results))
Output:
Running A with 1
Running A with 2
Running A with 3
Running B with 1 2
Running B with 2 3
Results: (2, 6)
Result one is A(1)=1 * A(2)=2 = 2
Result two is A(2)=2 * A(3)=3 = 6
From the printing, you can see that A was only executed once per input.
Note: as Dask parallelises the tasks, your stdout may be scrambled or in a different order to mine.