Skip to content

Instantly share code, notes, and snippets.

@eastlondoner
Last active February 2, 2021 23:37
Show Gist options
  • Save eastlondoner/63de56bc9b40a3c271c68386759e2547 to your computer and use it in GitHub Desktop.
Save eastlondoner/63de56bc9b40a3c271c68386759e2547 to your computer and use it in GitHub Desktop.
test
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.junit.platform.commons.logging.Logger;
import org.junit.platform.commons.logging.LoggerFactory;
import java.net.URI;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.UnaryOperator;
import java.util.logging.Level;
import org.neo4j.driver.AuthToken;
import org.neo4j.driver.AuthTokens;
import org.neo4j.driver.Config;
import org.neo4j.driver.Driver;
import org.neo4j.driver.GraphDatabase;
import org.neo4j.driver.Logging;
import org.neo4j.driver.Record;
import org.neo4j.driver.Session;
import org.neo4j.driver.SessionConfig;
import org.neo4j.driver.TransactionWork;
import org.neo4j.driver.async.AsyncSession;
import org.neo4j.driver.async.AsyncTransactionWork;
import org.neo4j.driver.async.ResultCursor;
import static java.util.concurrent.TimeUnit.MINUTES;
import static org.assertj.core.api.Assertions.assertThat;
@TestInstance( TestInstance.Lifecycle.PER_CLASS )
public class driversConcurrencyTest
{
private static final URI databaseURI = URI.create( YOUR_NEO4J_URL );
private static final AuthToken authToken = AuthTokens.basic( "neo4j", YOUR_NEO4J_PASSWORD );
private static final List<String> databases = List.of( "neo4j", "neo4j" );
private static final String SIMULATED_SLOW_READ = "CALL apoc.util.sleep(200) RETURN -2 AS value";
private static final String SIMULATED_SLOW_WRITE = "CALL apoc.util.sleep(500) RETURN -1 AS value";
private static final String CREATE_NODE = "CREATE (n) RETURN id(n) * 0";
private final Logger log = LoggerFactory.getLogger( this.getClass() );
private ExecutorService executor;
private static Driver graphDatabaseDriverWithConfig( URI uri, UnaryOperator<Config.ConfigBuilder> cb )
{
return GraphDatabase.driver( uri, authToken, cb.apply( Config.builder().withLogging( Logging.console( Level.WARNING ) ) ).build() );
}
private Driver getNewDriverWithAssertions()
{
return graphDatabaseDriverWithConfig( databaseURI, Config.ConfigBuilder::withEncryption );
}
@BeforeAll
void setUp()
{
try ( var checkDriver = getNewDriverWithAssertions() )
{
checkDriver.verifyConnectivity();
// Run once during setup to warm caches etc.
syncSessionsNoParalellismTest(-1);
}
executor = Executors.newWorkStealingPool( databases.size() * 2 );
}
@AfterAll
void tearDown() throws InterruptedException
{
executor.shutdownNow();
assertThat( executor.awaitTermination( 1, MINUTES ) ).isTrue();
}
@ParameterizedTest
@ValueSource( ints = {1, 2, 3, 4, 5} )
void syncSessionsNoParalellismTest( int run )
{
log.info( () -> "run " + run );
Set<Long> results = new HashSet<>();
for ( var db : databases )
{
// This creates a new driver on each pass deliberately
try ( var driverToUse = getNewDriverWithAssertions(); )
{
driverToUse.verifyConnectivity();
try ( var session = driverToUse.session( SessionConfig.forDatabase( db ) ) )
{
Long r1 = runSimulatedSlowQueryUsingReadTx( session );
results.add( r1 );
Long w1 = runSimulatedSlowQueryUsingWriteTx( session );
results.add( w1 );
Long r2 = runSimulatedSlowQueryUsingReadTx( session );
results.add( r2 );
Long w2 = writeWithoutRetries( session, tx -> tx.run( CREATE_NODE ).single().get( 0 ).asLong() );
results.add( w2 );
}
}
}
assertThat( results ).hasSize( 3 );
}
@ParameterizedTest
@ValueSource( ints = {1, 2, 3, 4, 5} )
void syncSessionPerTransactionUsingThreadsTest( int run )
{
log.info( () -> "run " + run );
Set<Long> results = new ConcurrentSkipListSet<>();
List<Future<?>> futures = new LinkedList<>();
for ( var db : databases )
{
futures.add( executor.submit(
() ->
{
// This creates a new driver on each pass deliberately
try ( var driverToUse = getNewDriverWithAssertions() )
{
{
List<Future<?>> innerFutures = new LinkedList<>();
driverToUse.verifyConnectivity();
innerFutures.add( executor.submit(
() ->
{
try ( var session = driverToUse.session( SessionConfig.forDatabase( db ) ) )
{
Long r1 = runSimulatedSlowQueryUsingReadTx( session );
results.add( r1 );
}
} )
);
innerFutures.add( executor.submit(
() ->
{
try ( var session = driverToUse.session( SessionConfig.forDatabase( db ) ) )
{
Long w1 = runSimulatedSlowQueryUsingWriteTx( session );
results.add( w1 );
}
} )
);
innerFutures.add( executor.submit(
() ->
{
try ( var session = driverToUse.session( SessionConfig.forDatabase( db ) ) )
{
Long r2 = runSimulatedSlowQueryUsingReadTx( session );
results.add( r2 );
}
} )
);
innerFutures.add( executor.submit(
() ->
{
try ( var session = driverToUse.session( SessionConfig.forDatabase( db ) ) )
{
Long w2 = writeWithoutRetries(
session, tx -> tx.run( "CREATE (n) RETURN id(n) * 0" ).single().get( 0 ).asLong() );
results.add( w2 );
}
} )
);
awaitAllOf( innerFutures );
}
}
} )
);
}
awaitAllOf( futures );
assertThat( results ).hasSize( 3 );
}
@ParameterizedTest
@ValueSource( ints = {1, 2, 3, 4, 5} )
void syncSessionPerDatabaseParallelismUsingThreadsTest( int run )
{
log.info( () -> "run " + run );
Set<Long> results = new ConcurrentSkipListSet<>();
List<Future<?>> futures = new LinkedList<>();
for ( var db : databases )
{
// This creates a new driver on each pass deliberately
futures.add(
executor.submit( () ->
{
try ( var driverToUse = getNewDriverWithAssertions() )
{
{
driverToUse.verifyConnectivity();
try ( var session = driverToUse.session( SessionConfig.forDatabase( db ) ) )
{
Long r1 = runSimulatedSlowQueryUsingReadTx( session );
results.add( r1 );
Long w1 = runSimulatedSlowQueryUsingWriteTx( session );
results.add( w1 );
Long r2 = runSimulatedSlowQueryUsingReadTx( session );
results.add( r2 );
Long w2 = writeWithoutRetries( session,
tx -> tx.run( "CREATE (n) RETURN id(n) * 0" ).single().get( 0 ).asLong() );
results.add( w2 );
}
}
}
} )
);
}
awaitAllOf( futures );
assertThat( results ).hasSize( 3 );
}
private static Long runSimulatedSlowQueryUsingWriteTx( Session session )
{
return writeWithoutRetries( session, tx -> tx.run( SIMULATED_SLOW_WRITE )
.single().get( "value" ).asLong() );
}
private static Long runSimulatedSlowQueryUsingReadTx( Session session )
{
return readWithoutRetries( session, tx -> tx.run( SIMULATED_SLOW_READ )
.single().get( "value" ).asLong() );
}
private static <T> T readWithoutRetries( Session session, TransactionWork<T> work )
{
var retryCounter = new AtomicInteger();
var retryCounterBefore = retryCounter.get();
var result = session.readTransaction(
tx ->
{
assertThat( retryCounter.getAndIncrement() ).isEqualTo( retryCounterBefore );
return work.execute( tx );
}
);
assertThat( retryCounterBefore + 1 ).isEqualTo( retryCounter.get() );
return result;
}
private static <T> T writeWithoutRetries( Session session, TransactionWork<T> work )
{
var retryCounter = new AtomicInteger();
var retryCounterBefore = retryCounter.get();
var result = session.writeTransaction(
tx ->
{
assertThat( retryCounter.getAndIncrement() ).isEqualTo( retryCounterBefore );
return work.execute( tx );
}
);
assertThat( retryCounterBefore + 1 ).isEqualTo( retryCounter.get() );
return result;
}
@ParameterizedTest
@ValueSource( ints = {1, 2, 3, 4, 5} )
void asyncSessionPerDatabaseChainedTransactionsTest( int run )
{
log.info( () -> "run " + run );
Set<Long> results = new ConcurrentSkipListSet<>();
List<CompletableFuture<Record>> promises = new LinkedList<>();
for ( var db : databases )
{
// This creates a new driver on each pass deliberately
final SessionConfig sessionConfig = SessionConfig.forDatabase( db );
promises.add(
CompletableFuture
.supplyAsync( this::getNewDriverWithAssertions, executor )
.thenCompose(
driverToUse ->
driverToUse.verifyConnectivityAsync()
.thenApply( ignored -> driverToUse.asyncSession( sessionConfig ) )
.thenCompose(
session ->
{
var query = session.readTransactionAsync(
tx -> tx.runAsync( SIMULATED_SLOW_READ )
.thenCompose( ResultCursor::singleAsync )
.thenCompose( result -> tx.commitAsync()
.thenApply( ignored -> result )
)
).thenApply( result -> results.add( result.get( "value" ).asLong() ) );
query = query.thenCompose(
ignored -> session.writeTransactionAsync(
tx -> tx.runAsync(
SIMULATED_SLOW_WRITE )
.thenCompose( ResultCursor::singleAsync )
.thenCompose( result -> tx.commitAsync()
.thenApply( ignored2 -> result )
)
)
).thenApply( result -> results.add( result.get( "value" ).asLong() ) );
query = query.thenCompose(
ignored -> session.readTransactionAsync(
tx -> tx.runAsync( SIMULATED_SLOW_READ )
.thenCompose( ResultCursor::singleAsync )
.thenCompose( result -> tx.commitAsync()
.thenApply( ignored2 -> result )
)
)
).thenApply( result -> results.add( result.get( "value" ).asLong() ) );
var finalQuery = query.thenCompose(
ignored -> session.writeTransactionAsync(
tx -> tx.runAsync( "CREATE (n) RETURN id(n) * 0" )
.thenCompose( ResultCursor::singleAsync )
.thenCompose( result -> tx.commitAsync()
.thenApply( ignored2 -> result ) )
)
).thenApply( result ->
{
results.add( result.get( 0 ).asLong() );
return result;
} );
return finalQuery.handle( closeSessionAndRethrowExceptions( session, driverToUse ) )
.thenCompose( i -> i );
} )
)
);
}
CompletableFuture<?>[] cfs = promises.toArray( CompletableFuture[]::new );
assertThat( CompletableFuture.allOf( cfs ) )
.succeedsWithin( 5, TimeUnit.MINUTES );
assertThat( results ).hasSize( 3 );
}
@ParameterizedTest
@ValueSource( ints = {1, 2, 3, 4, 5} )
void asyncSessionPerTransactionTest( int run )
{
log.info( () -> "run " + run );
Set<Long> results = new ConcurrentSkipListSet<>();
List<CompletableFuture<?>> promises = new LinkedList<>();
for ( var db : databases )
{
// This creates a new driver on each pass deliberately
final SessionConfig sessionConfig = SessionConfig.forDatabase( db );
promises.add(
CompletableFuture
.supplyAsync( this::getNewDriverWithAssertions, executor )
.thenCompose(
driverToUse ->
driverToUse.verifyConnectivityAsync()
.thenCompose(
ignored ->
{
var r1 = runInSingleSessionAsReadWithoutRetries( driverToUse, sessionConfig,
SIMULATED_SLOW_READ,
ResultCursor::singleAsync )
.thenApply( r -> results.add( r.get( "value" ).asLong() ) );
var w1 = runInSingleSessionAsWriteWithoutRetries( driverToUse, sessionConfig,
SIMULATED_SLOW_WRITE,
ResultCursor::singleAsync )
.thenApply( r -> results.add( r.get( "value" ).asLong() ) );
var r2 = runInSingleSessionAsReadWithoutRetries( driverToUse, sessionConfig,
SIMULATED_SLOW_READ,
ResultCursor::singleAsync )
.thenApply( r -> results.add( r.get( "value" ).asLong() ) );
var w2 = runInSingleSessionAsWriteWithoutRetries( driverToUse, sessionConfig,
"CREATE (n) RETURN id(n) * 0",
ResultCursor::singleAsync )
.thenApply( r -> results.add( r.get( 0 ).asLong() ) );
return CompletableFuture.allOf( r1.toCompletableFuture(), r2.toCompletableFuture(),
w1.toCompletableFuture(), w2.toCompletableFuture() );
}
)
.handle( closeSessionAndRethrowExceptions( driverToUse ) )
.toCompletableFuture()
)
);
}
assertThat( CompletableFuture.allOf( promises.toArray( CompletableFuture[]::new ) ) )
.succeedsWithin( 5, TimeUnit.MINUTES );
assertThat( results ).hasSize( 3 );
}
private <U> CompletionStage<U> runInSingleSessionAsWriteWithoutRetries(
Driver driver, SessionConfig sessionConfig, String query, Function<ResultCursor,? extends CompletionStage<U>> cursorFunction )
{
return runWithoutRetries( driver, sessionConfig, true, query, cursorFunction );
}
private <U> CompletionStage<U> runInSingleSessionAsReadWithoutRetries(
Driver driver, SessionConfig sessionConfig, String query, Function<ResultCursor,? extends CompletionStage<U>> cursorFunction )
{
return runWithoutRetries( driver, sessionConfig, false
, query, cursorFunction );
}
private <U> CompletionStage<U> runWithoutRetries(
Driver driver, SessionConfig sessionConfig, boolean isWrite,
String query, Function<ResultCursor,? extends CompletionStage<U>> cursorFunction )
{
final var retryCounter = new AtomicInteger();
final var asyncSession = driver.asyncSession( sessionConfig );
final Function<AsyncTransactionWork<CompletionStage<U>>,CompletionStage<U>> fn =
isWrite ? asyncSession::writeTransactionAsync : asyncSession::readTransactionAsync;
return fn.apply( tx ->
{
assertThat( retryCounter.incrementAndGet() ).isEqualTo( 1 );
return tx.runAsync( query )
.thenCompose( cursorFunction )
.thenCompose( r -> tx.commitAsync().thenApply( ignored -> r ) );
}
).handle( closeSessionAndRethrowExceptions( asyncSession ) ).thenCompose( i -> i );
}
private <U> BiFunction<? super U,Throwable,CompletionStage<? extends U>> closeSessionAndRethrowExceptions( Driver driver )
{
return ( originalResult, originalException ) ->
driver.closeAsync()
.handle( ( ignored, driverCloseException ) ->
{
if ( driverCloseException != null && originalException != null )
{
driverCloseException.addSuppressed( originalException );
throw new RuntimeException( driverCloseException );
}
else if ( driverCloseException != null )
{
throw new RuntimeException( driverCloseException );
}
else if ( originalException != null )
{
throw new RuntimeException( originalException );
}
return originalResult;
}
);
}
private <U> BiFunction<U,Throwable,CompletionStage<U>> closeSessionAndRethrowExceptions( AsyncSession session )
{
return ( originalResult, originalException ) ->
session.closeAsync()
.handle( ( ignored, sessionCloseException ) ->
{
if ( sessionCloseException != null && originalException != null )
{
sessionCloseException.addSuppressed( originalException );
throw new RuntimeException( sessionCloseException );
}
else if ( sessionCloseException != null )
{
throw new RuntimeException( sessionCloseException );
}
else if ( originalException != null )
{
throw new RuntimeException( originalException );
}
return originalResult;
}
);
}
private <U> BiFunction<U,Throwable,CompletionStage<U>> closeSessionAndRethrowExceptions( AsyncSession session, Driver driver )
{
return ( U originalResult, Throwable originalException ) ->
closeSessionAndRethrowExceptions( session )
.apply( originalResult, originalException )
.handle( (closeSessionAndRethrowExceptions( driver )) ).thenCompose( i -> (CompletionStage<U>) i );
}
private static void awaitAllOf( List<Future<?>> futures )
{
RuntimeException collectedException = new RuntimeException( "error in async commands" );
for ( var future : futures )
{
try
{
future.get( 1, TimeUnit.MINUTES );
}
catch ( Throwable t )
{
collectedException.addSuppressed( t );
}
}
if ( collectedException.getSuppressed().length > 0 )
{
throw collectedException;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment