TestDevicePluginAdapter.java
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.server.nodemanager.Context;
import org.apache.hadoop.yarn.server.nodemanager.NodeManager;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePlugin;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRegisterRequest;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DeviceRuntimeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.MountDeviceSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.MountVolumeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.VolumeSpec;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.YarnRuntimeType;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.ResourceMappings;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperation;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationException;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.privileged.PrivilegedOperationExecutor;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.CGroupsHandler;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerRunCommand;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.runtime.docker.DockerVolumeCommand;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.DockerCommandPlugin;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.ResourcePluginManager;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.runtime.ContainerRuntimeConstants;
import org.apache.hadoop.yarn.server.nodemanager.recovery.NMMemoryStateStoreService;
import org.apache.hadoop.yarn.server.nodemanager.recovery.NMStateStoreService;
import org.apache.hadoop.yarn.server.nodemanager.webapp.dao.NMDeviceResourceInfo;
import org.apache.hadoop.yarn.util.resource.ResourceUtils;
import org.apache.hadoop.yarn.util.resource.TestResourceUtils;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.ArgumentMatchers.anySet;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
/**
* Unit tests for DevicePluginAdapter.
* About interaction with vendor plugin
* */
public class TestDevicePluginAdapter {
protected static final Logger LOG =
LoggerFactory.getLogger(TestDevicePluginAdapter.class);
private YarnConfiguration conf;
private String tempResourceTypesFile;
private CGroupsHandler mockCGroupsHandler;
private PrivilegedOperationExecutor mockPrivilegedExecutor;
@BeforeEach
public void setup() throws Exception {
this.conf = new YarnConfiguration();
// setup resource-types.xml
ResourceUtils.resetResourceTypes();
String resourceTypesFile = "resource-types-pluggable-devices.xml";
this.tempResourceTypesFile =
TestResourceUtils.setupResourceTypes(this.conf, resourceTypesFile);
mockCGroupsHandler = mock(CGroupsHandler.class);
mockPrivilegedExecutor = mock(PrivilegedOperationExecutor.class);
}
@AfterEach
public void tearDown() throws IOException {
// cleanup resource-types.xml
File dest = new File(this.tempResourceTypesFile);
if (dest.exists()) {
dest.delete();
}
}
/**
* Use the MyPlugin which implement {@code DevicePlugin}.
* Plugin's initialization is tested in TestResourcePluginManager
* */
@Test
public void testBasicWorkflow()
throws YarnException, IOException {
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService storeService = mock(NMStateStoreService.class);
when(context.getNMStateStore()).thenReturn(storeService);
when(context.getConf()).thenReturn(this.conf);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
// Init scheduler manager
DeviceMappingManager dmm = new DeviceMappingManager(context);
ResourcePluginManager rpm = mock(ResourcePluginManager.class);
when(rpm.getDeviceMappingManager()).thenReturn(dmm);
// Init an plugin
MyPlugin plugin = new MyPlugin();
MyPlugin spyPlugin = spy(plugin);
String resourceName = MyPlugin.RESOURCE_NAME;
// Init an adapter for the plugin
DevicePluginAdapter adapter = new DevicePluginAdapter(
resourceName,
spyPlugin, dmm);
// Bootstrap, adding device
adapter.initialize(context);
// Use mock shell when create resourceHandler
ShellWrapper mockShellWrapper = mock(ShellWrapper.class);
when(mockShellWrapper.existFile(anyString())).thenReturn(true);
when(mockShellWrapper.getDeviceFileType(anyString())).thenReturn("c");
DeviceResourceHandlerImpl drhl = new DeviceResourceHandlerImpl(resourceName,
adapter, dmm, mockCGroupsHandler, mockPrivilegedExecutor, context,
mockShellWrapper);
adapter.setDeviceResourceHandler(drhl);
adapter.getDeviceResourceHandler().bootstrap(conf);
verify(mockCGroupsHandler).initializeCGroupController(
CGroupsHandler.CGroupController.DEVICES);
int size = dmm.getAvailableDevices(resourceName);
assertEquals(3, size);
// Case 1. A container c1 requests 1 device
Container c1 = mockContainerWithDeviceRequest(1,
resourceName,
1, false);
// preStart
adapter.getDeviceResourceHandler().preStart(c1);
// check book keeping
assertEquals(2,
dmm.getAvailableDevices(resourceName));
assertEquals(1,
dmm.getAllUsedDevices().get(resourceName).size());
assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
assertEquals(1,
dmm.getAllocatedDevices(resourceName, c1.getContainerId()).size());
verify(mockShellWrapper, times(2)).getDeviceFileType(anyString());
// check device cgroup create operation
checkCgroupOperation(c1.getContainerId().toString(), 1,
"c-256:1-rwm,c-256:2-rwm", "256:0");
// postComplete
adapter.getDeviceResourceHandler().postComplete(getContainerId(1));
assertEquals(3,
dmm.getAvailableDevices(resourceName));
assertEquals(0,
dmm.getAllUsedDevices().get(resourceName).size());
assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
// check cgroup delete operation
verify(mockCGroupsHandler).deleteCGroup(
CGroupsHandler.CGroupController.DEVICES,
c1.getContainerId().toString());
// Case 2. A container c2 requests 3 device
Container c2 = mockContainerWithDeviceRequest(2,
resourceName,
3, false);
reset(mockShellWrapper);
reset(mockCGroupsHandler);
reset(mockPrivilegedExecutor);
when(mockShellWrapper.existFile(anyString())).thenReturn(true);
when(mockShellWrapper.getDeviceFileType(anyString())).thenReturn("c");
// preStart
adapter.getDeviceResourceHandler().preStart(c2);
// check book keeping
assertEquals(0,
dmm.getAvailableDevices(resourceName));
assertEquals(3,
dmm.getAllUsedDevices().get(resourceName).size());
assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
assertEquals(3,
dmm.getAllocatedDevices(resourceName, c2.getContainerId()).size());
verify(mockShellWrapper, times(0)).getDeviceFileType(anyString());
// check device cgroup create operation
verify(mockCGroupsHandler).createCGroup(
CGroupsHandler.CGroupController.DEVICES,
c2.getContainerId().toString());
// check device cgroup update operation
checkCgroupOperation(c2.getContainerId().toString(), 1,
null, "256:0,256:1,256:2");
// postComplete
adapter.getDeviceResourceHandler().postComplete(getContainerId(2));
assertEquals(3,
dmm.getAvailableDevices(resourceName));
assertEquals(0,
dmm.getAllUsedDevices().get(resourceName).size());
assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
// check cgroup delete operation
verify(mockCGroupsHandler).deleteCGroup(
CGroupsHandler.CGroupController.DEVICES,
c2.getContainerId().toString());
// Case 3. A container c3 request 0 device
Container c3 = mockContainerWithDeviceRequest(3,
resourceName,
0, false);
reset(mockShellWrapper);
reset(mockCGroupsHandler);
reset(mockPrivilegedExecutor);
when(mockShellWrapper.existFile(anyString())).thenReturn(true);
when(mockShellWrapper.getDeviceFileType(anyString())).thenReturn("c");
// preStart
adapter.getDeviceResourceHandler().preStart(c3);
// check book keeping
assertEquals(3,
dmm.getAvailableDevices(resourceName));
assertEquals(0,
dmm.getAllUsedDevices().get(resourceName).size());
assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
verify(mockShellWrapper, times(3)).getDeviceFileType(anyString());
// check device cgroup create operation
verify(mockCGroupsHandler).createCGroup(
CGroupsHandler.CGroupController.DEVICES,
c3.getContainerId().toString());
// check device cgroup update operation
checkCgroupOperation(c3.getContainerId().toString(), 1,
"c-256:0-rwm,c-256:1-rwm,c-256:2-rwm", null);
// postComplete
adapter.getDeviceResourceHandler().postComplete(getContainerId(3));
assertEquals(3,
dmm.getAvailableDevices(resourceName));
assertEquals(0,
dmm.getAllUsedDevices().get(resourceName).size());
assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
assertEquals(0,
dmm.getAllocatedDevices(resourceName, c3.getContainerId()).size());
// check cgroup delete operation
verify(mockCGroupsHandler).deleteCGroup(
CGroupsHandler.CGroupController.DEVICES,
c3.getContainerId().toString());
}
private void checkCgroupOperation(String cId,
int invokeTimesOfPrivilegedExecutor,
String excludedParam, String allowedParam)
throws PrivilegedOperationException, ResourceHandlerException {
verify(mockCGroupsHandler).createCGroup(
CGroupsHandler.CGroupController.DEVICES,
cId);
// check device cgroup update operation
ArgumentCaptor<PrivilegedOperation> args =
ArgumentCaptor.forClass(PrivilegedOperation.class);
verify(mockPrivilegedExecutor, times(invokeTimesOfPrivilegedExecutor))
.executePrivilegedOperation(args.capture(), eq(true));
assertEquals(PrivilegedOperation.OperationType.DEVICE,
args.getValue().getOperationType());
List<String> expectedArgs = new ArrayList<>();
expectedArgs.add(DeviceResourceHandlerImpl.CONTAINER_ID_CLI_OPTION);
expectedArgs.add(cId);
if (excludedParam != null && !excludedParam.isEmpty()) {
expectedArgs.add(DeviceResourceHandlerImpl.EXCLUDED_DEVICES_CLI_OPTION);
expectedArgs.add(excludedParam);
}
if (allowedParam != null && !allowedParam.isEmpty()) {
expectedArgs.add(DeviceResourceHandlerImpl.ALLOWED_DEVICES_CLI_OPTION);
expectedArgs.add(allowedParam);
}
assertArrayEquals(expectedArgs.toArray(),
args.getValue().getArguments().toArray());
}
@Test
public void testDeviceResourceUpdaterImpl() throws YarnException {
Resource nodeResource = mock(Resource.class);
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
// Init an plugin
MyPlugin plugin = new MyPlugin();
MyPlugin spyPlugin = spy(plugin);
String resourceName = MyPlugin.RESOURCE_NAME;
// Init scheduler manager
DeviceMappingManager dmm = new DeviceMappingManager(context);
// Init an adapter for the plugin
DevicePluginAdapter adapter = new DevicePluginAdapter(
resourceName, spyPlugin, dmm);
adapter.initialize(mock(Context.class));
adapter.getNodeResourceHandlerInstance()
.updateConfiguredResource(nodeResource);
verify(spyPlugin, times(1)).getDevices();
verify(nodeResource, times(1)).setResourceValue(
resourceName, 3);
}
@Test
public void testStoreDeviceSchedulerManagerState()
throws IOException, YarnException {
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService realStoreService = new NMMemoryStateStoreService();
NMStateStoreService storeService = spy(realStoreService);
when(context.getNMStateStore()).thenReturn(storeService);
when(context.getConf()).thenReturn(this.conf);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
// Init scheduler manager
DeviceMappingManager dmm = new DeviceMappingManager(context);
ResourcePluginManager rpm = mock(ResourcePluginManager.class);
when(rpm.getDeviceMappingManager()).thenReturn(dmm);
// Init an plugin
MyPlugin plugin = new MyPlugin();
MyPlugin spyPlugin = spy(plugin);
String resourceName = MyPlugin.RESOURCE_NAME;
// Init an adapter for the plugin
DevicePluginAdapter adapter = new DevicePluginAdapter(
resourceName,
spyPlugin, dmm);
// Bootstrap, adding device
adapter.initialize(context);
adapter.createResourceHandler(context,
mockCGroupsHandler, mockPrivilegedExecutor);
adapter.getDeviceResourceHandler().bootstrap(conf);
// A container c0 requests 1 device
Container c0 = mockContainerWithDeviceRequest(0,
resourceName,
1, false);
// preStart
adapter.getDeviceResourceHandler().preStart(c0);
// ensure container1's resource is persistent
verify(storeService).storeAssignedResources(c0, resourceName,
Arrays.asList(Device.Builder.newInstance()
.setId(0)
.setDevPath("/dev/hdwA0")
.setMajorNumber(256)
.setMinorNumber(0)
.setBusID("0000:80:00.0")
.setHealthy(true)
.build()));
}
@Test
public void testRecoverDeviceSchedulerManagerState()
throws IOException, YarnException {
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService realStoreService = new NMMemoryStateStoreService();
NMStateStoreService storeService = spy(realStoreService);
when(context.getNMStateStore()).thenReturn(storeService);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
// Init scheduler manager
DeviceMappingManager dmm = new DeviceMappingManager(context);
ResourcePluginManager rpm = mock(ResourcePluginManager.class);
when(rpm.getDeviceMappingManager()).thenReturn(dmm);
// Init an plugin
MyPlugin plugin = new MyPlugin();
MyPlugin spyPlugin = spy(plugin);
String resourceName = MyPlugin.RESOURCE_NAME;
// Init an adapter for the plugin
DevicePluginAdapter adapter = new DevicePluginAdapter(
resourceName,
spyPlugin, dmm);
// Bootstrap, adding device
adapter.initialize(context);
adapter.createResourceHandler(context,
mockCGroupsHandler, mockPrivilegedExecutor);
adapter.getDeviceResourceHandler().bootstrap(conf);
assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
// mock NMStateStore
Device storedDevice = Device.Builder.newInstance()
.setId(0)
.setDevPath("/dev/hdwA0")
.setMajorNumber(256)
.setMinorNumber(0)
.setBusID("0000:80:00.0")
.setHealthy(true)
.build();
ConcurrentHashMap<ContainerId, Container> runningContainersMap
= new ConcurrentHashMap<>();
Container nmContainer = mock(Container.class);
ResourceMappings rmap = new ResourceMappings();
ResourceMappings.AssignedResources ar =
new ResourceMappings.AssignedResources();
ar.updateAssignedResources(
Arrays.asList(storedDevice));
rmap.addAssignedResources(resourceName, ar);
when(nmContainer.getResourceMappings()).thenReturn(rmap);
when(context.getContainers()).thenReturn(runningContainersMap);
// Test case 1. c0 get recovered. scheduler state restored
runningContainersMap.put(getContainerId(0), nmContainer);
adapter.getDeviceResourceHandler().reacquireContainer(
getContainerId(0));
assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
assertEquals(1,
dmm.getAllUsedDevices().get(resourceName).size());
assertEquals(2,
dmm.getAvailableDevices(resourceName));
Map<Device, ContainerId> used = dmm.getAllUsedDevices().get(resourceName);
assertTrue(used.keySet().contains(storedDevice));
// Test case 2. c1 wants get recovered.
// But stored device is already allocated to c2
nmContainer = mock(Container.class);
rmap = new ResourceMappings();
ar = new ResourceMappings.AssignedResources();
ar.updateAssignedResources(
Arrays.asList(storedDevice));
rmap.addAssignedResources(resourceName, ar);
// already assigned to c1
runningContainersMap.put(getContainerId(2), nmContainer);
boolean caughtException = false;
try {
adapter.getDeviceResourceHandler().reacquireContainer(getContainerId(1));
} catch (ResourceHandlerException e) {
caughtException = true;
}
assertTrue(caughtException, "Should fail since requested device is assigned already");
// don't affect c0 allocation state
assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
assertEquals(1,
dmm.getAllUsedDevices().get(resourceName).size());
assertEquals(2,
dmm.getAvailableDevices(resourceName));
used = dmm.getAllUsedDevices().get(resourceName);
assertTrue(used.keySet().contains(storedDevice));
}
@Test
public void testAssignedDeviceCleanupWhenStoreOpFails()
throws IOException, YarnException {
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService realStoreService = new NMMemoryStateStoreService();
NMStateStoreService storeService = spy(realStoreService);
when(context.getConf()).thenReturn(this.conf);
when(context.getNMStateStore()).thenReturn(storeService);
doThrow(new IOException("Exception ...")).when(storeService)
.storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
// Init scheduler manager
DeviceMappingManager dmm = new DeviceMappingManager(context);
ResourcePluginManager rpm = mock(ResourcePluginManager.class);
when(rpm.getDeviceMappingManager()).thenReturn(dmm);
// Init an plugin
MyPlugin plugin = new MyPlugin();
MyPlugin spyPlugin = spy(plugin);
String resourceName = MyPlugin.RESOURCE_NAME;
// Init an adapter for the plugin
DevicePluginAdapter adapter = new DevicePluginAdapter(
resourceName,
spyPlugin, dmm);
// Bootstrap, adding device
adapter.initialize(context);
adapter.createResourceHandler(context,
mockCGroupsHandler, mockPrivilegedExecutor);
adapter.getDeviceResourceHandler().bootstrap(conf);
// A container c0 requests 1 device
Container c0 = mockContainerWithDeviceRequest(0,
resourceName,
1, false);
// preStart
boolean exception = false;
try {
adapter.getDeviceResourceHandler().preStart(c0);
} catch (ResourceHandlerException e) {
exception = true;
}
assertTrue(exception, "Should throw exception in preStart");
// no device assigned
assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
assertEquals(0,
dmm.getAllUsedDevices().get(resourceName).size());
assertEquals(3,
dmm.getAvailableDevices(resourceName));
}
@Test
public void testPreferPluginScheduler() throws IOException, YarnException {
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService storeService = mock(NMStateStoreService.class);
when(context.getNMStateStore()).thenReturn(storeService);
when(context.getConf()).thenReturn(this.conf);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
// Init scheduler manager
DeviceMappingManager dmm = new DeviceMappingManager(context);
ResourcePluginManager rpm = mock(ResourcePluginManager.class);
when(rpm.getDeviceMappingManager()).thenReturn(dmm);
// Init an plugin
MyPlugin plugin = new MyPlugin();
MyPlugin spyPlugin = spy(plugin);
String resourceName = MyPlugin.RESOURCE_NAME;
// Add plugin to DeviceMappingManager
dmm.getDevicePluginSchedulers().put(MyPlugin.RESOURCE_NAME, spyPlugin);
// Init an adapter for the plugin
DevicePluginAdapter adapter = new DevicePluginAdapter(
resourceName,
spyPlugin, dmm);
// Bootstrap, adding device
adapter.initialize(context);
adapter.createResourceHandler(context,
mockCGroupsHandler, mockPrivilegedExecutor);
adapter.getDeviceResourceHandler().bootstrap(conf);
int size = dmm.getAvailableDevices(resourceName);
assertEquals(3, size);
// A container c1 requests 1 device
Container c1 = mockContainerWithDeviceRequest(0,
resourceName,
1, false);
// preStart
adapter.getDeviceResourceHandler().preStart(c1);
// Use customized scheduler
verify(spyPlugin, times(1)).allocateDevices(
anySet(), anyInt(), anyMap());
assertEquals(2,
dmm.getAvailableDevices(resourceName));
assertEquals(1,
dmm.getAllUsedDevices().get(resourceName).size());
assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
}
private static Container mockContainerWithDeviceRequest(int id,
String resourceName,
int numDeviceRequest,
boolean dockerContainerEnabled) {
Container c = mock(Container.class);
when(c.getContainerId()).thenReturn(getContainerId(id));
Resource res = Resource.newInstance(1024, 1);
ResourceMappings resMapping = new ResourceMappings();
res.setResourceValue(resourceName, numDeviceRequest);
when(c.getResource()).thenReturn(res);
when(c.getResourceMappings()).thenReturn(resMapping);
ContainerLaunchContext clc = mock(ContainerLaunchContext.class);
Map<String, String> env = new HashMap<>();
if (dockerContainerEnabled) {
env.put(ContainerRuntimeConstants.ENV_CONTAINER_TYPE,
ContainerRuntimeConstants.CONTAINER_RUNTIME_DOCKER);
}
when(clc.getEnvironment()).thenReturn(env);
when(c.getLaunchContext()).thenReturn(clc);
return c;
}
/**
* Ensure correct return value generated.
* */
@Test
public void testNMResourceInfoRESTAPI() throws IOException, YarnException {
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService storeService = mock(NMStateStoreService.class);
when(context.getNMStateStore()).thenReturn(storeService);
when(context.getConf()).thenReturn(this.conf);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
// Init scheduler manager
DeviceMappingManager dmm = new DeviceMappingManager(context);
ResourcePluginManager rpm = mock(ResourcePluginManager.class);
when(rpm.getDeviceMappingManager()).thenReturn(dmm);
// Init an plugin
MyPlugin plugin = new MyPlugin();
MyPlugin spyPlugin = spy(plugin);
String resourceName = MyPlugin.RESOURCE_NAME;
// Init an adapter for the plugin
DevicePluginAdapter adapter = new DevicePluginAdapter(
resourceName,
spyPlugin, dmm);
// Bootstrap, adding device
adapter.initialize(context);
adapter.createResourceHandler(context,
mockCGroupsHandler, mockPrivilegedExecutor);
adapter.getDeviceResourceHandler().bootstrap(conf);
int size = dmm.getAvailableDevices(resourceName);
assertEquals(3, size);
// A container c1 requests 1 device
Container c1 = mockContainerWithDeviceRequest(0,
resourceName,
1, false);
// preStart
adapter.getDeviceResourceHandler().preStart(c1);
// check book keeping
assertEquals(2,
dmm.getAvailableDevices(resourceName));
assertEquals(1,
dmm.getAllUsedDevices().get(resourceName).size());
assertEquals(3,
dmm.getAllAllowedDevices().get(resourceName).size());
// get REST return value
NMDeviceResourceInfo response =
(NMDeviceResourceInfo) adapter.getNMResourceInfo();
assertEquals(1, response.getAssignedDevices().size());
assertEquals(3, response.getTotalDevices().size());
Device device = response.getAssignedDevices().get(0).getDevice();
String cId = response.getAssignedDevices().get(0).getContainerId();
assertTrue(dmm.getAllAllowedDevices().get(resourceName)
.contains(device));
assertTrue(dmm.getAllUsedDevices().get(resourceName)
.containsValue(ContainerId.fromString(cId)));
//finish container
adapter.getDeviceResourceHandler().postComplete(getContainerId(0));
response =
(NMDeviceResourceInfo) adapter.getNMResourceInfo();
assertEquals(0, response.getAssignedDevices().size());
assertEquals(3, response.getTotalDevices().size());
}
/**
* Test a container run command update when using Docker runtime.
* And the device plugin it uses is like Nvidia Docker v1.
* */
@Test
public void testDeviceResourceDockerRuntimePlugin1() throws Exception {
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService storeService = mock(NMStateStoreService.class);
when(context.getNMStateStore()).thenReturn(storeService);
when(context.getConf()).thenReturn(this.conf);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
// Init scheduler manager
DeviceMappingManager dmm = new DeviceMappingManager(context);
DeviceMappingManager spyDmm = spy(dmm);
ResourcePluginManager rpm = mock(ResourcePluginManager.class);
when(rpm.getDeviceMappingManager()).thenReturn(spyDmm);
// Init a plugin
MyPlugin plugin = new MyPlugin();
MyPlugin spyPlugin = spy(plugin);
String resourceName = MyPlugin.RESOURCE_NAME;
// Init an adapter for the plugin
DevicePluginAdapter adapter = new DevicePluginAdapter(
resourceName,
spyPlugin, spyDmm);
adapter.initialize(context);
// Bootstrap, adding device
adapter.initialize(context);
adapter.createResourceHandler(context,
mockCGroupsHandler, mockPrivilegedExecutor);
adapter.getDeviceResourceHandler().bootstrap(conf);
// Case 1. A container request Docker runtime and 1 device
Container c1 = mockContainerWithDeviceRequest(1, resourceName, 1, true);
// generate spec based on v1
spyPlugin.setDevicePluginVersion("v1");
// preStart will do allocation
adapter.getDeviceResourceHandler().preStart(c1);
Set<Device> allocatedDevice = spyDmm.getAllocatedDevices(resourceName,
c1.getContainerId());
reset(spyDmm);
// c1 is requesting docker runtime.
// it will create parent cgroup but no cgroups update operation needed.
// check device cgroup create operation
verify(mockCGroupsHandler).createCGroup(
CGroupsHandler.CGroupController.DEVICES,
c1.getContainerId().toString());
// ensure no cgroups update operation
verify(mockPrivilegedExecutor, times(0))
.executePrivilegedOperation(
any(PrivilegedOperation.class), anyBoolean());
DockerCommandPlugin dcp = adapter.getDockerCommandPluginInstance();
// When DockerLinuxContainerRuntime invoke the DockerCommandPluginInstance
// First to create volume
DockerVolumeCommand dvc = dcp.getCreateDockerVolumeCommand(c1);
// ensure that allocation is get once from device mapping manager
verify(spyDmm).getAllocatedDevices(resourceName, c1.getContainerId());
// ensure that plugin's onDeviceAllocated is invoked
verify(spyPlugin).onDevicesAllocated(
allocatedDevice,
YarnRuntimeType.RUNTIME_DEFAULT);
verify(spyPlugin).onDevicesAllocated(
allocatedDevice,
YarnRuntimeType.RUNTIME_DOCKER);
assertEquals("nvidia-docker", dvc.getDriverName());
assertEquals("create", dvc.getSubCommand());
assertEquals("nvidia_driver_352.68", dvc.getVolumeName());
// then the DockerLinuxContainerRuntime will update docker run command
DockerRunCommand drc =
new DockerRunCommand(c1.getContainerId().toString(), "user",
"image/tensorflow");
// reset to avoid count times in above invocation
reset(spyPlugin);
reset(spyDmm);
// Second, update the run command.
dcp.updateDockerRunCommand(drc, c1);
// The spec is already generated in getCreateDockerVolumeCommand
// and there should be a cache hit for DeviceRuntime spec.
verify(spyPlugin, times(0)).onDevicesAllocated(
allocatedDevice,
YarnRuntimeType.RUNTIME_DOCKER);
// ensure that allocation is get from cache instead of device mapping
// manager
verify(spyDmm, times(0)).getAllocatedDevices(resourceName,
c1.getContainerId());
String runStr = drc.toString();
assertTrue(
runStr.contains("nvidia_driver_352.68:/usr/local/nvidia:ro"));
assertTrue(runStr.contains("/dev/hdwA0:/dev/hdwA0"));
// Third, cleanup in getCleanupDockerVolumesCommand
dcp.getCleanupDockerVolumesCommand(c1);
// Ensure device plugin's onDeviceReleased is invoked
verify(spyPlugin).onDevicesReleased(allocatedDevice);
// If we run the c1 again. No cache will be used for allocation and spec
dcp.getCreateDockerVolumeCommand(c1);
verify(spyDmm).getAllocatedDevices(resourceName, c1.getContainerId());
verify(spyPlugin).onDevicesAllocated(
allocatedDevice,
YarnRuntimeType.RUNTIME_DOCKER);
}
/**
* Test a container run command update when using Docker runtime.
* And the device plugin it uses is like Nvidia Docker v2.
* */
@Test
public void testDeviceResourceDockerRuntimePlugin2() throws Exception {
NodeManager.NMContext context = mock(NodeManager.NMContext.class);
NMStateStoreService storeService = mock(NMStateStoreService.class);
when(context.getNMStateStore()).thenReturn(storeService);
when(context.getConf()).thenReturn(this.conf);
doNothing().when(storeService).storeAssignedResources(isA(Container.class),
isA(String.class),
isA(ArrayList.class));
// Init scheduler manager
DeviceMappingManager dmm = new DeviceMappingManager(context);
DeviceMappingManager spyDmm = spy(dmm);
ResourcePluginManager rpm = mock(ResourcePluginManager.class);
when(rpm.getDeviceMappingManager()).thenReturn(spyDmm);
// Init a plugin
MyPlugin plugin = new MyPlugin();
MyPlugin spyPlugin = spy(plugin);
String resourceName = MyPlugin.RESOURCE_NAME;
// Init an adapter for the plugin
DevicePluginAdapter adapter = new DevicePluginAdapter(
resourceName,
spyPlugin, spyDmm);
adapter.initialize(context);
// Bootstrap, adding device
adapter.initialize(context);
adapter.createResourceHandler(context,
mockCGroupsHandler, mockPrivilegedExecutor);
adapter.getDeviceResourceHandler().bootstrap(conf);
// Case 1. A container request Docker runtime and 1 device
Container c1 = mockContainerWithDeviceRequest(1, resourceName, 2, true);
// generate spec based on v2
spyPlugin.setDevicePluginVersion("v2");
// preStart will do allocation
adapter.getDeviceResourceHandler().preStart(c1);
Set<Device> allocatedDevice = spyDmm.getAllocatedDevices(resourceName,
c1.getContainerId());
reset(spyDmm);
// c1 is requesting docker runtime.
// it will create parent cgroup but no cgroups update operation needed.
// check device cgroup create operation
verify(mockCGroupsHandler).createCGroup(
CGroupsHandler.CGroupController.DEVICES,
c1.getContainerId().toString());
// ensure no cgroups update operation
verify(mockPrivilegedExecutor, times(0))
.executePrivilegedOperation(
any(PrivilegedOperation.class), anyBoolean());
DockerCommandPlugin dcp = adapter.getDockerCommandPluginInstance();
// When DockerLinuxContainerRuntime invoke the DockerCommandPluginInstance
// First to create volume
DockerVolumeCommand dvc = dcp.getCreateDockerVolumeCommand(c1);
// ensure that allocation is get once from device mapping manager
verify(spyDmm).getAllocatedDevices(resourceName, c1.getContainerId());
// ensure that plugin's onDeviceAllocated is invoked
verify(spyPlugin).onDevicesAllocated(
allocatedDevice,
YarnRuntimeType.RUNTIME_DEFAULT);
verify(spyPlugin).onDevicesAllocated(
allocatedDevice,
YarnRuntimeType.RUNTIME_DOCKER);
// No volume creation request
assertNull(dvc);
// then the DockerLinuxContainerRuntime will update docker run command
DockerRunCommand drc =
new DockerRunCommand(c1.getContainerId().toString(), "user",
"image/tensorflow");
// reset to avoid count times in above invocation
reset(spyPlugin);
reset(spyDmm);
// Second, update the run command.
dcp.updateDockerRunCommand(drc, c1);
// The spec is already generated in getCreateDockerVolumeCommand
// and there should be a cache hit for DeviceRuntime spec.
verify(spyPlugin, times(0)).onDevicesAllocated(
allocatedDevice,
YarnRuntimeType.RUNTIME_DOCKER);
// ensure that allocation is get once from device mapping manager
verify(spyDmm, times(0)).getAllocatedDevices(resourceName,
c1.getContainerId());
assertEquals("0,1", drc.getEnv().get("NVIDIA_VISIBLE_DEVICES"));
assertTrue(drc.toString().contains("runtime=nvidia"));
// Third, cleanup in getCleanupDockerVolumesCommand
dcp.getCleanupDockerVolumesCommand(c1);
// Ensure device plugin's onDeviceReleased is invoked
verify(spyPlugin).onDevicesReleased(allocatedDevice);
// If we run the c1 again. No cache will be used for allocation and spec
dcp.getCreateDockerVolumeCommand(c1);
verify(spyDmm).getAllocatedDevices(resourceName, c1.getContainerId());
verify(spyPlugin).onDevicesAllocated(
allocatedDevice,
YarnRuntimeType.RUNTIME_DOCKER);
}
private static ContainerId getContainerId(int id) {
return ContainerId.newContainerId(ApplicationAttemptId
.newInstance(ApplicationId.newInstance(1234L, 1), 1), id);
}
private class MyPlugin implements DevicePlugin, DevicePluginScheduler {
private final static String RESOURCE_NAME = "cmpA.com/hdwA";
// v1 means the vendor uses the similar way of Nvidia Docker v1
// v2 means the vendor user the similar way of Nvidia Docker v2
private String devicePluginVersion = "v2";
public void setDevicePluginVersion(String version) {
devicePluginVersion = version;
}
@Override
public DeviceRegisterRequest getRegisterRequestInfo() {
return DeviceRegisterRequest.Builder.newInstance()
.setResourceName(RESOURCE_NAME)
.setPluginVersion("v1.0").build();
}
@Override
public Set<Device> getDevices() {
TreeSet<Device> r = new TreeSet<>();
r.add(Device.Builder.newInstance()
.setId(0)
.setDevPath("/dev/hdwA0")
.setMajorNumber(256)
.setMinorNumber(0)
.setBusID("0000:80:00.0")
.setHealthy(true)
.build());
r.add(Device.Builder.newInstance()
.setId(1)
.setDevPath("/dev/hdwA1")
.setMajorNumber(256)
.setMinorNumber(1)
.setBusID("0000:80:01.0")
.setHealthy(true)
.build());
r.add(Device.Builder.newInstance()
.setId(2)
.setDevPath("/dev/hdwA2")
.setMajorNumber(256)
.setMinorNumber(2)
.setBusID("0000:80:02.0")
.setHealthy(true)
.build());
return r;
}
@Override
public DeviceRuntimeSpec onDevicesAllocated(Set<Device> allocatedDevices,
YarnRuntimeType yarnRuntime) throws Exception {
if (yarnRuntime == YarnRuntimeType.RUNTIME_DEFAULT) {
return null;
}
if (yarnRuntime == YarnRuntimeType.RUNTIME_DOCKER) {
return generateSpec(devicePluginVersion, allocatedDevices);
}
return null;
}
private DeviceRuntimeSpec generateSpec(String version,
Set<Device> allocatedDevices) {
DeviceRuntimeSpec.Builder builder =
DeviceRuntimeSpec.Builder.newInstance();
if (version.equals("v1")) {
// Nvidia v1 examples like below. These info is get from Nvidia v1
// RESTful.
// --device=/dev/nvidiactl --device=/dev/nvidia-uvm
// --device=/dev/nvidia0
// --volume-driver=nvidia-docker
// --volume=nvidia_driver_352.68:/usr/local/nvidia:ro
String volumeDriverName = "nvidia-docker";
String volumeToBeCreated = "nvidia_driver_352.68";
String volumePathInContainer = "/usr/local/nvidia";
// describe volumes to be created and mounted
builder.addVolumeSpec(
VolumeSpec.Builder.newInstance()
.setVolumeDriver(volumeDriverName)
.setVolumeName(volumeToBeCreated)
.setVolumeOperation(VolumeSpec.CREATE).build())
.addMountVolumeSpec(
MountVolumeSpec.Builder.newInstance()
.setHostPath(volumeToBeCreated)
.setMountPath(volumePathInContainer)
.setReadOnly(true).build());
// describe devices to be mounted
for (Device device : allocatedDevices) {
builder.addMountDeviceSpec(
MountDeviceSpec.Builder.newInstance()
.setDevicePathInHost(device.getDevPath())
.setDevicePathInContainer(device.getDevPath())
.setDevicePermission(MountDeviceSpec.RW).build());
}
}
if (version.equals("v2")) {
String nvidiaRuntime = "nvidia";
String nvidiaVisibleDevices = "NVIDIA_VISIBLE_DEVICES";
StringBuilder gpuMinorNumbersSB = new StringBuilder();
for (Device device : allocatedDevices) {
gpuMinorNumbersSB.append(device.getMinorNumber() + ",");
}
String minorNumbers = gpuMinorNumbersSB.toString();
// set runtime and environment variable is enough for
// plugin like Nvidia Docker v2
builder.addEnv(nvidiaVisibleDevices,
minorNumbers.substring(0, minorNumbers.length() - 1))
.setContainerRuntime(nvidiaRuntime);
}
return builder.build();
}
@Override
public void onDevicesReleased(Set<Device> releasedDevices) {
// nothing to do
}
@Override
public Set<Device> allocateDevices(Set<Device> availableDevices,
int count, Map<String, String> env) {
Set<Device> allocated = new TreeSet<>();
int number = 0;
for (Device d : availableDevices) {
allocated.add(d);
number++;
if (number == count) {
break;
}
}
return allocated;
}
} // MyPlugin
}