TestCsiAdaptorService.java
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.yarn.csi.adaptor;
import csi.v0.Csi;
import org.apache.commons.io.FileUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.service.ServiceStateException;
import org.apache.hadoop.test.GenericTestUtils;
import org.apache.hadoop.thirdparty.com.google.common.collect.ImmutableList;
import org.apache.hadoop.thirdparty.com.google.common.collect.ImmutableMap;
import org.apache.hadoop.yarn.api.CsiAdaptorProtocol;
import org.apache.hadoop.yarn.api.CsiAdaptorPlugin;
import org.apache.hadoop.yarn.api.impl.pb.client.CsiAdaptorProtocolPBClientImpl;
import org.apache.hadoop.yarn.api.protocolrecords.GetPluginInfoRequest;
import org.apache.hadoop.yarn.api.protocolrecords.GetPluginInfoResponse;
import org.apache.hadoop.yarn.api.protocolrecords.NodePublishVolumeRequest;
import org.apache.hadoop.yarn.api.protocolrecords.NodePublishVolumeResponse;
import org.apache.hadoop.yarn.api.protocolrecords.NodeUnpublishVolumeRequest;
import org.apache.hadoop.yarn.api.protocolrecords.NodeUnpublishVolumeResponse;
import org.apache.hadoop.yarn.api.protocolrecords.ValidateVolumeCapabilitiesRequest;
import org.apache.hadoop.yarn.api.protocolrecords.ValidateVolumeCapabilitiesResponse;
import org.apache.hadoop.yarn.api.protocolrecords.impl.pb.ValidateVolumeCapabilitiesRequestPBImpl;
import org.apache.hadoop.yarn.client.NMProxy;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.exceptions.YarnException;
import org.apache.hadoop.yarn.ipc.YarnRPC;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import java.io.File;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.apache.hadoop.yarn.api.protocolrecords.ValidateVolumeCapabilitiesRequest.AccessMode.MULTI_NODE_MULTI_WRITER;
import static org.apache.hadoop.yarn.api.protocolrecords.ValidateVolumeCapabilitiesRequest.VolumeType.FILE_SYSTEM;
/**
* UT for {@link CsiAdaptorProtocolService}.
*/
public class TestCsiAdaptorService {
private static File testRoot = null;
private static String domainSocket = null;
@BeforeAll
public static void setUp() throws IOException {
testRoot = GenericTestUtils.getTestDir("csi-test");
File socketPath = new File(testRoot, "csi.sock");
FileUtils.forceMkdirParent(socketPath);
domainSocket = "unix://" + socketPath.getAbsolutePath();
}
@AfterAll
public static void tearDown() throws IOException {
if (testRoot != null) {
FileUtils.deleteDirectory(testRoot);
}
}
private interface FakeCsiAdaptor extends CsiAdaptorPlugin {
default void init(String driverName, Configuration conf)
throws YarnException {
return;
}
default String getDriverName() {
return null;
}
default GetPluginInfoResponse getPluginInfo(GetPluginInfoRequest request)
throws YarnException, IOException {
return null;
}
default ValidateVolumeCapabilitiesResponse validateVolumeCapacity(
ValidateVolumeCapabilitiesRequest request) throws YarnException,
IOException {
return null;
}
default NodePublishVolumeResponse nodePublishVolume(
NodePublishVolumeRequest request) throws YarnException, IOException {
return null;
}
default NodeUnpublishVolumeResponse nodeUnpublishVolume(
NodeUnpublishVolumeRequest request) throws YarnException, IOException{
return null;
}
}
@Test
void testValidateVolume() throws IOException, YarnException {
ServerSocket ss = new ServerSocket(0);
ss.close();
InetSocketAddress address = new InetSocketAddress(ss.getLocalPort());
Configuration conf = new Configuration();
conf.setSocketAddr(
YarnConfiguration.NM_CSI_ADAPTOR_PREFIX + "test-driver.address",
address);
conf.set(
YarnConfiguration.NM_CSI_DRIVER_PREFIX + "test-driver.endpoint",
"unix:///tmp/test-driver.sock");
// inject a fake CSI adaptor
// this client validates if the ValidateVolumeCapabilitiesRequest
// is integrity, and then reply a fake response
CsiAdaptorPlugin plugin = new FakeCsiAdaptor() {
@Override
public String getDriverName() {
return "test-driver";
}
@Override
public GetPluginInfoResponse getPluginInfo(GetPluginInfoRequest request) {
return GetPluginInfoResponse.newInstance("test-plugin", "0.1");
}
@Override
public ValidateVolumeCapabilitiesResponse validateVolumeCapacity(
ValidateVolumeCapabilitiesRequest request) throws YarnException,
IOException {
// validate we get all info from the request
assertEquals("volume-id-0000123", request.getVolumeId());
assertEquals(1, request.getVolumeCapabilities().size());
assertEquals(Csi.VolumeCapability.AccessMode
.newBuilder().setModeValue(5).build().getMode().name(),
request.getVolumeCapabilities().get(0).getAccessMode().name());
assertEquals(2, request.getVolumeCapabilities().get(0)
.getMountFlags().size());
assertTrue(request.getVolumeCapabilities().get(0)
.getMountFlags().contains("mountFlag1"));
assertTrue(request.getVolumeCapabilities().get(0)
.getMountFlags().contains("mountFlag2"));
assertEquals(2, request.getVolumeAttributes().size());
assertEquals("v1", request.getVolumeAttributes().get("k1"));
assertEquals("v2", request.getVolumeAttributes().get("k2"));
// return a fake result
return ValidateVolumeCapabilitiesResponse
.newInstance(false, "this is a test");
}
};
CsiAdaptorProtocolService service =
new CsiAdaptorProtocolService(plugin);
service.init(conf);
service.start();
try (CsiAdaptorProtocolPBClientImpl client =
new CsiAdaptorProtocolPBClientImpl(1L, address, new Configuration())) {
ValidateVolumeCapabilitiesRequest request =
ValidateVolumeCapabilitiesRequestPBImpl
.newInstance("volume-id-0000123",
ImmutableList.of(
new ValidateVolumeCapabilitiesRequest
.VolumeCapability(
MULTI_NODE_MULTI_WRITER, FILE_SYSTEM,
ImmutableList.of("mountFlag1", "mountFlag2"))),
ImmutableMap.of("k1", "v1", "k2", "v2"));
ValidateVolumeCapabilitiesResponse response = client
.validateVolumeCapacity(request);
assertEquals(false, response.isSupported());
assertEquals("this is a test", response.getResponseMessage());
} finally {
service.stop();
}
}
@Test
void testValidateVolumeWithNMProxy() throws Exception {
ServerSocket ss = new ServerSocket(0);
ss.close();
InetSocketAddress address = new InetSocketAddress(ss.getLocalPort());
Configuration conf = new Configuration();
conf.setSocketAddr(
YarnConfiguration.NM_CSI_ADAPTOR_PREFIX + "test-driver.address",
address);
conf.set(
YarnConfiguration.NM_CSI_DRIVER_PREFIX + "test-driver.endpoint",
"unix:///tmp/test-driver.sock");
// inject a fake CSI adaptor
// this client validates if the ValidateVolumeCapabilitiesRequest
// is integrity, and then reply a fake response
FakeCsiAdaptor plugin = new FakeCsiAdaptor() {
@Override
public String getDriverName() {
return "test-driver";
}
@Override
public GetPluginInfoResponse getPluginInfo(
GetPluginInfoRequest request) throws YarnException, IOException {
return GetPluginInfoResponse.newInstance("test-plugin", "0.1");
}
@Override
public ValidateVolumeCapabilitiesResponse validateVolumeCapacity(
ValidateVolumeCapabilitiesRequest request)
throws YarnException, IOException {
// validate we get all info from the request
assertEquals("volume-id-0000123", request.getVolumeId());
assertEquals(1, request.getVolumeCapabilities().size());
assertEquals(
Csi.VolumeCapability.AccessMode.newBuilder().setModeValue(5)
.build().getMode().name(),
request.getVolumeCapabilities().get(0).getAccessMode().name());
assertEquals(2,
request.getVolumeCapabilities().get(0).getMountFlags().size());
assertTrue(request.getVolumeCapabilities().get(0).getMountFlags()
.contains("mountFlag1"));
assertTrue(request.getVolumeCapabilities().get(0).getMountFlags()
.contains("mountFlag2"));
assertEquals(2, request.getVolumeAttributes().size());
assertEquals("v1", request.getVolumeAttributes().get("k1"));
assertEquals("v2", request.getVolumeAttributes().get("k2"));
// return a fake result
return ValidateVolumeCapabilitiesResponse
.newInstance(false, "this is a test");
}
};
CsiAdaptorProtocolService service =
new CsiAdaptorProtocolService(plugin);
service.init(conf);
service.start();
YarnRPC rpc = YarnRPC.create(conf);
UserGroupInformation currentUser = UserGroupInformation.getCurrentUser();
CsiAdaptorProtocol adaptorClient = NMProxy
.createNMProxy(conf, CsiAdaptorProtocol.class, currentUser, rpc,
NetUtils.createSocketAddrForHost("localhost", ss.getLocalPort()));
ValidateVolumeCapabilitiesRequest request =
ValidateVolumeCapabilitiesRequestPBImpl
.newInstance("volume-id-0000123",
ImmutableList.of(new ValidateVolumeCapabilitiesRequest
.VolumeCapability(
MULTI_NODE_MULTI_WRITER, FILE_SYSTEM,
ImmutableList.of("mountFlag1", "mountFlag2"))),
ImmutableMap.of("k1", "v1", "k2", "v2"));
ValidateVolumeCapabilitiesResponse response = adaptorClient
.validateVolumeCapacity(request);
assertEquals(false, response.isSupported());
assertEquals("this is a test", response.getResponseMessage());
service.stop();
}
@Test
void testMissingConfiguration() {
assertThrows(ServiceStateException.class, () -> {
Configuration conf = new Configuration();
CsiAdaptorProtocolService service =
new CsiAdaptorProtocolService(new FakeCsiAdaptor() {
});
service.init(conf);
});
}
@Test
void testInvalidServicePort() {
assertThrows(ServiceStateException.class, () -> {
Configuration conf = new Configuration();
conf.set(YarnConfiguration.NM_CSI_ADAPTOR_PREFIX
+ "test-driver-0001.address",
"0.0.0.0:-100"); // this is an invalid address
CsiAdaptorProtocolService service =
new CsiAdaptorProtocolService(new FakeCsiAdaptor() {
});
service.init(conf);
});
}
@Test
void testInvalidHost() {
assertThrows(ServiceStateException.class, () -> {
Configuration conf = new Configuration();
conf.set(YarnConfiguration.NM_CSI_ADAPTOR_PREFIX
+ "test-driver-0001.address",
"192.0.1:8999"); // this is an invalid ip address
CsiAdaptorProtocolService service =
new CsiAdaptorProtocolService(new FakeCsiAdaptor() {
});
service.init(conf);
});
}
@Test
void testCustomizedAdaptor() throws IOException, YarnException {
ServerSocket ss = new ServerSocket(0);
ss.close();
InetSocketAddress address = new InetSocketAddress(ss.getLocalPort());
Configuration conf = new Configuration();
conf.set(YarnConfiguration.NM_CSI_DRIVER_NAMES, "customized-driver");
conf.setSocketAddr(
YarnConfiguration.NM_CSI_ADAPTOR_PREFIX + "customized-driver.address",
address);
conf.set(
YarnConfiguration.NM_CSI_ADAPTOR_PREFIX + "customized-driver.class",
"org.apache.hadoop.yarn.csi.adaptor.MockCsiAdaptor");
conf.set(
YarnConfiguration.NM_CSI_DRIVER_PREFIX + "customized-driver.endpoint",
"unix:///tmp/customized-driver.sock");
CsiAdaptorServices services =
new CsiAdaptorServices();
services.init(conf);
services.start();
YarnRPC rpc = YarnRPC.create(conf);
UserGroupInformation currentUser = UserGroupInformation.getCurrentUser();
CsiAdaptorProtocol adaptorClient = NMProxy
.createNMProxy(conf, CsiAdaptorProtocol.class, currentUser, rpc,
NetUtils.createSocketAddrForHost("localhost", ss.getLocalPort()));
// Test getPluginInfo
GetPluginInfoResponse pluginInfo =
adaptorClient.getPluginInfo(GetPluginInfoRequest.newInstance());
assertThat(pluginInfo.getDriverName()).isEqualTo("customized-driver");
assertThat(pluginInfo.getVersion()).isEqualTo("1.0");
// Test validateVolumeCapacity
ValidateVolumeCapabilitiesRequest request =
ValidateVolumeCapabilitiesRequestPBImpl
.newInstance("volume-id-0000123",
ImmutableList.of(new ValidateVolumeCapabilitiesRequest
.VolumeCapability(
MULTI_NODE_MULTI_WRITER, FILE_SYSTEM,
ImmutableList.of("mountFlag1", "mountFlag2"))),
ImmutableMap.of("k1", "v1", "k2", "v2"));
ValidateVolumeCapabilitiesResponse response = adaptorClient
.validateVolumeCapacity(request);
assertEquals(true, response.isSupported());
assertEquals("verified via MockCsiAdaptor",
response.getResponseMessage());
services.stop();
}
@Test
void testMultipleCsiAdaptors() throws IOException, YarnException {
ServerSocket driver1Addr = new ServerSocket(0);
ServerSocket driver2Addr = new ServerSocket(0);
InetSocketAddress address1 =
new InetSocketAddress(driver1Addr.getLocalPort());
InetSocketAddress address2 =
new InetSocketAddress(driver2Addr.getLocalPort());
Configuration conf = new Configuration();
// Two csi-drivers configured
conf.set(YarnConfiguration.NM_CSI_DRIVER_NAMES,
"customized-driver-1,customized-driver-2");
// customized-driver-1
conf.setSocketAddr(YarnConfiguration.NM_CSI_ADAPTOR_PREFIX
+ "customized-driver-1.address", address1);
conf.set(YarnConfiguration.NM_CSI_ADAPTOR_PREFIX
+ "customized-driver-1.class",
"org.apache.hadoop.yarn.csi.adaptor.MockCsiAdaptor");
conf.set(YarnConfiguration.NM_CSI_DRIVER_PREFIX
+ "customized-driver-1.endpoint",
"unix:///tmp/customized-driver-1.sock");
// customized-driver-2
conf.setSocketAddr(YarnConfiguration.NM_CSI_ADAPTOR_PREFIX
+ "customized-driver-2.address", address2);
conf.set(YarnConfiguration.NM_CSI_ADAPTOR_PREFIX
+ "customized-driver-2.class",
"org.apache.hadoop.yarn.csi.adaptor.MockCsiAdaptor");
conf.set(YarnConfiguration.NM_CSI_DRIVER_PREFIX
+ "customized-driver-2.endpoint",
"unix:///tmp/customized-driver-2.sock");
driver1Addr.close();
driver2Addr.close();
CsiAdaptorServices services =
new CsiAdaptorServices();
services.init(conf);
services.start();
YarnRPC rpc = YarnRPC.create(conf);
UserGroupInformation currentUser = UserGroupInformation.getCurrentUser();
CsiAdaptorProtocol client1 = NMProxy
.createNMProxy(conf, CsiAdaptorProtocol.class, currentUser, rpc,
NetUtils.createSocketAddrForHost("localhost",
driver1Addr.getLocalPort()));
// ***************************************************
// Verify talking with customized-driver-1
// ***************************************************
// Test getPluginInfo
GetPluginInfoResponse pluginInfo =
client1.getPluginInfo(GetPluginInfoRequest.newInstance());
assertThat(pluginInfo.getDriverName()).isEqualTo("customized-driver-1");
assertThat(pluginInfo.getVersion()).isEqualTo("1.0");
// Test validateVolumeCapacity
ValidateVolumeCapabilitiesRequest request =
ValidateVolumeCapabilitiesRequestPBImpl
.newInstance("driver-1-volume-00001",
ImmutableList.of(new ValidateVolumeCapabilitiesRequest
.VolumeCapability(
MULTI_NODE_MULTI_WRITER, FILE_SYSTEM,
ImmutableList.of())), ImmutableMap.of());
ValidateVolumeCapabilitiesResponse response = client1
.validateVolumeCapacity(request);
assertEquals(true, response.isSupported());
assertEquals("verified via MockCsiAdaptor",
response.getResponseMessage());
// ***************************************************
// Verify talking with customized-driver-2
// ***************************************************
CsiAdaptorProtocol client2 = NMProxy
.createNMProxy(conf, CsiAdaptorProtocol.class, currentUser, rpc,
NetUtils.createSocketAddrForHost("localhost",
driver2Addr.getLocalPort()));
GetPluginInfoResponse pluginInfo2 =
client2.getPluginInfo(GetPluginInfoRequest.newInstance());
assertThat(pluginInfo2.getDriverName()).isEqualTo("customized-driver-2");
assertThat(pluginInfo2.getVersion()).isEqualTo("1.0");
services.stop();
}
}