TestPolicyGenerator.java
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.server.globalpolicygenerator.policygenerator;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.server.federation.policies.manager.FederationPolicyManager;
import org.apache.hadoop.yarn.server.federation.policies.manager.WeightedLocalityPolicyManager;
import org.apache.hadoop.yarn.server.federation.store.FederationStateStore;
import org.apache.hadoop.yarn.server.federation.store.records.GetSubClusterPolicyConfigurationRequest;
import org.apache.hadoop.yarn.server.federation.store.records.GetSubClusterPolicyConfigurationResponse;
import org.apache.hadoop.yarn.server.federation.store.records.GetSubClustersInfoResponse;
import org.apache.hadoop.yarn.server.federation.store.records.SubClusterId;
import org.apache.hadoop.yarn.server.federation.store.records.SubClusterInfo;
import org.apache.hadoop.yarn.server.federation.store.records.SubClusterPolicyConfiguration;
import org.apache.hadoop.yarn.server.federation.store.records.SubClusterState;
import org.apache.hadoop.yarn.server.federation.utils.FederationStateStoreFacade;
import org.apache.hadoop.yarn.server.globalpolicygenerator.GPGContext;
import org.apache.hadoop.yarn.server.globalpolicygenerator.GPGContextImpl;
import org.apache.hadoop.yarn.server.globalpolicygenerator.GPGPolicyFacade;
import org.apache.hadoop.yarn.server.globalpolicygenerator.GPGUtils;
import org.apache.hadoop.yarn.server.resourcemanager.ResourceManager;
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.ResourceScheduler;
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.capacity.CapacityScheduler;
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.capacity.CapacitySchedulerConfiguration;
import org.apache.hadoop.yarn.server.resourcemanager.scheduler.capacity.QueuePath;
import org.apache.hadoop.yarn.server.resourcemanager.webapp.RMWSConsts;
import org.apache.hadoop.yarn.server.resourcemanager.webapp.dao.CapacitySchedulerInfo;
import org.apache.hadoop.yarn.server.resourcemanager.webapp.dao.ClusterMetricsInfo;
import org.apache.hadoop.yarn.server.resourcemanager.webapp.dao.SchedulerInfo;
import org.apache.hadoop.yarn.server.resourcemanager.webapp.dao.SchedulerTypeInfo;
import org.apache.hadoop.yarn.server.resourcemanager.webapp.dao.CapacitySchedulerQueueInfoList;
import org.apache.hadoop.yarn.server.resourcemanager.webapp.dao.CapacitySchedulerQueueInfo;
import org.apache.hadoop.yarn.webapp.util.WebAppUtils;
import org.glassfish.jersey.jettison.JettisonConfig;
import org.glassfish.jersey.jettison.JettisonJaxbContext;
import org.glassfish.jersey.jettison.JettisonUnmarshaller;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import javax.xml.bind.JAXBException;
import java.io.IOException;
import java.io.StringReader;
import java.net.InetSocketAddress;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/**
* Unit test for GPG Policy Generator.
*/
public class TestPolicyGenerator {
private static final int NUM_SC = 3;
private Configuration conf;
private FederationStateStore stateStore;
private FederationStateStoreFacade facade;
private List<SubClusterId> subClusterIds;
private Map<SubClusterId, SubClusterInfo> subClusterInfos;
private Map<SubClusterId, Map<Class, Object>> clusterInfos;
private Map<SubClusterId, SchedulerInfo> schedulerInfos;
private GPGContext gpgContext;
private PolicyGenerator policyGenerator;
public TestPolicyGenerator() {
conf = new Configuration();
conf.setInt(YarnConfiguration.FEDERATION_CACHE_TIME_TO_LIVE_SECS, 0);
facade = FederationStateStoreFacade.getInstance(conf);
gpgContext = new GPGContextImpl();
gpgContext.setPolicyFacade(new GPGPolicyFacade(facade, conf));
gpgContext.setStateStoreFacade(facade);
}
@BeforeEach
public void setUp() throws IOException, YarnException, JAXBException {
subClusterIds = new ArrayList<>();
subClusterInfos = new HashMap<>();
clusterInfos = new HashMap<>();
schedulerInfos = new HashMap<>();
CapacitySchedulerInfo sti1 =
readJSON("src/test/resources/schedulerInfo1.json",
CapacitySchedulerInfo.class);
CapacitySchedulerInfo sti2 =
readJSON("src/test/resources/schedulerInfo2.json",
CapacitySchedulerInfo.class);
// Set up sub clusters
for (int i = 0; i < NUM_SC; ++i) {
// Sub cluster Id
SubClusterId id = SubClusterId.newInstance("sc" + i);
subClusterIds.add(id);
// Sub cluster info
SubClusterInfo cluster = SubClusterInfo
.newInstance(id, "amrm:" + i, "clientrm:" + i, "rmadmin:" + i,
"rmweb:" + i, SubClusterState.SC_RUNNING, 0, "");
subClusterInfos.put(id, cluster);
// Cluster metrics info
ClusterMetricsInfo metricsInfo = new ClusterMetricsInfo();
metricsInfo.setAppsPending(2000);
if (!clusterInfos.containsKey(id)) {
clusterInfos.put(id, new HashMap<>());
}
clusterInfos.get(id).put(ClusterMetricsInfo.class, metricsInfo);
schedulerInfos.put(id, sti1);
}
// Change one of the sub cluster schedulers
schedulerInfos.put(subClusterIds.get(0), sti2);
stateStore = mock(FederationStateStore.class);
when(stateStore.getSubClusters(any()))
.thenReturn(GetSubClustersInfoResponse.newInstance(
new ArrayList<>(subClusterInfos.values())));
facade.reinitialize(stateStore, conf);
}
@AfterEach
public void tearDown() throws Exception {
stateStore.close();
stateStore = null;
}
private <T> T readJSON(String pathname, Class<T> classy)
throws IOException, JAXBException {
JettisonJaxbContext jaxbContext = new JettisonJaxbContext(JettisonConfig.DEFAULT, classy);
String contents = new String(Files.readAllBytes(Paths.get(pathname)));
JettisonUnmarshaller unmarshaller = jaxbContext.createJsonUnmarshaller();
return unmarshaller.unmarshalFromJSON(new StringReader(contents), classy);
}
@Test
public void testPolicyGenerator() {
policyGenerator = new TestablePolicyGenerator();
policyGenerator.setPolicy(mock(GlobalPolicy.class));
policyGenerator.run();
verify(policyGenerator.getPolicy(), times(1))
.updatePolicy("default", clusterInfos, null);
verify(policyGenerator.getPolicy(), times(1))
.updatePolicy("default2", clusterInfos, null);
}
@Test
public void testBlacklist() {
conf.set(YarnConfiguration.GPG_POLICY_GENERATOR_BLACKLIST,
subClusterIds.get(0).toString());
Map<SubClusterId, Map<Class, Object>> blacklistedCMI =
new HashMap<>(clusterInfos);
blacklistedCMI.remove(subClusterIds.get(0));
policyGenerator = new TestablePolicyGenerator();
policyGenerator.setPolicy(mock(GlobalPolicy.class));
policyGenerator.run();
verify(policyGenerator.getPolicy(), times(1))
.updatePolicy("default", blacklistedCMI, null);
verify(policyGenerator.getPolicy(), times(0))
.updatePolicy("default", clusterInfos, null);
}
@Test
public void testBlacklistTwo() {
conf.set(YarnConfiguration.GPG_POLICY_GENERATOR_BLACKLIST,
subClusterIds.get(0).toString() + "," + subClusterIds.get(1)
.toString());
Map<SubClusterId, Map<Class, Object>> blacklistedCMI =
new HashMap<>(clusterInfos);
blacklistedCMI.remove(subClusterIds.get(0));
blacklistedCMI.remove(subClusterIds.get(1));
policyGenerator = new TestablePolicyGenerator();
policyGenerator.setPolicy(mock(GlobalPolicy.class));
policyGenerator.run();
verify(policyGenerator.getPolicy(), times(1))
.updatePolicy("default", blacklistedCMI, null);
verify(policyGenerator.getPolicy(), times(0))
.updatePolicy("default", clusterInfos, null);
}
@Test
public void testExistingPolicy() throws YarnException {
WeightedLocalityPolicyManager manager = new WeightedLocalityPolicyManager();
// Add a test policy for test queue
manager.setQueue("default");
manager.getWeightedPolicyInfo().setAMRMPolicyWeights(GPGUtils
.createUniformWeights(new HashSet<>(subClusterIds)));
manager.getWeightedPolicyInfo().setRouterPolicyWeights(GPGUtils
.createUniformWeights(new HashSet<>(subClusterIds)));
SubClusterPolicyConfiguration testConf = manager.serializeConf();
when(stateStore.getPolicyConfiguration(
GetSubClusterPolicyConfigurationRequest.newInstance("default")))
.thenReturn(
GetSubClusterPolicyConfigurationResponse.newInstance(testConf));
policyGenerator = new TestablePolicyGenerator();
policyGenerator.setPolicy(mock(GlobalPolicy.class));
policyGenerator.run();
ArgumentCaptor<FederationPolicyManager> argCaptor =
ArgumentCaptor.forClass(FederationPolicyManager.class);
verify(policyGenerator.getPolicy(), times(1))
.updatePolicy(eq("default"), eq(clusterInfos), argCaptor.capture());
argCaptor.getValue().setWeightedPolicyInfo(manager.getWeightedPolicyInfo());
assertEquals(argCaptor.getValue().getClass(), manager.getClass());
assertEquals(argCaptor.getValue().serializeConf(), manager.serializeConf());
}
@Test
public void testCallRM() {
CapacitySchedulerConfiguration csConf =
new CapacitySchedulerConfiguration();
final QueuePath a = new QueuePath(CapacitySchedulerConfiguration.ROOT + ".a");
final QueuePath b = new QueuePath(CapacitySchedulerConfiguration.ROOT + ".b");
final QueuePath a1 = new QueuePath(a + ".a1");
final QueuePath a2 = new QueuePath(a + ".a2");
final QueuePath b1 = new QueuePath(b + ".b1");
final QueuePath b2 = new QueuePath(b + ".b2");
final QueuePath b3 = new QueuePath(b + ".b3");
float aCapacity = 10.5f;
float bCapacity = 89.5f;
float a1Capacity = 30;
float a2Capacity = 70;
float b1Capacity = 79.2f;
float b2Capacity = 0.8f;
float b3Capacity = 20;
// Define top-level queues
csConf.setQueues(new QueuePath(CapacitySchedulerConfiguration.ROOT),
new String[] {"a", "b"});
csConf.setCapacity(a, aCapacity);
csConf.setCapacity(b, bCapacity);
// Define 2nd-level queues
csConf.setQueues(a, new String[] {"a1", "a2"});
csConf.setCapacity(a1, a1Capacity);
csConf.setUserLimitFactor(a1, 100.0f);
csConf.setCapacity(a2, a2Capacity);
csConf.setUserLimitFactor(a2, 100.0f);
csConf.setQueues(b, new String[] {"b1", "b2", "b3"});
csConf.setCapacity(b1, b1Capacity);
csConf.setUserLimitFactor(b1, 100.0f);
csConf.setCapacity(b2, b2Capacity);
csConf.setUserLimitFactor(b2, 100.0f);
csConf.setCapacity(b3, b3Capacity);
csConf.setUserLimitFactor(b3, 100.0f);
YarnConfiguration rmConf = new YarnConfiguration(csConf);
ResourceManager resourceManager = new ResourceManager();
rmConf.setClass(YarnConfiguration.RM_SCHEDULER, CapacityScheduler.class,
ResourceScheduler.class);
resourceManager.init(rmConf);
resourceManager.start();
String rmAddress = WebAppUtils.getRMWebAppURLWithScheme(this.conf);
String webAppAddress = getServiceAddress(NetUtils.createSocketAddr(rmAddress));
SchedulerTypeInfo sti = GPGUtils.invokeRMWebService(webAppAddress, RMWSConsts.SCHEDULER,
SchedulerTypeInfo.class, conf);
Assertions.assertNotNull(sti);
SchedulerInfo schedulerInfo = sti.getSchedulerInfo();
Assertions.assertTrue(schedulerInfo instanceof CapacitySchedulerInfo);
CapacitySchedulerInfo capacitySchedulerInfo = (CapacitySchedulerInfo) schedulerInfo;
Assertions.assertNotNull(capacitySchedulerInfo);
CapacitySchedulerQueueInfoList queues = capacitySchedulerInfo.getQueues();
Assertions.assertNotNull(queues);
ArrayList<CapacitySchedulerQueueInfo> queueInfoList = queues.getQueueInfoList();
Assertions.assertNotNull(queueInfoList);
Assertions.assertEquals(2, queueInfoList.size());
CapacitySchedulerQueueInfo queueA = queueInfoList.get(0);
Assertions.assertNotNull(queueA);
Assertions.assertEquals("root.a", queueA.getQueuePath());
Assertions.assertEquals(10.5f, queueA.getCapacity(), 0.00001);
CapacitySchedulerQueueInfoList queueAQueues = queueA.getQueues();
Assertions.assertNotNull(queueAQueues);
ArrayList<CapacitySchedulerQueueInfo> queueInfoAList = queueAQueues.getQueueInfoList();
Assertions.assertNotNull(queueInfoAList);
Assertions.assertEquals(2, queueInfoAList.size());
CapacitySchedulerQueueInfo queueA1 = queueInfoAList.get(0);
Assertions.assertNotNull(queueA1);
Assertions.assertEquals(30f, queueA1.getCapacity(), 0.00001);
CapacitySchedulerQueueInfo queueA2 = queueInfoAList.get(1);
Assertions.assertNotNull(queueA2);
Assertions.assertEquals(70f, queueA2.getCapacity(), 0.00001);
CapacitySchedulerQueueInfo queueB = queueInfoList.get(1);
Assertions.assertNotNull(queueB);
Assertions.assertEquals("root.b", queueB.getQueuePath());
Assertions.assertEquals(89.5f, queueB.getCapacity(), 0.00001);
CapacitySchedulerQueueInfoList queueBQueues = queueB.getQueues();
Assertions.assertNotNull(queueBQueues);
ArrayList<CapacitySchedulerQueueInfo> queueInfoBList = queueBQueues.getQueueInfoList();
Assertions.assertNotNull(queueInfoBList);
Assertions.assertEquals(3, queueInfoBList.size());
CapacitySchedulerQueueInfo queueB1 = queueInfoBList.get(0);
Assertions.assertNotNull(queueB1);
Assertions.assertEquals(79.2f, queueB1.getCapacity(), 0.00001);
CapacitySchedulerQueueInfo queueB2 = queueInfoBList.get(1);
Assertions.assertNotNull(queueB2);
Assertions.assertEquals(0.8f, queueB2.getCapacity(), 0.00001);
CapacitySchedulerQueueInfo queueB3 = queueInfoBList.get(2);
Assertions.assertNotNull(queueB3);
Assertions.assertEquals(20f, queueB3.getCapacity(), 0.00001);
}
private String getServiceAddress(InetSocketAddress address) {
InetSocketAddress socketAddress = NetUtils.getConnectAddress(address);
return socketAddress.getAddress().getHostAddress() + ":" + socketAddress.getPort();
}
/**
* Testable policy generator overrides the methods that communicate
* with the RM REST endpoint, allowing us to inject faked responses.
*/
class TestablePolicyGenerator extends PolicyGenerator {
TestablePolicyGenerator() {
super(conf, gpgContext);
}
@Override
protected Map<SubClusterId, Map<Class, Object>> getInfos(
Map<SubClusterId, SubClusterInfo> activeSubClusters) {
Map<SubClusterId, Map<Class, Object>> ret = new HashMap<>();
for (SubClusterId id : activeSubClusters.keySet()) {
if (!ret.containsKey(id)) {
ret.put(id, new HashMap<>());
}
ret.get(id).put(ClusterMetricsInfo.class,
clusterInfos.get(id).get(ClusterMetricsInfo.class));
}
return ret;
}
@Override
protected Map<SubClusterId, SchedulerInfo> getSchedulerInfo(
Map<SubClusterId, SubClusterInfo> activeSubClusters) {
Map<SubClusterId, SchedulerInfo> ret = new HashMap<>();
for (SubClusterId id : activeSubClusters.keySet()) {
ret.put(id, schedulerInfos.get(id));
}
return ret;
}
}
}