001/*
002 * SPDX-FileCopyrightText: none
003 * SPDX-License-Identifier: CC0-1.0
004 */
005
006package gov.nist.secauto.metaschema.model.testing;
007
008import org.apache.logging.log4j.LogManager;
009import org.apache.logging.log4j.Logger;
010import org.junit.jupiter.api.extension.ExtensionContext;
011import org.junit.jupiter.api.extension.TestWatcher;
012
013import java.lang.management.ManagementFactory;
014import java.lang.management.ThreadInfo;
015import java.lang.management.ThreadMXBean;
016import java.util.Map;
017import java.util.Optional;
018
019/**
020 * A JUnit 5 extension that detects deadlocks and dumps thread information when
021 * tests fail or are aborted (e.g., due to timeout).
022 * <p>
023 * To use this extension, annotate your test class with:
024 *
025 * <pre>
026 * {@literal @}ExtendWith(DeadlockDetectionExtension.class)
027 * </pre>
028 * <p>
029 * Or register it globally via
030 * {@code META-INF/services/org.junit.jupiter.api.extension.Extension}.
031 */
032public class DeadlockDetectionExtension implements TestWatcher {
033  private static final Logger LOGGER = LogManager.getLogger(DeadlockDetectionExtension.class);
034
035  @Override
036  public void testAborted(ExtensionContext context, Throwable cause) {
037    if (LOGGER.isErrorEnabled()) {
038      LOGGER.error("Test aborted: {} - {}", context.getDisplayName(), cause.getMessage());
039      dumpThreadInfo(context, "ABORTED");
040    }
041  }
042
043  @Override
044  public void testFailed(ExtensionContext context, Throwable cause) {
045    // Check if this looks like a timeout
046    if (isTimeoutException(cause)) {
047      if (LOGGER.isErrorEnabled()) {
048        LOGGER.error("Test timed out (possible deadlock): {}", context.getDisplayName());
049      }
050      dumpThreadInfo(context, "TIMEOUT");
051      detectDeadlocks();
052    }
053  }
054
055  private boolean isTimeoutException(Throwable cause) {
056    if (cause == null) {
057      return false;
058    }
059    String message = cause.getMessage();
060    String className = cause.getClass().getName();
061    return className.contains("Timeout")
062        || (message != null && message.toLowerCase().contains("timed out"))
063        || cause instanceof java.util.concurrent.TimeoutException
064        || isTimeoutException(cause.getCause());
065  }
066
067  private void dumpThreadInfo(ExtensionContext context, String reason) {
068    StringBuilder sb = new StringBuilder();
069    sb.append("\n");
070    sb.append("=".repeat(80)).append("\n");
071    sb.append("THREAD DUMP - Test ").append(reason).append(": ").append(context.getDisplayName()).append("\n");
072    sb.append("=".repeat(80)).append("\n\n");
073
074    // Get all thread stack traces
075    Map<Thread, StackTraceElement[]> allStackTraces = Thread.getAllStackTraces();
076
077    for (Map.Entry<Thread, StackTraceElement[]> entry : allStackTraces.entrySet()) {
078      Thread thread = entry.getKey();
079      StackTraceElement[] stackTrace = entry.getValue();
080
081      sb.append("Thread: \"").append(thread.getName()).append("\"")
082          .append(" (id=").append(thread.getId()).append(")")
083          .append(" state=").append(thread.getState())
084          .append(" daemon=").append(thread.isDaemon())
085          .append("\n");
086
087      for (StackTraceElement element : stackTrace) {
088        sb.append("\tat ").append(element).append("\n");
089      }
090      sb.append("\n");
091    }
092
093    sb.append("=".repeat(80)).append("\n");
094
095    if (LOGGER.isErrorEnabled()) {
096      LOGGER.error(sb.toString());
097    }
098  }
099
100  private void detectDeadlocks() {
101    ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean();
102    long[] deadlockedThreadIds = threadMXBean.findDeadlockedThreads();
103
104    if (deadlockedThreadIds != null && deadlockedThreadIds.length > 0) {
105      StringBuilder sb = new StringBuilder();
106      sb.append("\n");
107      sb.append("!".repeat(80)).append("\n");
108      sb.append("DEADLOCK DETECTED!\n");
109      sb.append("!".repeat(80)).append("\n\n");
110
111      ThreadInfo[] threadInfos = threadMXBean.getThreadInfo(deadlockedThreadIds, true, true);
112      for (ThreadInfo threadInfo : threadInfos) {
113        if (threadInfo != null) {
114          sb.append("Deadlocked thread: \"").append(threadInfo.getThreadName()).append("\"\n");
115          sb.append("  State: ").append(threadInfo.getThreadState()).append("\n");
116          sb.append("  Blocked on: ").append(threadInfo.getLockName()).append("\n");
117          sb.append("  Blocked by: ").append(threadInfo.getLockOwnerName()).append("\n");
118          sb.append("  Stack trace:\n");
119          for (StackTraceElement element : threadInfo.getStackTrace()) {
120            sb.append("\t\tat ").append(element).append("\n");
121          }
122          sb.append("\n");
123        }
124      }
125
126      sb.append("!".repeat(80)).append("\n");
127
128      if (LOGGER.isErrorEnabled()) {
129        LOGGER.error(sb.toString());
130      }
131    }
132  }
133
134  @Override
135  public void testDisabled(ExtensionContext context, Optional<String> reason) {
136    // No action needed for disabled tests
137  }
138
139  @Override
140  public void testSuccessful(ExtensionContext context) {
141    // No action needed for successful tests
142  }
143}