NvidiaGPUPluginForRuntimeV2.java

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * <p>
 * http://www.apache.org/licenses/LICENSE-2.0
 * <p>
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.com.nvidia;

import org.apache.hadoop.classification.VisibleForTesting;
import org.apache.hadoop.thirdparty.com.google.common.collect.ImmutableSet;
import org.apache.hadoop.util.Shell;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

/**
 * Nvidia GPU plugin supporting both Nvidia container runtime v2 for Docker and
 * non-Docker container.
 * It has topology aware as well as simple scheduling ability.
 * */
public class NvidiaGPUPluginForRuntimeV2 implements DevicePlugin,
    DevicePluginScheduler {
  public static final Logger LOG = LoggerFactory.getLogger(
      NvidiaGPUPluginForRuntimeV2.class);

  public static final String NV_RESOURCE_NAME = "nvidia.com/gpu";

  private NvidiaCommandExecutor shellExecutor = new NvidiaCommandExecutor();

  private Map<String, String> environment = new HashMap<>();

  // If this environment is set, use it directly
  private static final String ENV_BINARY_PATH = "NVIDIA_SMI_PATH";

  private static final String DEFAULT_BINARY_NAME = "nvidia-smi";

  private static final String DEV_NAME_PREFIX = "nvidia";

  private String pathOfGpuBinary = null;

  // command should not run more than 10 sec.
  private static final int MAX_EXEC_TIMEOUT_MS = 10 * 1000;

  // When executable path not set, try to search default dirs
  // By default search /usr/bin, /bin, and /usr/local/nvidia/bin (when
  // launched by nvidia-docker.
  private static final Set<String> DEFAULT_BINARY_SEARCH_DIRS = ImmutableSet.of(
      "/usr/bin", "/bin", "/usr/local/nvidia/bin");

  private boolean topoInitialized = false;

  private Set<Device> lastTimeFoundDevices;

  /**
   * It caches the combination of different devices and the communication cost.
   * The key is device count
   * The value is an ordered list of map entry whose key is device combination,
   * value is cost. The list is sorted by cost in ascending order.
   * For instance:
   * { 2=> [[device1,device2]=>0, [device1,device3]=>10]
   *   3 => [[device1,device2,device3]=>10, [device2,device3,device5]=>20],
   * }
   * */
  private Map<Integer, List<Map.Entry<Set<Device>, Integer>>> costTable
      = new HashMap<>();

  /**
   * The key is a pair of minors. For instance, "0-1" indicates 0 to 1
   * The value is weight between the two devices.
   * */
  private Map<String, Integer> devicePairToWeight = new HashMap<>();

  /**
   * The container can set this environment variable.
   * To tell the scheduler what's the policy to use when do scheduling
   * */
  public static final String TOPOLOGY_POLICY_ENV_KEY = "NVIDIA_TOPO_POLICY";

  /**
   * Schedule policy that prefer the faster GPU-GPU communication.
   * Suitable for heavy GPU computation workload generally.
   * */
  public static final String TOPOLOGY_POLICY_PACK = "PACK";

  /**
   * Schedule policy that prefer the faster CPU-GPU communication.
   * Suitable for heavy CPU-GPU IO operations generally.
   * */
  public static final String TOPOLOGY_POLICY_SPREAD = "SPREAD";

  @Override
  public DeviceRegisterRequest getRegisterRequestInfo() throws Exception {
    return DeviceRegisterRequest.Builder.newInstance()
        .setResourceName(NV_RESOURCE_NAME).build();
  }

  @Override
  public Set<Device> getDevices() throws Exception {
    shellExecutor.searchBinary();
    TreeSet<Device> r = new TreeSet<>();
    String output;
    try {
      output = shellExecutor.getDeviceInfo();
      String[] lines = output.trim().split("\n");
      int id = 0;
      for (String oneLine : lines) {
        String[] tokensEachLine = oneLine.split(",");
        if (tokensEachLine.length != 2) {
          throw new Exception("Cannot parse the output to get device info. "
              + "Unexpected format in it:" + oneLine);
        }
        String minorNumber = tokensEachLine[0].trim();
        String busId = tokensEachLine[1].trim();
        String majorNumber = getMajorNumber(DEV_NAME_PREFIX
            + minorNumber);
        if (majorNumber != null) {
          r.add(Device.Builder.newInstance()
              .setId(id)
              .setMajorNumber(Integer.parseInt(majorNumber))
              .setMinorNumber(Integer.parseInt(minorNumber))
              .setBusID(busId)
              .setDevPath("/dev/" + DEV_NAME_PREFIX + minorNumber)
              .setHealthy(true)
              .build());
          id++;
        }
      }
      // cache it which help to topology scheduling
      lastTimeFoundDevices = r;
      return r;
    } catch (IOException e) {
      LOG.debug("Failed to get output from {}", pathOfGpuBinary);
      throw new YarnException(e);
    }
  }

  @Override
  public DeviceRuntimeSpec onDevicesAllocated(Set<Device> allocatedDevices,
      YarnRuntimeType yarnRuntime) throws Exception {
    LOG.debug("Generating runtime spec for allocated devices: {}, {}",
        allocatedDevices, yarnRuntime.getName());
    if (yarnRuntime == YarnRuntimeType.RUNTIME_DOCKER) {
      String nvidiaRuntime = "nvidia";
      String nvidiaVisibleDevices = "NVIDIA_VISIBLE_DEVICES";
      StringBuilder gpuMinorNumbersSB = new StringBuilder();
      for (Device device : allocatedDevices) {
        gpuMinorNumbersSB.append(device.getMinorNumber() + ",");
      }
      String minorNumbers = gpuMinorNumbersSB.toString();
      LOG.info("Nvidia Docker v2 assigned GPU: " + minorNumbers);
      return DeviceRuntimeSpec.Builder.newInstance()
          .addEnv(nvidiaVisibleDevices,
              minorNumbers.substring(0, minorNumbers.length() - 1))
          .setContainerRuntime(nvidiaRuntime)
          .build();
    }
    return null;
  }

  @Override
  public void onDevicesReleased(Set<Device> releasedDevices) throws Exception {
    // do nothing
  }

  // Get major number from device name.
  private String getMajorNumber(String devName) {
    String output = null;
    // output "major:minor" in hex
    try {
      LOG.debug("Get major numbers from /dev/{}", devName);
      output = shellExecutor.getMajorMinorInfo(devName);
      String[] strs = output.trim().split(":");
      LOG.debug("stat output:{}", output);
      output = Integer.toString(Integer.parseInt(strs[0], 16));
    } catch (IOException e) {
      String msg =
          "Failed to get major number from reading /dev/" + devName;
      LOG.warn(msg);
    } catch (NumberFormatException e) {
      LOG.error("Failed to parse device major number from stat output");
      output = null;
    }
    return output;
  }

  @Override
  public Set<Device> allocateDevices(Set<Device> availableDevices, int count,
      Map<String, String> envs) {
    Set<Device> allocation = new TreeSet<>();
    /**
     * corner cases.
     * if allocate 1 device or all devices, no topo scheduling needed.
     * if total available devices is less than 3, no topo scheduling needed.
     * */
    if (availableDevices.size() < 3
        || count == 1
        || availableDevices.size() == count) {
      basicSchedule(allocation, count, availableDevices);
      return allocation;
    }

    try {
      if (!topoInitialized) {
        initCostTable();
      }
      // topology aware scheduling
      topologyAwareSchedule(allocation, count,
          envs, availableDevices, this.costTable);
      if (allocation.size() == count) {
        return allocation;
      } else {
        LOG.error("Failed to do topology scheduling. Skip to use basic "
            + "scheduling");
      }
    } catch (IOException e) {
      LOG.error("Error in getting GPU topology info. "
          + "Skip topology aware scheduling", e);
    }
    // basic scheduling
    basicSchedule(allocation, count, availableDevices);
    return allocation;
  }

  @VisibleForTesting
  public void initCostTable() throws IOException {
    // get topology
    String topo = shellExecutor.getTopologyInfo();
    // build the graph
    parseTopo(topo, devicePairToWeight);
    // build the cost table of different device combinations
    if (lastTimeFoundDevices == null) {
      try {
        getDevices();
      } catch (Exception e) {
        LOG.error("Failed to get devices!", e);
        return;
      }
    }
    buildCostTable(costTable, lastTimeFoundDevices);
    loggingCostTable(costTable);
    this.topoInitialized = true;
  }

  private void loggingCostTable(
      Map<Integer, List<Map.Entry<Set<Device>, Integer>>> cTable) {
    if (LOG.isDebugEnabled()) {
      StringBuilder sb = new StringBuilder("The costTable is:");
      sb.append("\n{");
      for (Map.Entry<Integer, List<Map.Entry<Set<Device>, Integer>>> entry
          : cTable.entrySet()) {
        sb.append("\n\t")
            .append(entry.getKey())
            .append(" => [");
        for (Map.Entry<Set<Device>, Integer> e : entry.getValue()) {
          sb.append("\n\t\t").append(e.toString()).append(",\n");
        }
        sb.append("\t\t]\n");
      }
      sb.append("}\n");
      LOG.debug(sb.toString());
    }
  }

  /**
   * Generate combination of devices and its cost.
   * costTable
   * */
  private void buildCostTable(
      Map<Integer, List<Map.Entry<Set<Device>, Integer>>> cTable,
      Set<Device> ltfDevices) {
    Device[] deviceList = new Device[ltfDevices.size()];
    ltfDevices.toArray(deviceList);
    generateAllDeviceCombination(cTable, deviceList, deviceList.length);
  }

  /**
   * For every possible combination of i elements.
   * We generate a map whose key is the combination, value is cost.
   */
  private void generateAllDeviceCombination(
      Map<Integer, List<Map.Entry<Set<Device>, Integer>>> cTable,
      Device[] allDevices, int n) {
    // allocated devices count range from 1 to n-1
    for (int i = 2; i < n; i++) {
      Map<Set<Device>, Integer> combinationToCost =
          new HashMap<>();
      buildCombination(combinationToCost, allDevices, n, i);
      // sort the map entry by cost ascending order
      List<Map.Entry<Set<Device>, Integer>> listSortedByCost =
          new LinkedList<>(combinationToCost.entrySet());
      Collections.sort(listSortedByCost,
          (o1, o2) -> (o1.getValue()).compareTo(o2.getValue()));
      cTable.put(i, listSortedByCost);
    }
  }

  private void buildCombination(Map<Set<Device>, Integer> combinationToCost,
      Device[] allDevices, int n, int r) {
    // A temporary list to store all combination one by one
    Device[] subDeviceList = new Device[r];
    combinationRecursive(combinationToCost, allDevices, subDeviceList,
        0, n - 1, 0, r);
  }

  /**
   * Populate combination to cost map recursively.
   *
   * @param cTc           combinationToCost map.
   *                      The key is device set, the value is cost
   * @param allDevices    all devices used to assign value to subDevicelist
   * @param subDeviceList store a subset of devices temporary
   * @param start         start index in the allDevices
   * @param end           last index in the allDevices
   * @param index         dynamic index in subDeviceList need to be assigned
   * @param r             the length of the subDeviceList
   */
  void combinationRecursive(Map<Set<Device>, Integer> cTc,
      Device[] allDevices, Device[] subDeviceList,
      int start, int end, int index, int r) {
    // sub device list's length is ready to compute the cost
    if (index == r) {
      Set<Device> oneSet = new TreeSet<>(Arrays.asList(subDeviceList));
      int cost = computeCostOfDevices(subDeviceList);
      cTc.put(oneSet, cost);
      return;
    }
    for (int i = start; i <= end; i++) {
      subDeviceList[index] = allDevices[i];
      combinationRecursive(cTc, allDevices, subDeviceList,
          i + 1, end, index + 1, r);
    }
  }

  /**
   * The cost function used to calculate costs of a sub set of devices.
   * It calculate link weight of each pair in non-duplicated combination of
   * devices.
   */
  @VisibleForTesting
  public int computeCostOfDevices(Device[] devices) {
    int cost = 0;
    String gpuIndex0;
    String gpuIndex1;
    for (int i = 0; i < devices.length; i++) {
      gpuIndex0 = String.valueOf(devices[i].getMinorNumber());
      for (int j = i + 1; j < devices.length; j++) {
        gpuIndex1 = String.valueOf(devices[j].getMinorNumber());
        cost += this.devicePairToWeight.get(gpuIndex0 + "-" + gpuIndex1);
      }
    }
    return cost;
  }

  /**
   * Topology Aware schedule algorithm.
   * It doesn't consider CPU affinity or NUMA or bus bandwidths.
   * It support two plicy: "spread" and "pack" which can be set by container's
   * environment variable. Use pack by default which means prefer the faster
   * GPU-GPU. "Spread" means prefer the faster CPU-GPU.
   * It can potentially be extend to take GPU attribute like GPU chip memory
   * into consideration.
   * */
  @VisibleForTesting
  public void topologyAwareSchedule(Set<Device> allocation, int count,
      Map<String, String> envs,
      Set<Device> availableDevices,
      Map<Integer, List<Map.Entry<Set<Device>, Integer>>> cTable) {
    int num = 0;
    String policy = envs.get(TOPOLOGY_POLICY_ENV_KEY);
    if (policy == null) {
      policy = TOPOLOGY_POLICY_PACK;
    }

    /**
     * Get combinations from costTable given the count of device want to
     * allocate.
     * */
    if (cTable == null) {
      LOG.error("No cost table initialized!");
      return;
    }
    List<Map.Entry<Set<Device>, Integer>> combinationsToCost =
        cTable.get(count);
    Iterator<Map.Entry<Set<Device>, Integer>> iterator =
        combinationsToCost.iterator();
    // the container needs spread policy
    if (policy.equalsIgnoreCase(TOPOLOGY_POLICY_SPREAD)) {
      // loop from high cost to low cost
      iterator = ((LinkedList) combinationsToCost).descendingIterator();
    }
    while (iterator.hasNext()) {
      Map.Entry<Set<Device>, Integer> element = iterator.next();
      if (availableDevices.containsAll(element.getKey())) {
        allocation.addAll(element.getKey());
        LOG.info("Topology scheduler allocated: " + allocation);
        return;
      }
    }
    LOG.error("Unknown error happened in topology scheduler");
  }

  @VisibleForTesting
  public void basicSchedule(Set<Device> allocation, int count,
      Set<Device> availableDevices) {
    // Basic scheduling
    // allocate all available
    if (count == availableDevices.size()) {
      allocation.addAll(availableDevices);
      return;
    }
    int number = 0;
    for (Device d : availableDevices) {
      allocation.add(d);
      number++;
      if (number == count) {
        break;
      }
    }
  }

  /**
   * A typical sample topo output:
   *     GPU0  GPU1  GPU2  GPU3  CPU Affinity
   * GPU0  X  PHB  SOC  SOC  0-31
   * GPU1 PHB  X   SOC  SOC  0-31
   * GPU2 SOC SOC  X    PHB  0-31
   * GPU3 SOC SOC  PHB   X   0-31
   *
   *
   * Legend:
   *
   *   X   = Self
   *   SOC  = Connection traversing PCIe as well as the SMP link between
   *   CPU sockets(e.g. QPI)
   *   PHB  = Connection traversing PCIe as well as a PCIe Host Bridge
   *   (typically the CPU)
   *   PXB  = Connection traversing multiple PCIe switches
   *   (without traversing the PCIe Host Bridge)
   *   PIX  = Connection traversing a single PCIe switch
   *   NV#  = Connection traversing a bonded set of # NVLinks���
   * */
  public void parseTopo(String topo,
      Map<String, Integer> deviceLinkToWeight) {
    String[] lines = topo.split("\n");
    int rowMinor;
    int colMinor;
    String legend;
    String tempType;
    for (String oneLine : lines) {
      oneLine = oneLine.trim();
      if (oneLine.isEmpty()) {
        continue;
      }
      // To the end. No more metrics info
      if (oneLine.startsWith("Legend")) {
        break;
      }
      // Skip header
      if (oneLine.contains("Affinity")) {
        continue;
      }
      String[] tokens = oneLine.split(("\\s+"));
      String name = tokens[0];
      rowMinor = Integer.parseInt(name.substring(name.lastIndexOf("U") + 1));
      for (int i = 1; i < tokens.length; i++) {
        tempType = tokens[i];
        colMinor = i - 1;
        // self, skip
        if (tempType.equals("X")) {
          continue;
        }
        if (tempType.equals("SOC") || tempType.equals("SYS")) {
          populateGraphEdgeWeight(DeviceLinkType.P2PLinkCrossCPUSocket,
              rowMinor, colMinor, deviceLinkToWeight);
          continue;
        }
        if (tempType.equals("PHB") || tempType.equals("NODE")) {
          populateGraphEdgeWeight(DeviceLinkType.P2PLinkSameCPUSocket,
              rowMinor, colMinor, deviceLinkToWeight);
          continue;
        }
        if (tempType.equals("PXB")) {
          populateGraphEdgeWeight(DeviceLinkType.P2PLinkMultiSwitch,
              rowMinor, colMinor, deviceLinkToWeight);
          continue;
        }
        if (tempType.equals("PIX")) {
          populateGraphEdgeWeight(DeviceLinkType.P2PLinkSingleSwitch,
              rowMinor, colMinor, deviceLinkToWeight);
          continue;
        }
        if (tempType.equals("NV1")) {
          populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink1,
              rowMinor, colMinor, deviceLinkToWeight);
          continue;
        }
        if (tempType.equals("NV2")) {
          populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink2,
              rowMinor, colMinor, deviceLinkToWeight);
          continue;
        }
        if (tempType.equals("NV3")) {
          populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink3,
              rowMinor, colMinor, deviceLinkToWeight);
          continue;
        }
        if (tempType.equals("NV4")) {
          populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink4,
              rowMinor, colMinor, deviceLinkToWeight);
          continue;
        }
        if (tempType.equals("NV5")) {
          populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink5,
              rowMinor, colMinor, deviceLinkToWeight);
          continue;
        }
        if (tempType.equals("NV6")) {
          populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink6,
              rowMinor, colMinor, deviceLinkToWeight);
          continue;
        }
        if (tempType.equals("NV7")) {
          populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink7,
              rowMinor, colMinor, deviceLinkToWeight);
          continue;
        }
        if (tempType.equals("NV8")) {
          populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink8,
              rowMinor, colMinor, deviceLinkToWeight);
          continue;
        }
        if (tempType.equals("NV9")) {
          populateGraphEdgeWeight(DeviceLinkType.P2PLinkNVLink9,
              rowMinor, colMinor, deviceLinkToWeight);
          continue;
        }
      } // end one line handling
    }
  }

  private void populateGraphEdgeWeight(
      DeviceLinkType linkType,
      int leftVertex,
      int rightVertex,
      Map<String, Integer> deviceLinkToWeight) {
    deviceLinkToWeight.put(leftVertex + "-" + rightVertex,
        linkType.getWeight());
  }

  /**
   * Different type of link.
   * The weight of each link is a relative value.
   * The higher weight, the higher cost between the GPUs
   * */
  public enum DeviceLinkType {
    /**
     * For Nvdia GPU NVLink.
     * */
    P2PLinkNVLink9(10),
    P2PLinkNVLink8(20),
    P2PLinkNVLink7(30),
    P2PLinkNVLink6(40),
    P2PLinkNVLink5(50),
    P2PLinkNVLink4(60),
    P2PLinkNVLink3(70),
    P2PLinkNVLink2(80),
    P2PLinkNVLink1(90),

    /**
     * Connected to same CPU (Same NUMA node).
     * */
    P2PLinkSameCPUSocket(200),

    /**
     * Cross CPU through socket-level link (e.g. QPI).
     * Usually cross NUMA node
     * */
    P2PLinkCrossCPUSocket(300),

    /**
     * Just need to traverse one PCIe switch to talk.
     * */
    P2PLinkSingleSwitch(600),

    /**
     * Need to traverse multiple PCIe switch to talk.
     * */
    P2PLinkMultiSwitch(1200);

    // A higher link level means slower communication.
    private int weight;

    public int getWeight() {
      return weight;
    }

    DeviceLinkType(int w) {
      this.weight = w;
    }
  }

  /**
   * A shell wrapper class easy for test.
   * */
  public class NvidiaCommandExecutor {

    public String getDeviceInfo() throws IOException {
      return Shell.execCommand(environment,
          new String[]{pathOfGpuBinary, "--query-gpu=index,pci.bus_id",
              "--format=csv,noheader"}, MAX_EXEC_TIMEOUT_MS);
    }

    public String getMajorMinorInfo(String devName) throws IOException {
      // output "major:minor" in hex
      Shell.ShellCommandExecutor shexec = new Shell.ShellCommandExecutor(
          new String[]{"stat", "-c", "%t:%T", "/dev/" + devName});
      shexec.execute();
      return shexec.getOutput();
    }

    // Get the topology metrics info from nvdia-smi
    public String getTopologyInfo() throws IOException {
      return Shell.execCommand(environment,
          new String[]{pathOfGpuBinary, "topo",
              "-m"}, MAX_EXEC_TIMEOUT_MS);
    }

    public void searchBinary() throws Exception {
      if (pathOfGpuBinary != null) {
        LOG.info("Skip searching, the nvidia gpu binary is already set: "
            + pathOfGpuBinary);
        return;
      }
      // search env for the binary
      String envBinaryPath = System.getenv(ENV_BINARY_PATH);
      if (null != envBinaryPath) {
        if (new File(envBinaryPath).exists()) {
          pathOfGpuBinary = envBinaryPath;
          LOG.info("Use nvidia gpu binary: " + pathOfGpuBinary);
          return;
        }
      }
      LOG.info("Search binary..");
      // search if binary exists in default folders
      File binaryFile;
      boolean found = false;
      for (String dir : DEFAULT_BINARY_SEARCH_DIRS) {
        binaryFile = new File(dir, DEFAULT_BINARY_NAME);
        if (binaryFile.exists()) {
          found = true;
          pathOfGpuBinary = binaryFile.getAbsolutePath();
          LOG.info("Found binary:" + pathOfGpuBinary);
          break;
        }
      }
      if (!found) {
        LOG.error("No binary found from env variable: "
            + ENV_BINARY_PATH + " or path "
            + DEFAULT_BINARY_SEARCH_DIRS.toString());
        throw new Exception("No binary found for "
            + NvidiaGPUPluginForRuntimeV2.class);
      }
    }
  }

  @VisibleForTesting
  public void setPathOfGpuBinary(String pOfGpuBinary) {
    this.pathOfGpuBinary = pOfGpuBinary;
  }

  @VisibleForTesting
  public void setShellExecutor(
      NvidiaCommandExecutor shellExecutor) {
    this.shellExecutor = shellExecutor;
  }

  @VisibleForTesting
  public boolean isTopoInitialized() {
    return topoInitialized;
  }

  @VisibleForTesting
  public Map<Integer, List<Map.Entry<Set<Device>, Integer>>> getCostTable() {
    return costTable;
  }

  @VisibleForTesting
  public Map<String, Integer> getDevicePairToWeight() {
    return devicePairToWeight;
  }

}