RecursionCollectingNodeItemVisitor.java

/*
 * SPDX-FileCopyrightText: none
 * SPDX-License-Identifier: CC0-1.0
 */

package gov.nist.secauto.metaschema.core.metapath.item.node;

import gov.nist.secauto.metaschema.core.model.IAssemblyDefinition;
import gov.nist.secauto.metaschema.core.model.IModule;
import gov.nist.secauto.metaschema.core.util.ObjectUtils;

import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import edu.umd.cs.findbugs.annotations.NonNull;

public class RecursionCollectingNodeItemVisitor
    extends AbstractRecursionPreventingNodeItemVisitor<Void, Void> {

  @SuppressWarnings("PMD.UseConcurrentHashMap")
  @NonNull
  private final Map<IAssemblyDefinition, AssemblyRecord> assemblyAnalysis = new LinkedHashMap<>();

  /**
   * Get the identified assembly definitions that recurse.
   *
   * @return the definitions that recurse
   * @see AssemblyRecord#isRecursive()
   */
  @NonNull
  public Set<AssemblyRecord> getRecursiveAssemblyDefinitions() {
    return ObjectUtils.notNull(assemblyAnalysis.values().stream()
        .filter(AssemblyRecord::isRecursive)
        .collect(Collectors.toSet()));
  }

  /**
   * Visit the provided module.
   *
   * @param module
   *          the Metaschema module to visit
   */
  public void visit(@NonNull IModule module) {
    visitMetaschema(INodeItemFactory.instance().newModuleNodeItem(module), null);
  }

  @Override
  public Void visitAssembly(IAssemblyNodeItem item, Void context) {
    IAssemblyDefinition definition = item.getDefinition();

    // get the assembly record from the cache
    AssemblyRecord record = assemblyAnalysis.get(definition);
    if (record == null) {
      record = new AssemblyRecord(definition);
      assemblyAnalysis.put(definition, record);
    } else if (isDecendant(item, definition)) {
      record.markRecursive();
      record.addLocation(item);
    }
    return super.visitAssembly(item, context);
  }

  @Override
  public Void visitAssembly(IAssemblyInstanceGroupedNodeItem item, Void context) {
    return visitAssembly((IAssemblyNodeItem) item, context);
  }

  @Override
  protected Void defaultResult() {
    return null;
  }

  public static final class AssemblyRecord {
    @NonNull
    private final IAssemblyDefinition definition;
    private boolean recursive; // false
    @NonNull
    private final List<IDefinitionNodeItem<?, ?>> locations = new LinkedList<>();

    private AssemblyRecord(@NonNull IAssemblyDefinition definition) {
      this.definition = definition;
    }

    /**
     * Get the definition associated with the record.
     *
     * @return the definition
     */
    @NonNull
    public IAssemblyDefinition getDefinition() {
      return definition;
    }

    /**
     * Determine if the definition associated with the record is a descendant of
     * itself.
     *
     * @return {@code true} if the definition is a descendant of itself or
     *         {@code false} otherwise
     */
    public boolean isRecursive() {
      return recursive;
    }

    /**
     * Mark the record as recursive.
     *
     * @see #isRecursive()
     */
    private void markRecursive() {
      recursive = true;
    }

    /**
     * Get the node locations where the definition associated with this record is
     * used.
     *
     * @return the node locations
     */
    @NonNull
    public List<IDefinitionNodeItem<?, ?>> getLocations() {
      return locations;
    }

    /**
     * Associate the provided location with the definition associated with the
     * record.
     *
     * @param location
     *          the location to associate
     */
    public void addLocation(@NonNull IDefinitionNodeItem<?, ?> location) {
      this.locations.add(location);
    }
  }

}