1   /*
2    * SPDX-FileCopyrightText: none
3    * SPDX-License-Identifier: CC0-1.0
4    */
5   
6   package gov.nist.secauto.metaschema.model.testing;
7   
8   import org.apache.logging.log4j.LogManager;
9   import org.apache.logging.log4j.Logger;
10  import org.junit.jupiter.api.extension.ExtensionContext;
11  import org.junit.jupiter.api.extension.TestWatcher;
12  
13  import java.lang.management.ManagementFactory;
14  import java.lang.management.ThreadInfo;
15  import java.lang.management.ThreadMXBean;
16  import java.util.Map;
17  import java.util.Optional;
18  
19  /**
20   * A JUnit 5 extension that detects deadlocks and dumps thread information when
21   * tests fail or are aborted (e.g., due to timeout).
22   * <p>
23   * To use this extension, annotate your test class with:
24   *
25   * <pre>
26   * {@literal @}ExtendWith(DeadlockDetectionExtension.class)
27   * </pre>
28   * <p>
29   * Or register it globally via
30   * {@code META-INF/services/org.junit.jupiter.api.extension.Extension}.
31   */
32  public class DeadlockDetectionExtension implements TestWatcher {
33    private static final Logger LOGGER = LogManager.getLogger(DeadlockDetectionExtension.class);
34  
35    @Override
36    public void testAborted(ExtensionContext context, Throwable cause) {
37      if (LOGGER.isErrorEnabled()) {
38        LOGGER.error("Test aborted: {} - {}", context.getDisplayName(), cause.getMessage());
39        dumpThreadInfo(context, "ABORTED");
40      }
41    }
42  
43    @Override
44    public void testFailed(ExtensionContext context, Throwable cause) {
45      // Check if this looks like a timeout
46      if (isTimeoutException(cause)) {
47        if (LOGGER.isErrorEnabled()) {
48          LOGGER.error("Test timed out (possible deadlock): {}", context.getDisplayName());
49        }
50        dumpThreadInfo(context, "TIMEOUT");
51        detectDeadlocks();
52      }
53    }
54  
55    private boolean isTimeoutException(Throwable cause) {
56      if (cause == null) {
57        return false;
58      }
59      String message = cause.getMessage();
60      String className = cause.getClass().getName();
61      return className.contains("Timeout")
62          || (message != null && message.toLowerCase().contains("timed out"))
63          || cause instanceof java.util.concurrent.TimeoutException
64          || isTimeoutException(cause.getCause());
65    }
66  
67    private void dumpThreadInfo(ExtensionContext context, String reason) {
68      StringBuilder sb = new StringBuilder();
69      sb.append("\n");
70      sb.append("=".repeat(80)).append("\n");
71      sb.append("THREAD DUMP - Test ").append(reason).append(": ").append(context.getDisplayName()).append("\n");
72      sb.append("=".repeat(80)).append("\n\n");
73  
74      // Get all thread stack traces
75      Map<Thread, StackTraceElement[]> allStackTraces = Thread.getAllStackTraces();
76  
77      for (Map.Entry<Thread, StackTraceElement[]> entry : allStackTraces.entrySet()) {
78        Thread thread = entry.getKey();
79        StackTraceElement[] stackTrace = entry.getValue();
80  
81        sb.append("Thread: \"").append(thread.getName()).append("\"")
82            .append(" (id=").append(thread.getId()).append(")")
83            .append(" state=").append(thread.getState())
84            .append(" daemon=").append(thread.isDaemon())
85            .append("\n");
86  
87        for (StackTraceElement element : stackTrace) {
88          sb.append("\tat ").append(element).append("\n");
89        }
90        sb.append("\n");
91      }
92  
93      sb.append("=".repeat(80)).append("\n");
94  
95      if (LOGGER.isErrorEnabled()) {
96        LOGGER.error(sb.toString());
97      }
98    }
99  
100   private void detectDeadlocks() {
101     ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean();
102     long[] deadlockedThreadIds = threadMXBean.findDeadlockedThreads();
103 
104     if (deadlockedThreadIds != null && deadlockedThreadIds.length > 0) {
105       StringBuilder sb = new StringBuilder();
106       sb.append("\n");
107       sb.append("!".repeat(80)).append("\n");
108       sb.append("DEADLOCK DETECTED!\n");
109       sb.append("!".repeat(80)).append("\n\n");
110 
111       ThreadInfo[] threadInfos = threadMXBean.getThreadInfo(deadlockedThreadIds, true, true);
112       for (ThreadInfo threadInfo : threadInfos) {
113         if (threadInfo != null) {
114           sb.append("Deadlocked thread: \"").append(threadInfo.getThreadName()).append("\"\n");
115           sb.append("  State: ").append(threadInfo.getThreadState()).append("\n");
116           sb.append("  Blocked on: ").append(threadInfo.getLockName()).append("\n");
117           sb.append("  Blocked by: ").append(threadInfo.getLockOwnerName()).append("\n");
118           sb.append("  Stack trace:\n");
119           for (StackTraceElement element : threadInfo.getStackTrace()) {
120             sb.append("\t\tat ").append(element).append("\n");
121           }
122           sb.append("\n");
123         }
124       }
125 
126       sb.append("!".repeat(80)).append("\n");
127 
128       if (LOGGER.isErrorEnabled()) {
129         LOGGER.error(sb.toString());
130       }
131     }
132   }
133 
134   @Override
135   public void testDisabled(ExtensionContext context, Optional<String> reason) {
136     // No action needed for disabled tests
137   }
138 
139   @Override
140   public void testSuccessful(ExtensionContext context) {
141     // No action needed for successful tests
142   }
143 }