DnsSrvHostProviderTest.java
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zookeeper.client;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import java.net.InetSocketAddress;
import java.util.HashSet;
import java.util.Set;
import org.apache.zookeeper.client.DnsSrvHostProvider.DnsSrvResolver;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.xbill.DNS.Name;
import org.xbill.DNS.SRVRecord;
public class DnsSrvHostProviderTest {
private static final String TEST_DNS_NAME = "_zookeeper._tcp.example.com.";
private static final long TEST_SEED = 12345L;
private DnsSrvResolver mockDnsSrvResolver;
@BeforeEach
public void setUp() {
mockDnsSrvResolver = mock(DnsSrvResolver.class);
}
@AfterEach
public void tearDown() {
System.clearProperty(ZKClientConfig.DNS_SRV_REFRESH_INTERVAL_SECONDS);
}
@Test
public void testBasic() throws Exception {
final SRVRecord[] srvRecords = createMockSrvRecords();
when(mockDnsSrvResolver.lookupSrvRecords(TEST_DNS_NAME)).thenReturn(srvRecords);
try (final DnsSrvHostProvider hostProvider = new DnsSrvHostProvider(TEST_DNS_NAME, TEST_SEED, mockDnsSrvResolver, null)) {
assertEquals(3, hostProvider.size());
assertNotNull(hostProvider.next(0));
}
}
@Test
public void testServerIteration() throws Exception {
final SRVRecord[] srvRecords = createMockSrvRecords();
when(mockDnsSrvResolver.lookupSrvRecords(TEST_DNS_NAME)).thenReturn(srvRecords);
try (final DnsSrvHostProvider hostProvider = new DnsSrvHostProvider(TEST_DNS_NAME, TEST_SEED, mockDnsSrvResolver, null)) {
final InetSocketAddress addr1 = hostProvider.next(0);
final InetSocketAddress addr2 = hostProvider.next(0);
final InetSocketAddress addr3 = hostProvider.next(0);
final Set<InetSocketAddress> actualAddresses = new HashSet<>();
actualAddresses.add(addr1);
actualAddresses.add(addr2);
actualAddresses.add(addr3);
final Set<InetSocketAddress> expectedAddresses = new HashSet<>();
expectedAddresses.add(new InetSocketAddress("server1.example.com", 2181));
expectedAddresses.add(new InetSocketAddress("server2.example.com", 2181));
expectedAddresses.add(new InetSocketAddress("server3.example.com", 2181));
assertEquals(expectedAddresses, actualAddresses);
// cycle back
final InetSocketAddress addr4 = hostProvider.next(0);
assertTrue(expectedAddresses.contains(addr4));
}
}
@Test
public void testEmptyDnsName() {
assertThrows(IllegalArgumentException.class,
() -> new DnsSrvHostProvider("", TEST_SEED, mockDnsSrvResolver, null));
assertThrows(IllegalArgumentException.class,
() -> new DnsSrvHostProvider(null, TEST_SEED, mockDnsSrvResolver, null));
assertThrows(IllegalArgumentException.class,
() -> new DnsSrvHostProvider(" ", TEST_SEED, mockDnsSrvResolver, null));
}
@Test
public void testNoSrvRecords() throws Exception {
when(mockDnsSrvResolver.lookupSrvRecords(TEST_DNS_NAME)).thenReturn(new SRVRecord[0]);
assertThrows(IllegalArgumentException.class,
() -> new DnsSrvHostProvider(TEST_DNS_NAME, TEST_SEED, mockDnsSrvResolver, null));
}
@Test
public void testDnsLookupFailure() throws Exception {
when(mockDnsSrvResolver.lookupSrvRecords(TEST_DNS_NAME))
.thenThrow(new java.io.IOException("DNS lookup failed"));
assertThrows(IllegalArgumentException.class,
() -> new DnsSrvHostProvider(TEST_DNS_NAME, TEST_SEED, mockDnsSrvResolver, null));
}
@Test
public void testInvalidPortFiltering() throws Exception {
// Create SRV record with invalid port (0)
final SRVRecord invalidPortRecord = createMockSrvRecord("server1.example.com.", 0);
final SRVRecord[] srvRecords = new SRVRecord[]{invalidPortRecord};
when(mockDnsSrvResolver.lookupSrvRecords(TEST_DNS_NAME)).thenReturn(srvRecords);
assertThrows(IllegalArgumentException.class,
() -> new DnsSrvHostProvider(TEST_DNS_NAME, TEST_SEED, mockDnsSrvResolver, null));
}
@Test
public void testTrailingDotRemoval() throws Exception {
final SRVRecord recordWithDot = createMockSrvRecord("server1.example.com.", 2181);
final SRVRecord[] srvRecords = new SRVRecord[]{recordWithDot};
when(mockDnsSrvResolver.lookupSrvRecords(TEST_DNS_NAME)).thenReturn(srvRecords);
try (final DnsSrvHostProvider hostProvider = new DnsSrvHostProvider(TEST_DNS_NAME, TEST_SEED, mockDnsSrvResolver, null)) {
assertEquals(1, hostProvider.size());
final InetSocketAddress addr = hostProvider.next(0);
// validate trailing dot is removed
assertEquals("server1.example.com", addr.getHostString());
}
}
@Test
public void testRefreshIntervalZeroDisablesPeriodicRefresh() throws Exception {
// Set system property to disable refresh
System.setProperty(ZKClientConfig.DNS_SRV_REFRESH_INTERVAL_SECONDS, "0");
final SRVRecord[] srvRecords = createMockSrvRecords();
when(mockDnsSrvResolver.lookupSrvRecords(TEST_DNS_NAME)).thenReturn(srvRecords);
try (final DnsSrvHostProvider hostProvider = new DnsSrvHostProvider(TEST_DNS_NAME, TEST_SEED, mockDnsSrvResolver, null)) {
// Verify initial setup works
assertEquals(3, hostProvider.size());
// Wait to ensure no background refresh occurs
Thread.sleep(1000);
// Verify DNS resolver was only called once during initialization (no periodic refresh)
verify(mockDnsSrvResolver, times(1)).lookupSrvRecords(TEST_DNS_NAME);
// Verify host provider still works normally
assertNotNull(hostProvider.next(0));
// Test multiple next() calls to ensure functionality is not affected
for (int i = 0; i < 5; i++) {
assertNotNull(hostProvider.next(0));
}
// Verify no additional DNS calls were made
verify(mockDnsSrvResolver, times(1)).lookupSrvRecords(TEST_DNS_NAME);
}
}
@Test
public void testRefreshIntervalNegative() throws Exception {
System.setProperty(ZKClientConfig.DNS_SRV_REFRESH_INTERVAL_SECONDS, "-1");
final SRVRecord[] srvRecords = createMockSrvRecords();
when(mockDnsSrvResolver.lookupSrvRecords(TEST_DNS_NAME)).thenReturn(srvRecords);
final IllegalArgumentException exception = assertThrows(IllegalArgumentException.class,
() -> new DnsSrvHostProvider(TEST_DNS_NAME, TEST_SEED, mockDnsSrvResolver, null));
assertEquals("Invalid DNS SRV refresh interval: -1", exception.getMessage());
}
private SRVRecord[] createMockSrvRecords() {
return new SRVRecord[]{
createMockSrvRecord("server1.example.com.", 2181),
createMockSrvRecord("server2.example.com.", 2181),
createMockSrvRecord("server3.example.com.", 2181)
};
}
private SRVRecord createMockSrvRecord(final String target, final int port) {
try {
final Name targetName = Name.fromString(target);
final Name serviceName = Name.fromString(TEST_DNS_NAME);
return new SRVRecord(serviceName, 1, 300, 1, 1, port, targetName);
} catch (final Exception e) {
throw new RuntimeException("Failed to create mock SRV record", e);
}
}
}