Rewrite postgres test container

This commit is contained in:
2025-08-12 13:01:15 +02:00
committed by swordsteel
parent 4bc14b5340
commit 2328e7ebe2
10 changed files with 147 additions and 170 deletions

View File

@@ -1,12 +0,0 @@
package ltd.hlaeja.test.container
import org.springframework.test.context.ContextConfiguration
import org.springframework.test.context.TestExecutionListeners
import org.springframework.test.context.TestExecutionListeners.MergeMode.MERGE_WITH_DEFAULTS
@Suppress("unused")
@Retention(AnnotationRetention.RUNTIME)
@Target(AnnotationTarget.CLASS)
@ContextConfiguration(initializers = [PostgresInitializer::class])
@TestExecutionListeners(listeners = [PostgresInitializer::class], mergeMode = MERGE_WITH_DEFAULTS)
annotation class PostgresContainer

View File

@@ -1,55 +0,0 @@
package ltd.hlaeja.test.container
import io.github.oshai.kotlinlogging.KotlinLogging
import java.io.BufferedReader
import java.io.InputStream
import java.io.InputStreamReader
import kotlinx.coroutines.runBlocking
import ltd.hlaeja.test.util.isResourceFile
import org.springframework.r2dbc.core.DatabaseClient
import org.springframework.r2dbc.core.await
import org.springframework.stereotype.Component
private val log = KotlinLogging.logger {}
@Component
class PostgresExecutor(
private val databaseClient: DatabaseClient,
) {
fun executeSqlFile(
sqlFile: String,
) = runBlocking {
sqlFile.isResourceFile()
?.inputStream
?.use {
log.debug { "Executing SQL file: $sqlFile" }
executeSqlStatements(makeSqlStatements(it))
}
?: log.debug { "SQL file not found or not readable: $sqlFile" }
}
@Suppress("TooGenericExceptionThrown")
private suspend fun executeSqlStatements(
statements: List<String>,
) = try {
statements.forEach { statement ->
log.debug { "Running statement: $statement" }
databaseClient.sql(statement).await()
}
} catch (e: Exception) {
throw RuntimeException("Failed to execute SQL statements", e)
}
private suspend fun makeSqlStatements(
inputStream: InputStream,
): List<String> = BufferedReader(InputStreamReader(inputStream))
.lines()
.filter { it.isNotEmpty() && !it.startsWith("--") }
.map { it.trim() }
.toList()
.joinToString(" ")
.split(';')
.filter { it.isNotBlank() }
.map { "${it.trim()};" }
}

View File

@@ -1,67 +0,0 @@
package ltd.hlaeja.test.container
import io.github.oshai.kotlinlogging.KotlinLogging
import ltd.hlaeja.test.util.getProperty
import ltd.hlaeja.test.util.isResourceFile
import org.springframework.boot.test.util.TestPropertyValues
import org.springframework.context.ApplicationContextInitializer
import org.springframework.context.ConfigurableApplicationContext
import org.springframework.test.context.TestContext
import org.springframework.test.context.TestExecutionListener
import org.testcontainers.containers.PostgreSQLContainer
private val log = KotlinLogging.logger {}
@Suppress("unused")
class PostgresInitializer : ApplicationContextInitializer<ConfigurableApplicationContext>, TestExecutionListener {
companion object {
const val SCRIPT_INIT = "container.postgres.init"
const val SCRIPT_BEFORE = "container.postgres.before"
const val SCRIPT_AFTER = "container.postgres.after"
const val POSTGRES_VERSION = "container.postgres.version"
const val POSTGRES_LATEST = "postgres:latest"
}
override fun initialize(
context: ConfigurableApplicationContext,
) {
postgres(context).apply {
TestPropertyValues.of(
"spring.r2dbc.url=r2dbc:pool:postgresql://$host:$firstMappedPort/$databaseName",
"spring.r2dbc.username=$username",
"spring.r2dbc.password=$password",
).applyTo(context)
}
}
override fun beforeTestClass(
context: TestContext,
) {
context.testClass
.also { log.debug { "Starting execution before class: ${it.simpleName}" } }
.getAnnotation(PostgresContainer::class.java) ?: return
context.getProperty(SCRIPT_BEFORE)
?.let { context.applicationContext.getBean(PostgresExecutor::class.java).executeSqlFile(it) }
}
override fun afterTestClass(
context: TestContext,
) {
context.testClass
.also { log.debug { "Starting execution after class: ${it.simpleName}" } }
.getAnnotation(PostgresContainer::class.java) ?: return
context.getProperty(SCRIPT_AFTER)
?.let { context.applicationContext.getBean(PostgresExecutor::class.java).executeSqlFile(it) }
}
private fun postgres(
context: ConfigurableApplicationContext,
): PostgreSQLContainer<*> = PostgreSQLContainer(context.getProperty(POSTGRES_VERSION, POSTGRES_LATEST)).apply {
context.getProperty(SCRIPT_INIT)
?.isResourceFile()
?.let { lala -> withInitScript(lala.path) }
?: log.error { "Postgres init script not found" }
start()
}
}

View File

@@ -0,0 +1,15 @@
package ltd.hlaeja.test.container
import ltd.hlaeja.test.container.extension.PostgresTestExtension
import ltd.hlaeja.test.container.postgres.PostgresTestListener
import org.junit.jupiter.api.extension.ExtendWith
import org.springframework.test.context.ContextConfiguration
import org.springframework.test.context.TestExecutionListeners
import org.springframework.test.context.TestExecutionListeners.MergeMode
@Target(AnnotationTarget.CLASS)
@Retention(AnnotationRetention.RUNTIME)
@ExtendWith(PostgresTestExtension::class)
@ContextConfiguration(initializers = [PostgresTestExtension::class])
@TestExecutionListeners(listeners = [PostgresTestListener::class], mergeMode = MergeMode.MERGE_WITH_DEFAULTS)
annotation class PostgresTestContainer

View File

@@ -0,0 +1,21 @@
package ltd.hlaeja.test.container.extension
import ltd.hlaeja.test.container.postgres.TestContainerPostgres
import org.junit.jupiter.api.extension.BeforeAllCallback
import org.junit.jupiter.api.extension.ExtensionContext
import org.springframework.boot.test.util.TestPropertyValues
import org.springframework.context.ApplicationContextInitializer
import org.springframework.context.ConfigurableApplicationContext
class PostgresTestExtension : BeforeAllCallback, ApplicationContextInitializer<ConfigurableApplicationContext> {
override fun initialize(applicationContext: ConfigurableApplicationContext) {
TestPropertyValues.of(TestContainerPostgres.props()).applyTo(applicationContext.environment)
}
override fun beforeAll(context: ExtensionContext) {
if (!TestContainerPostgres.postgres.isRunning) {
TestContainerPostgres.postgres.start()
}
}
}

View File

@@ -0,0 +1,24 @@
package ltd.hlaeja.test.container.postgres
import ltd.hlaeja.test.container.postgres.TestContainerPostgres.sqlFile
import ltd.hlaeja.test.container.util.hasAnnotation
import org.junit.jupiter.api.Nested
import org.springframework.test.context.TestContext
import org.springframework.test.context.TestExecutionListener
class PostgresTestListener : TestExecutionListener {
override fun beforeTestClass(
context: TestContext,
) {
if (context.testClass.hasAnnotation<Nested>()) return
sqlFile("postgres/data.sql", context)
}
override fun afterTestClass(
context: TestContext,
) {
if (context.testClass.hasAnnotation<Nested>()) return
sqlFile("postgres/reset.sql", context)
}
}

View File

@@ -0,0 +1,68 @@
package ltd.hlaeja.test.container.postgres
import java.io.BufferedReader
import java.io.InputStream
import java.io.InputStreamReader
import kotlinx.coroutines.runBlocking
import ltd.hlaeja.test.container.util.isResourceFile
import org.springframework.r2dbc.core.DatabaseClient
import org.springframework.r2dbc.core.await
import org.springframework.test.context.TestContext
import org.testcontainers.containers.PostgreSQLContainer
import org.testcontainers.utility.DockerImageName
object TestContainerPostgres {
val postgres = PostgreSQLContainer(DockerImageName.parse("postgres:17"))
.withReuse(true)
.apply {
withDatabaseName("testdb")
withUsername("test")
withPassword("test")
"postgres/schema.sql".isResourceFile()?.let { withInitScript(it.path) }
}
fun props(): Map<String, String> = postgres.let {
mapOf(
"spring.r2dbc.url" to "r2dbc:postgresql://${it.host}:${it.firstMappedPort}/${it.databaseName}",
"spring.r2dbc.username" to it.username,
"spring.r2dbc.password" to it.password,
)
}
fun sqlFile(
sqlFile: String,
context: TestContext,
): Unit = runBlocking {
sqlFile.isResourceFile()
?.inputStream
?.use {
executeSqlStatements(
makeSqlStatements(it),
context.applicationContext.getBean(DatabaseClient::class.java),
)
}
}
@Suppress("TooGenericExceptionThrown", "SqlSourceToSinkFlow")
private suspend fun executeSqlStatements(
statements: List<String>,
databaseClient: DatabaseClient,
) = try {
statements.forEach { databaseClient.sql(it).await() }
} catch (e: Exception) {
throw RuntimeException("Failed to execute SQL statements", e)
}
private fun makeSqlStatements(
inputStream: InputStream,
): List<String> = BufferedReader(InputStreamReader(inputStream))
.lines()
.filter { it.isNotEmpty() && !it.startsWith("--") }
.map { it.trim() }
.toList()
.joinToString(" ")
.split(';')
.filter { it.isNotBlank() }
.map { "${it.trim()};" }
}

View File

@@ -0,0 +1,12 @@
package ltd.hlaeja.test.container.util
import org.springframework.core.io.ClassPathResource
fun String.isResourceFile(): ClassPathResource? = ClassPathResource(this).let { resource ->
when {
resource.exists() && resource.isReadable -> resource
else -> null
}
}
inline fun <reified T : Annotation> Class<*>.hasAnnotation(): Boolean = getAnnotation(T::class.java) != null

View File

@@ -1,26 +0,0 @@
package ltd.hlaeja.test.util
import org.springframework.context.ConfigurableApplicationContext
import org.springframework.core.io.ClassPathResource
import org.springframework.test.context.TestContext
fun ConfigurableApplicationContext.getProperty(
property: String,
): String? = this.environment.getProperty(property)
fun ConfigurableApplicationContext.getProperty(
property: String,
default: String,
): String = this.environment.getProperty(property, default)
fun TestContext.getProperty(
property: String,
): String? = this.applicationContext.environment.getProperty(property)
fun String.isResourceFile(): ClassPathResource? {
val resource = ClassPathResource(this)
return when {
resource.exists() && resource.isReadable -> resource
else -> null
}
}