TestRaceWhenRelogin.java
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security;
import static org.junit.Assert.assertTrue;
import java.io.File;
import java.io.IOException;
import java.security.PrivilegedExceptionAction;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.security.auth.kerberos.KerberosTicket;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.minikdc.KerberosSecurityTestcase;
import org.apache.hadoop.security.SaslRpcServer.AuthMethod;
import org.apache.hadoop.security.SaslRpcServer.QualityOfProtection;
import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod;
import org.junit.Before;
import org.junit.Test;
/**
* Testcase for HADOOP-13433 that confirms that tgt will always be the first
* ticket after relogin.
*/
public class TestRaceWhenRelogin extends KerberosSecurityTestcase {
private int numThreads = 10;
private String clientPrincipal = "client";
private String serverProtocol = "server";
private String[] serverProtocols;
private String host = "localhost";
private String serverPrincipal = serverProtocol + "/" + host;
private String[] serverPrincipals;
private File keytabFile;
private Configuration conf = new Configuration();
private Map<String, String> props;
private UserGroupInformation ugi;
@Before
public void setUp() throws Exception {
keytabFile = new File(getWorkDir(), "keytab");
serverProtocols = new String[numThreads];
serverPrincipals = new String[numThreads];
for (int i = 0; i < numThreads; i++) {
serverProtocols[i] = serverProtocol + i;
serverPrincipals[i] = serverProtocols[i] + "/" + host;
}
String[] principals =
Arrays.copyOf(serverPrincipals, serverPrincipals.length + 2);
principals[numThreads] = serverPrincipal;
principals[numThreads + 1] = clientPrincipal;
getKdc().createPrincipal(keytabFile, principals);
SecurityUtil.setAuthenticationMethod(AuthenticationMethod.KERBEROS, conf);
UserGroupInformation.setConfiguration(conf);
UserGroupInformation.setShouldRenewImmediatelyForTests(true);
props = new HashMap<String, String>();
props.put(Sasl.QOP, QualityOfProtection.AUTHENTICATION.saslQop);
ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(clientPrincipal,
keytabFile.getAbsolutePath());
}
private void relogin(AtomicBoolean pass) {
for (int i = 0; i < 100; i++) {
try {
ugi.reloginFromKeytab();
} catch (IOException e) {
}
KerberosTicket tgt = ugi.getSubject().getPrivateCredentials().stream()
.filter(c -> c instanceof KerberosTicket).map(c -> (KerberosTicket) c)
.findFirst().get();
if (!tgt.getServer().getName().startsWith("krbtgt")) {
pass.set(false);
return;
}
try {
Thread.sleep(50);
} catch (InterruptedException e) {
}
}
}
private void getServiceTicket(AtomicBoolean running, String serverProtocol) {
while (running.get()) {
try {
ugi.doAs(new PrivilegedExceptionAction<Void>() {
@Override
public Void run() throws Exception {
SaslClient client = Sasl.createSaslClient(
new String[] {AuthMethod.KERBEROS.getMechanismName()},
clientPrincipal, serverProtocol, host, props, null);
client.evaluateChallenge(new byte[0]);
client.dispose();
return null;
}
});
} catch (Exception e) {
}
try {
Thread.sleep(ThreadLocalRandom.current().nextInt(100));
} catch (InterruptedException e) {
}
}
}
@Test
public void test() throws InterruptedException, IOException {
AtomicBoolean pass = new AtomicBoolean(true);
Thread reloginThread = new Thread(() -> relogin(pass), "Relogin");
AtomicBoolean running = new AtomicBoolean(true);
Thread[] getServiceTicketThreads = new Thread[numThreads];
for (int i = 0; i < numThreads; i++) {
String serverProtocol = serverProtocols[i];
getServiceTicketThreads[i] =
new Thread(() -> getServiceTicket(running, serverProtocol),
"GetServiceTicket-" + i);
}
for (Thread getServiceTicketThread : getServiceTicketThreads) {
getServiceTicketThread.start();
}
reloginThread.start();
reloginThread.join();
running.set(false);
for (Thread getServiceTicketThread : getServiceTicketThreads) {
getServiceTicketThread.join();
}
assertTrue("tgt is not the first ticket after relogin", pass.get());
}
}