Last active
January 2, 2022 17:25
-
-
Save viniciusrplima/890b01be943e4bf7733b3ab47d2914ac to your computer and use it in GitHub Desktop.
Componente para limpar as tabelas e resetar as sequencias do banco de dados postgres (mantém os dados do flyway)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Baseado em: https://gist.github.com/thiagofa/cff61c709277f48a241c145116b92ec1 | |
import java.sql.*; | |
import java.util.ArrayList; | |
import java.util.List; | |
import javax.sql.DataSource; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import org.springframework.beans.factory.annotation.Autowired; | |
import org.springframework.stereotype.Component; | |
@Component | |
public class DatabaseCleaner { | |
private final Logger logger = LoggerFactory.getLogger(getClass()); | |
@Autowired | |
private DataSource dataSource; | |
private Connection connection; | |
public void clearTablesAndResetSequences() { | |
try (Connection connection = dataSource.getConnection()) { | |
this.connection = connection; | |
checkTestDatabase(); | |
tryToClearTables(); | |
tryToResetSequences(); | |
} catch (SQLException e) { | |
throw new RuntimeException(e); | |
} finally { | |
this.connection = null; | |
} | |
} | |
private void checkTestDatabase() throws SQLException { | |
String catalog = connection.getCatalog(); | |
if (catalog == null || !catalog.endsWith("test")) { | |
throw new RuntimeException( | |
"Cannot clear database tables because '" + catalog + "' is not a test database (suffix 'test' not found)."); | |
} | |
} | |
private void tryToClearTables() throws SQLException { | |
List<String> tableNames = getTableNames(); | |
clear(tableNames); | |
} | |
private void tryToResetSequences() throws SQLException { | |
List<String> sequenceNames = getSequenceNames(); | |
reset(sequenceNames); | |
} | |
private List<String> getTableNames() throws SQLException { | |
List<String> tableNames = new ArrayList<>(); | |
DatabaseMetaData metaData = connection.getMetaData(); | |
ResultSet rs = metaData.getTables(connection.getCatalog(), null, null, new String[] { "TABLE" }); | |
while (rs.next()) { | |
tableNames.add(rs.getString("TABLE_NAME")); | |
} | |
tableNames.remove("flyway_schema_history"); | |
return tableNames; | |
} | |
private List<String> getSequenceNames() throws SQLException { | |
List<String> sequenceNames = new ArrayList<>(); | |
String sql = "SELECT sequencename FROM pg_sequences;"; | |
PreparedStatement statement = connection.prepareStatement(sql); | |
ResultSet rs = statement.executeQuery(); | |
while (rs.next()) { | |
sequenceNames.add(rs.getString("sequencename")); | |
} | |
return sequenceNames; | |
} | |
private void clear(List<String> tableNames) throws SQLException { | |
Statement statement = buildSqlClearTables(tableNames); | |
logger.debug("Executing SQL"); | |
statement.executeBatch(); | |
} | |
private void reset(List<String> sequenceNames) throws SQLException { | |
Statement statement = buildSqlResetIndexes(sequenceNames); | |
logger.debug("Executing SQL"); | |
statement.executeBatch(); | |
} | |
private Statement buildSqlResetIndexes(List<String> sequenceNames) throws SQLException { | |
Statement statement = connection.createStatement(); | |
addRestartStatements(sequenceNames, statement); | |
return statement; | |
} | |
private void addRestartStatements(List<String> sequenceNames, Statement statement) { | |
sequenceNames.forEach(sequenceName -> { | |
try { | |
statement.addBatch(sql("ALTER SEQUENCE " + sequenceName + " RESTART")); | |
} catch (SQLException e) { | |
throw new RuntimeException(e); | |
} | |
}); | |
} | |
private Statement buildSqlClearTables(List<String> tableNames) throws SQLException { | |
Statement statement = connection.createStatement(); | |
statement.addBatch(sql("SET session_replication_role = replica")); | |
addTruncateSatements(tableNames, statement); | |
statement.addBatch(sql("SET session_replication_role = DEFAULT")); | |
return statement; | |
} | |
private void addTruncateSatements(List<String> tableNames, Statement statement) { | |
tableNames.forEach(tableName -> { | |
try { | |
statement.addBatch(sql("TRUNCATE TABLE " + tableName + " CASCADE")); | |
} catch (SQLException e) { | |
throw new RuntimeException(e); | |
} | |
}); | |
} | |
private String sql(String sql) { | |
logger.debug("Adding SQL: {}", sql); | |
return sql; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment