BasePBImplRecordsTest.java

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.yarn.api;

import java.lang.reflect.Array;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.nio.ByteBuffer;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;

import org.apache.hadoop.thirdparty.com.google.common.collect.Maps;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.commons.lang3.Range;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.util.Lists;
import org.apache.hadoop.util.Sets;
import org.apache.hadoop.yarn.api.resource.PlacementConstraint;
import org.apache.hadoop.yarn.api.resource.PlacementConstraints;

import static org.apache.hadoop.yarn.api.resource.PlacementConstraints.NODE;
import static org.apache.hadoop.yarn.api.resource.PlacementConstraints.PlacementTargets.allocationTag;
import static org.apache.hadoop.yarn.api.resource.PlacementConstraints.targetIn;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.fail;

/**
 * Generic helper class to validate protocol records.
 */
public class BasePBImplRecordsTest {
  private static final Logger LOG =
      LoggerFactory.getLogger(BasePBImplRecordsTest.class);

  @SuppressWarnings("checkstyle:visibilitymodifier")
  protected static HashMap<Type, Object> typeValueCache =
      new HashMap<Type, Object>();
  @SuppressWarnings("checkstyle:visibilitymodifier")
  protected static HashMap<Type, List<String>> excludedPropertiesMap =
      new HashMap<>();
  private static Random rand = new Random();
  private static byte [] bytes = new byte[] {'1', '2', '3', '4'};

  @SuppressWarnings({"rawtypes", "unchecked"})
  private static Object genTypeValue(Type type) {
    Object ret = typeValueCache.get(type);
    if (ret != null) {
      return ret;
    }
    // only use positive primitive values
    if (type.equals(boolean.class)) {
      return rand.nextBoolean();
    } else if (type.equals(byte.class)) {
      return bytes[rand.nextInt(4)];
    } else if (type.equals(int.class) || type.equals(Integer.class)) {
      return rand.nextInt(1000000);
    } else if (type.equals(long.class) || type.equals(Long.class)) {
      return Long.valueOf(rand.nextInt(1000000));
    } else if (type.equals(float.class)) {
      return rand.nextFloat();
    } else if (type.equals(double.class)) {
      return rand.nextDouble();
    } else if (type.equals(String.class)) {
      return String.format("%c%c%c",
          'a' + rand.nextInt(26),
          'a' + rand.nextInt(26),
          'a' + rand.nextInt(26));
    } else if (type.equals(Float.class)) {
      return rand.nextFloat();
    } else if (type.equals(Text.class)) {
      return new Text('a' + String.valueOf(rand.nextInt(1000000)));
    } else if (type instanceof Class) {
      Class clazz = (Class)type;
      if (clazz.isArray()) {
        Class compClass = clazz.getComponentType();
        if (compClass != null) {
          ret = Array.newInstance(compClass, 2);
          Array.set(ret, 0, genTypeValue(compClass));
          Array.set(ret, 1, genTypeValue(compClass));
        }
      } else if (clazz.isEnum()) {
        Object [] values = clazz.getEnumConstants();
        ret = values[rand.nextInt(values.length)];
      } else if (clazz.equals(ByteBuffer.class)) {
        // return new ByteBuffer every time
        // to prevent potential side effects
        ByteBuffer buff = ByteBuffer.allocate(4);
        rand.nextBytes(buff.array());
        return buff;
      } else if (type.equals(PlacementConstraint.class)) {
        PlacementConstraint.AbstractConstraint sConstraintExpr =
            targetIn(NODE, allocationTag("foo"));
        ret = PlacementConstraints.build(sConstraintExpr);
      }
    } else if (type instanceof ParameterizedType) {
      ParameterizedType pt = (ParameterizedType)type;
      Type rawType = pt.getRawType();
      Type [] params = pt.getActualTypeArguments();
      // only support EnumSet<T>, List<T>, Set<T>, Map<K,V>, Range<T>
      if (rawType.equals(EnumSet.class)) {
        if (params[0] instanceof Class) {
          Class c = (Class)(params[0]);
          return EnumSet.allOf(c);
        }
      } if (rawType.equals(List.class)) {
        ret = Lists.newArrayList(genTypeValue(params[0]));
      } else if (rawType.equals(Set.class)) {
        ret = Sets.newHashSet(genTypeValue(params[0]));
      } else if (rawType.equals(Map.class)) {
        Map<Object, Object> map = Maps.newHashMap();
        map.put(genTypeValue(params[0]), genTypeValue(params[1]));
        ret = map;
      } else if (rawType.equals(Range.class)) {
        ret = typeValueCache.get(rawType);
        if (ret != null) {
          return ret;
        }
      }
    }
    if (ret == null) {
      throw new IllegalArgumentException("type " + type + " is not supported");
    }
    typeValueCache.put(type, ret);
    return ret;
  }

