2222import org .hibernate .tool .schema .spi .SchemaManagementToolCoordinator ;
2323import org .hibernate .tool .schema .spi .SchemaManagementToolCoordinator .ActionGrouping ;
2424import org .jboss .logging .Logger ;
25+ import org .junit .jupiter .api .extension .AfterAllCallback ;
26+ import org .junit .jupiter .api .extension .AfterEachCallback ;
27+ import org .junit .jupiter .api .extension .BeforeAllCallback ;
2528import org .junit .jupiter .api .extension .BeforeEachCallback ;
2629import org .junit .jupiter .api .extension .ExtensionContext ;
2730import org .junit .jupiter .api .extension .TestExecutionExceptionHandler ;
4346 * @see DomainModelExtension
4447 *
4548 * @author Steve Ebersole
49+ * @author inpink
4650 */
4751public class SessionFactoryExtension
48- implements TestInstancePostProcessor , BeforeEachCallback , TestExecutionExceptionHandler {
52+ implements TestInstancePostProcessor , BeforeAllCallback , BeforeEachCallback ,
53+ AfterEachCallback , AfterAllCallback , TestExecutionExceptionHandler {
4954
5055 private static final Logger log = Logger .getLogger ( SessionFactoryExtension .class );
5156 private static final String SESSION_FACTORY_KEY = SessionFactoryScope .class .getName ();
@@ -74,7 +79,7 @@ public void postProcessTestInstance(Object testInstance, ExtensionContext contex
7479 );
7580
7681 if ( sfAnnRef .isPresent ()
77- || SessionFactoryProducer .class .isAssignableFrom ( context .getRequiredTestClass () ) ) {
82+ || SessionFactoryProducer .class .isAssignableFrom ( context .getRequiredTestClass () ) ) {
7883 final DomainModelScope domainModelScope = DomainModelExtension .getOrCreateDomainModelScope ( testInstance , context );
7984 final SessionFactoryScope created = createSessionFactoryScope ( testInstance , sfAnnRef , domainModelScope , context );
8085 locateExtensionStore ( testInstance , context ).put ( SESSION_FACTORY_KEY , created );
@@ -83,6 +88,9 @@ public void postProcessTestInstance(Object testInstance, ExtensionContext contex
8388
8489 @ Override
8590 public void beforeEach (ExtensionContext context ) {
91+ // Handle BEFORE_EACH drop data timing
92+ handleDropData (context , DropDataTiming .BEFORE_EACH );
93+
8694 final Optional <SessionFactory > sfAnnRef = AnnotationSupport .findAnnotation (
8795 context .getRequiredTestMethod (),
8896 SessionFactory .class
@@ -231,6 +239,42 @@ public void handleTestExecutionException(ExtensionContext context, Throwable thr
231239 throw throwable ;
232240 }
233241
242+ @ Override
243+ public void beforeAll (ExtensionContext context ) throws Exception {
244+ handleDropData (context , DropDataTiming .BEFORE_ALL );
245+ }
246+
247+ @ Override
248+ public void afterEach (ExtensionContext context ) throws Exception {
249+ handleDropData (context , DropDataTiming .AFTER_EACH );
250+ }
251+
252+ @ Override
253+ public void afterAll (ExtensionContext context ) throws Exception {
254+ handleDropData (context , DropDataTiming .AFTER_ALL );
255+ }
256+
257+ private void handleDropData (ExtensionContext context , DropDataTiming timing ) {
258+ try {
259+ final Object testInstance = context .getRequiredTestInstance ();
260+ final SessionFactoryScope scope = findSessionFactoryScope (testInstance , context );
261+
262+ final Optional <SessionFactory > sfAnnRef = AnnotationSupport .findAnnotation (
263+ context .getRequiredTestClass (),
264+ SessionFactory .class
265+ );
266+
267+ if (sfAnnRef .isPresent ()) {
268+ DropDataTiming configuredTiming = sfAnnRef .get ().dropTestData ();
269+ if (configuredTiming == timing ) {
270+ scope .dropData ();
271+ }
272+ }
273+ } catch (Exception e ) {
274+ log .debugf ("Failed to drop data at timing %s: %s" , timing , e .getMessage ());
275+ }
276+ }
277+
234278 private static class SessionFactoryScopeImpl implements SessionFactoryScope , AutoCloseable {
235279 private final DomainModelScope modelScope ;
236280 private final SessionFactoryProducer producer ;
0 commit comments