  /**
   * this method generate record instance by calling newInstance
   * using reflection, add register the generated value to typeValueCache
   */
  @SuppressWarnings("rawtypes")
  protected static Object generateByNewInstance(Class clazz) throws Exception {
    Object ret = typeValueCache.get(clazz);
    if (ret != null) {
      return ret;
    }
    Method newInstance = null;
    Type [] paramTypes = new Type[0];
    // get newInstance method with most parameters
    for (Method m : clazz.getMethods()) {
      int mod = m.getModifiers();
      if (m.getDeclaringClass().equals(clazz) &&
          Modifier.isPublic(mod) &&
          Modifier.isStatic(mod) &&
          m.getName().equals("newInstance")) {
        Type [] pts = m.getGenericParameterTypes();
        if (newInstance == null
            || (pts.length > paramTypes.length)) {
          newInstance = m;
          paramTypes = pts;
        }
      }
    }
    if (newInstance == null) {
      throw new IllegalArgumentException("type " + clazz.getName() +
          " does not have newInstance method");
    }
    Object [] args = new Object[paramTypes.length];
    for (int i = 0; i < args.length; i++) {
      args[i] = genTypeValue(paramTypes[i]);
    }
    ret = newInstance.invoke(null, args);
    typeValueCache.put(clazz, ret);
    return ret;
  }

  private class GetSetPair {
    public String propertyName;
    public Method getMethod;
    public Method setMethod;
    public Type type;
    public Object testValue;

    @Override
    public String toString() {
      return String.format("{ name=%s, class=%s, value=%s }", propertyName,
          type, testValue);
    }
  }

  private <R> Map<String, GetSetPair> getGetSetPairs(Class<R> recordClass)
      throws Exception {
    Map<String, GetSetPair> ret = new HashMap<String, GetSetPair>();
    List<String> excluded = null;
    if (excludedPropertiesMap.containsKey(recordClass.getClass())) {
      excluded = excludedPropertiesMap.get(recordClass.getClass());
    }
    Method [] methods = recordClass.getDeclaredMethods();
    // get all get methods
    for (int i = 0; i < methods.length; i++) {
      Method m = methods[i];
      int mod = m.getModifiers();
      if (m.getDeclaringClass().equals(recordClass) &&
          Modifier.isPublic(mod) &&
          (!Modifier.isStatic(mod))) {
        String name = m.getName();
        if (name.equals("getProto")) {
          continue;
        }
        if ((name.length() > 3) && name.startsWith("get") &&
            (m.getParameterTypes().length == 0)) {
          String propertyName = name.substring(3);
          Type valueType = m.getGenericReturnType();
          GetSetPair p = ret.get(propertyName);
          if (p == null) {
            p = new GetSetPair();
            p.propertyName = propertyName;
            p.type = valueType;
            p.getMethod = m;
            ret.put(propertyName, p);
          } else {
            fail("Multiple get method with same name: " + recordClass
                + p.propertyName);
          }
        }
      }
    }
    // match get methods with set methods
    for (int i = 0; i < methods.length; i++) {
      Method m = methods[i];
      int mod = m.getModifiers();
      if (m.getDeclaringClass().equals(recordClass) &&
          Modifier.isPublic(mod) &&
          (!Modifier.isStatic(mod))) {
        String name = m.getName();
        if (name.startsWith("set") && (m.getParameterTypes().length == 1)) {
          String propertyName = name.substring(3);
          Type valueType = m.getGenericParameterTypes()[0];
          GetSetPair p = ret.get(propertyName);
          if (p != null && p.type.equals(valueType)) {
            p.setMethod = m;
          }
        }
      }
    }
    // exclude incomplete get/set pair, and generate test value
    Iterator<Map.Entry<String, GetSetPair>> itr = ret.entrySet().iterator();
    while (itr.hasNext()) {
      Map.Entry<String, GetSetPair> cur = itr.next();
      GetSetPair gsp = cur.getValue();
      if ((gsp.getMethod == null) ||
          (gsp.setMethod == null)) {
        LOG.info(String.format("Exclude potential property: %s\n", gsp.propertyName));
        itr.remove();
      } else if ((excluded != null && excluded.contains(gsp.propertyName))) {
        LOG.info(String.format(
            "Excluding potential property(present in exclusion list): %s\n",
            gsp.propertyName));
        itr.remove();
      } else {
        LOG.info(String.format("New property: %s type: %s", gsp.toString(), gsp.type));
        gsp.testValue = genTypeValue(gsp.type);
        LOG.info(String.format(" testValue: %s\n", gsp.testValue));
      }
    }
    return ret;
  }

  protected  <R, P> void validatePBImplRecord(Class<R> recordClass,
      Class<P> protoClass)
      throws Exception {
    LOG.info(String.format("Validate %s %s\n", recordClass.getName(),
        protoClass.getName()));
    Constructor<R> emptyConstructor = recordClass.getConstructor();
    Constructor<R> pbConstructor = recordClass.getConstructor(protoClass);
    Method getProto = recordClass.getDeclaredMethod("getProto");
    Map<String, GetSetPair> getSetPairs = getGetSetPairs(recordClass);
    R origRecord = emptyConstructor.newInstance();
    for (GetSetPair gsp : getSetPairs.values()) {
      gsp.setMethod.invoke(origRecord, gsp.testValue);
    }
    Object ret = getProto.invoke(origRecord);
    assertNotNull(ret, recordClass.getName() + "#getProto returns null");
    if (!(protoClass.isAssignableFrom(ret.getClass()))) {
      fail("Illegal getProto method return type: " + ret.getClass());
    }
    R deserRecord = pbConstructor.newInstance(ret);
    assertEquals(origRecord, deserRecord, "whole " + recordClass + " records should be equal");
    for (GetSetPair gsp : getSetPairs.values()) {
      Object origValue = gsp.getMethod.invoke(origRecord);
      Object deserValue = gsp.getMethod.invoke(deserRecord);
      assertEquals(origValue, deserValue, "property " + recordClass.getName() + "#"
          + gsp.propertyName + " should be equal");
    }
  }
}