WebApps.java

/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements.  See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership.  The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.  You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.hadoop.yarn.webapp;

import static org.apache.hadoop.util.Preconditions.checkNotNull;

import java.io.IOException;
import java.net.ConnectException;
import java.net.URI;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

import javax.servlet.http.HttpServlet;

import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.conf.Configuration.IntegerRanges;
import org.apache.hadoop.http.HttpConfig.Policy;
import org.apache.hadoop.http.HttpServer2;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.authorize.AccessControlList;
import org.apache.hadoop.security.http.RestCsrfPreventionFilter;
import org.apache.hadoop.yarn.api.ApplicationClientProtocol;
import org.apache.hadoop.yarn.conf.YarnConfiguration;
import org.apache.hadoop.yarn.webapp.util.WebAppUtils;
import org.eclipse.jetty.webapp.WebAppContext;
import org.glassfish.jersey.server.ResourceConfig;
import org.glassfish.jersey.servlet.ServletProperties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.inject.AbstractModule;
import com.google.inject.Guice;
import com.google.inject.Injector;
import com.google.inject.servlet.GuiceFilter;

/**
 * Helpers to create an embedded webapp.
 *
 * <b>Quick start:</b>
 * <pre>
 *   WebApp wa = WebApps.$for(myApp).start();</pre>
 * Starts a webapp with default routes binds to 0.0.0.0 (all network interfaces)
 * on an ephemeral port, which can be obtained with:<pre>
 *   int port = wa.port();</pre>
 * <b>With more options:</b>
 * <pre>
 *   WebApp wa = WebApps.$for(myApp).at(address, port).
 *                        with(configuration).
 *                        start(new WebApp() {
 *     &#064;Override public void setup() {
 *       route("/foo/action", FooController.class);
 *       route("/foo/:id", FooController.class, "show");
 *     }
 *   });</pre>
 */
@InterfaceAudience.LimitedPrivate({"YARN", "MapReduce"})
public class WebApps {
  static final Logger LOG = LoggerFactory.getLogger(WebApps.class);
  public static class Builder<T> {

    static class ServletStruct {
      public Class<? extends HttpServlet> clazz;
      public String name;
      public String spec;
      public Map<String, String> params;
      public boolean loadExistingFilters = true;
    }
    
    final String name;
    final String wsName;
    final Class<T> api;
    final T application;
    String bindAddress = "0.0.0.0";
    int port = 0;
    boolean findPort = false;
    Configuration conf;
    Policy httpPolicy = null;
    boolean needsClientAuth = false;
    String portRangeConfigKey = null;
    boolean devMode = false;
    private String spnegoPrincipalKey;
    private String spnegoKeytabKey;
    private String csrfConfigPrefix;
    private String xfsConfigPrefix;
    private final HashSet<ServletStruct> servlets = new HashSet<>();
    private final HashMap<String, Object> attributes = new HashMap<>();
    private ApplicationClientProtocol appClientProtocol;
    private ResourceConfig config;
    Builder(String name, Class<T> api, T application, String wsName) {
      this.name = name;
      this.api = api;
      this.application = application;
      this.wsName = wsName;
    }

    Builder(String name, Class<T> api, T application) {
      this(name, api, application, null);
    }

    public Builder<T> at(String bindAddress) {
      String[] parts = StringUtils.split(bindAddress, ':');
      if (parts.length == 2) {
        int port = Integer.parseInt(parts[1]);
        return at(parts[0], port, port == 0);
      }
      return at(bindAddress, 0, true);
    }

    public Builder<T> at(int port) {
      return at("0.0.0.0", port, port == 0);
    }

    public Builder<T> at(String address, int port, boolean findPort) {
      this.bindAddress = checkNotNull(address, "bind address");
      this.port = port;
      this.findPort = findPort;
      return this;
    }

    public Builder<T> withAttribute(String key, Object value) {
      attributes.put(key, value);
      return this;
    }
    
    public Builder<T> withServlet(String name, String pathSpec, 
        Class<? extends HttpServlet> servlet) {
      ServletStruct struct = new ServletStruct();
      struct.clazz = servlet;
      struct.name = name;
      struct.spec = pathSpec;
      servlets.add(struct);
      return this;
    }

    public Builder<T> withServlet(String name, String pathSpec,
        Class<? extends HttpServlet> servlet,
        Map<String, String> params,boolean loadExistingFilters) {
      ServletStruct struct = new ServletStruct();
      struct.clazz = servlet;
      struct.name = name;
      struct.spec = pathSpec;
      struct.params = params;
      struct.loadExistingFilters = loadExistingFilters;
      servlets.add(struct);
      return this;
    }

    public Builder<T> withResourceConfig(ResourceConfig resourceConfig) {
      this.config = resourceConfig;
      return this;
    }

    public Builder<T> with(Configuration conf) {
      this.conf = conf;
      return this;
    }

    public Builder<T> withHttpPolicy(Configuration conf, Policy httpPolicy) {
      this.conf = conf;
      this.httpPolicy = httpPolicy;
      return this;
    }

    public Builder<T> needsClientAuth(boolean needsClientAuth) {
      this.needsClientAuth = needsClientAuth;
      return this;
    }

    /**
     * Set port range config key and associated configuration object.
     * @param configuration configuration.
     * @param portRangeConfKey port range config key.
     * @return builder object.
     */
    public Builder<T> withPortRange(Configuration configuration,
        String portRangeConfKey) {
      this.conf = configuration;
      this.portRangeConfigKey = portRangeConfKey;
      return this;
    }

    public Builder<T> withHttpSpnegoPrincipalKey(String spnegoPrincipalKey) {
      this.spnegoPrincipalKey = spnegoPrincipalKey;
      return this;
    }
    
    public Builder<T> withHttpSpnegoKeytabKey(String spnegoKeytabKey) {
      this.spnegoKeytabKey = spnegoKeytabKey;
      return this;
    }

    /**
     * Enable the CSRF filter.
     * @param prefix The config prefix that identifies the
     *                         CSRF parameters applicable for this filter
     *                         instance.
     * @return the Builder instance
     */
    public Builder<T> withCSRFProtection(String prefix) {
      this.csrfConfigPrefix = prefix;
      return this;
    }

    /**
     * Enable the XFS filter.
     * @param prefix The config prefix that identifies the
     *                         XFS parameters applicable for this filter
     *                         instance.
     * @return the Builder instance
     */
    public Builder<T> withXFSProtection(String prefix) {
      this.xfsConfigPrefix = prefix;
      return this;
    }

    public Builder<T> inDevMode() {
      devMode = true;
      return this;
    }

    public Builder<T> withAppClientProtocol(
        ApplicationClientProtocol appClientProto) {
      this.appClientProtocol = appClientProto;
      return this;
    }

    public WebApp build(WebApp webapp) {
      if (webapp == null) {
        webapp = new WebApp() {
          @Override
          public void setup() {
            // Defaults should be fine in usual cases
          }
        };
      }
      webapp.setName(name);
      webapp.setWebServices(wsName);
      String basePath = "/" + name;
      webapp.setRedirectPath(basePath);
      List<String> pathList = new ArrayList<>();
      if (basePath.equals("/")) { 
        webapp.addServePathSpec("/*");
        pathList.add("/*");
      }  else {
        webapp.addServePathSpec(basePath);
        webapp.addServePathSpec(basePath + "/*");
        pathList.add(basePath + "/*");
      }
      if (wsName != null && !wsName.equals(basePath)) {
        if (wsName.equals("/")) { 
          webapp.addServePathSpec("/*");
          pathList.add("/*");
        } else {
          webapp.addServePathSpec("/" + wsName);
          webapp.addServePathSpec("/" + wsName + "/*");
          pathList.add("/" + wsName + "/*");
        }
      }

      for (ServletStruct s : servlets) {
        if (!pathList.contains(s.spec)) {
          // The servlet told us to not load-existing filters, but we still want
          // to add the default authentication filter always, so add it to the
          // pathList
          if (!s.loadExistingFilters) {
            pathList.add(s.spec);
          }
        }
      }
      if (conf == null) {
        conf = new Configuration();
      }
      try {
        if (application != null) {
          webapp.setHostClass(application.getClass());
        } else {
          String cls = inferHostClass();
          LOG.debug("setting webapp host class to {}", cls);
          webapp.setHostClass(Class.forName(cls));
        }
        if (devMode) {
          if (port > 0) {
            try {
              new URL("http://localhost:"+ port +"/__stop").getContent();
              LOG.info("stopping existing webapp instance");
              Thread.sleep(100);
            } catch (ConnectException e) {
              LOG.info("no existing webapp instance found: {}", e.toString());
            } catch (Exception e) {
              // should not be fatal
              LOG.warn("error stopping existing instance: {}", e.toString());
            }
          } else {
            LOG.error("dev mode does NOT work with ephemeral port!");
            System.exit(1);
          }
        }
        String httpScheme;
        if (this.httpPolicy == null) {
          httpScheme = WebAppUtils.getHttpSchemePrefix(conf);
        } else {
          httpScheme =
              (httpPolicy == Policy.HTTPS_ONLY) ? WebAppUtils.HTTPS_PREFIX
                  : WebAppUtils.HTTP_PREFIX;
        }
        HttpServer2.Builder builder = new HttpServer2.Builder()
            .setName(name).setConf(conf).setFindPort(findPort)
            .setACL(new AccessControlList(conf.get(
                YarnConfiguration.YARN_ADMIN_ACL,
                YarnConfiguration.DEFAULT_YARN_ADMIN_ACL)))
            .setPathSpec(pathList.toArray(new String[0]));

        // Set the X-FRAME-OPTIONS header, use the HttpServer2 default if
        // the header value is not specified
        Map<String, String> xfsParameters =
            getConfigParameters(xfsConfigPrefix);

        if (xfsParameters != null) {
          String xFrameOptions = xfsParameters.get("xframe-options");
          if (xFrameOptions != null) {
            builder.configureXFrame(hasXFSEnabled())
                .setXFrameOption(xFrameOptions);
          }
        }
        // Get port ranges from config.
        IntegerRanges ranges = null;
        if (portRangeConfigKey != null) {
          ranges = conf.getRange(portRangeConfigKey, "");
        }
        int startPort = port;
        if (ranges != null && !ranges.isEmpty()) {
          // Set port ranges if it's configured.
          startPort = ranges.getRangeStart();
          builder.setPortRanges(ranges);
        }
        builder.addEndpoint(URI.create(httpScheme + bindAddress +
            ":" + startPort));
        boolean hasSpnegoConf = spnegoPrincipalKey != null
            && conf.get(spnegoPrincipalKey) != null && spnegoKeytabKey != null
            && conf.get(spnegoKeytabKey) != null;

        if (hasSpnegoConf) {
          builder.setUsernameConfKey(spnegoPrincipalKey)
              .setKeytabConfKey(spnegoKeytabKey)
              .setSecurityEnabled(UserGroupInformation.isSecurityEnabled());
        }

        if (httpScheme.equals(WebAppUtils.HTTPS_PREFIX)) {
          String amKeystoreLoc = System.getenv("KEYSTORE_FILE_LOCATION");
          if (StringUtils.isBlank(amKeystoreLoc)) {
            amKeystoreLoc = System.getProperty("KEYSTORE_FILE_LOCATION");
          }
          if (amKeystoreLoc != null) {
            LOG.info("Setting keystore location to " + amKeystoreLoc);
            String password = System.getenv("KEYSTORE_PASSWORD");
            if (StringUtils.isBlank(password)) {
              password = System.getProperty("KEYSTORE_PASSWORD");
            }
            builder.keyStore(amKeystoreLoc, password, "jks");
          } else {
            LOG.info("Loading standard ssl config");
            WebAppUtils.loadSslConfiguration(builder, conf);
          }
          builder.needsClientAuth(needsClientAuth);
          if (needsClientAuth) {
            String amTruststoreLoc = System.getenv("TRUSTSTORE_FILE_LOCATION");
            if (StringUtils.isBlank(amTruststoreLoc)) {
              amTruststoreLoc = System.getProperty("TRUSTSTORE_FILE_LOCATION");
            }
            if (amTruststoreLoc != null) {
              LOG.info("Setting truststore location to " + amTruststoreLoc);
              String password = System.getenv("TRUSTSTORE_PASSWORD");
              if (StringUtils.isBlank(password)) {
                password = System.getProperty("TRUSTSTORE_PASSWORD");
              }
              builder.trustStore(amTruststoreLoc, password, "jks");
            }
          }
        }

        HttpServer2 server = builder.build();

        for(ServletStruct struct: servlets) {
          if (!struct.loadExistingFilters) {
            server.addInternalServlet(struct.name, struct.spec,
                struct.clazz, struct.params);
          } else {
            server.addServlet(struct.name, struct.spec, struct.clazz);
          }
        }
        for(Map.Entry<String, Object> entry : attributes.entrySet()) {
          server.setAttribute(entry.getKey(), entry.getValue());
        }
        Map<String, String> params = getConfigParameters(csrfConfigPrefix);

        if (hasCSRFEnabled(params)) {
          LOG.info("CSRF Protection has been enabled for the {} application. "
                   + "Please ensure that there is an authentication mechanism "
                   + "enabled (kerberos, custom, etc).",
                   name);
          String restCsrfClassName = RestCsrfPreventionFilter.class.getName();
          HttpServer2.defineFilter(server.getWebAppContext(), restCsrfClassName,
                                   restCsrfClassName, params,
                                   new String[] {"/*"});
        }

        final Map<String, String> guiceFilterParams = new HashMap<>();
        guiceFilterParams.put(ServletProperties.FILTER_FORWARD_ON_404, "true");
        HttpServer2.defineFilter(server.getWebAppContext(), "guice",
            GuiceFilter.class.getName(), guiceFilterParams, new String[]{"/*"});
        server.addJerseyResourceConfig(config, "/*", null);
        webapp.setConf(conf);
        webapp.setHttpServer(server);
      } catch (ClassNotFoundException e) {
        throw new WebAppException("Error starting http server", e);
      } catch (IOException e) {
        throw new WebAppException("Error starting http server", e);
      }
      Injector injector = Guice.createInjector(webapp, new AbstractModule() {
        @Override
        protected void configure() {
          if (api != null) {
            bind(api).toInstance(application);
          }
          if (appClientProtocol != null) {
            bind(ApplicationClientProtocol.class).toInstance(appClientProtocol);
          }
        }
      });
      LOG.info("Registered webapp guice modules");
      // save a guice filter instance for webapp stop (mostly for unit tests)
      webapp.setGuiceFilter(injector.getInstance(GuiceFilter.class));
      if (devMode) {
        injector.getInstance(Dispatcher.class).setDevMode(devMode);
        LOG.info("in dev mode!");
      }
      return webapp;
    }

    private boolean hasCSRFEnabled(Map<String, String> params) {
      return params != null && Boolean.valueOf(params.get("enabled"));
    }

    /**
     * XFS filter is enabled by default.  If the enabled flag is not explicitly
     * specified and set to "false", this method returns true.
     * @return true if XFS is enabled, false otherwise.
     */
    private boolean hasXFSEnabled() {
      return conf.getBoolean(YarnConfiguration.YARN_XFS_ENABLED, true);
    }

    private Map<String, String> getConfigParameters(String configPrefix) {
      return configPrefix != null ? conf.getPropsWithPrefix(configPrefix) : null;
    }

    public WebApp start() {
      return start(null);
    }

    public WebApp start(WebApp webapp) {
      return start(webapp, null);
    }

    public WebApp start(WebApp webapp, WebAppContext ui2Context) {
      WebApp webApp = build(webapp);
      HttpServer2 httpServer = webApp.httpServer();
      if (ui2Context != null) {
        addFiltersForNewContext(ui2Context);
        httpServer.addHandlerAtFront(ui2Context);
      }
      try {
        httpServer.start();
        LOG.info("Web app {} started at {}.", name, httpServer.getConnectorAddress(0).getPort());
      } catch (IOException e) {
        throw new WebAppException("Error starting http server", e, webApp);
      }
      return webApp;
    }

    private void addFiltersForNewContext(WebAppContext ui2Context) {
      Map<String, String> params = getConfigParameters(csrfConfigPrefix);

      if (hasCSRFEnabled(params)) {
        LOG.info("CSRF Protection has been enabled for the {} application. "
            + "Please ensure that there is an authentication mechanism "
            + "enabled (kerberos, custom, etc).", name);
        String restCsrfClassName = RestCsrfPreventionFilter.class.getName();
        HttpServer2.defineFilter(ui2Context, restCsrfClassName,
            restCsrfClassName, params, new String[]{"/*"});
      }
    }

    private String inferHostClass() {
      String thisClass = this.getClass().getName();
      Throwable t = new Throwable();
      for (StackTraceElement e : t.getStackTrace()) {
        if (e.getClassName().equals(thisClass)) continue;
        return e.getClassName();
      }
      LOG.warn("could not infer host class from", t);
      return thisClass;
    }
  }

  /**
   * Create a new webapp builder.
   * @see WebApps for a complete example
   * @param <T> application (holding the embedded webapp) type
   * @param prefix of the webapp
   * @param api the api class for the application
   * @param app the application instance
   * @param wsPrefix the prefix for the webservice api for this app
   * @return a webapp builder
   */
  public static <T> Builder<T> $for(String prefix, Class<T> api, T app, String wsPrefix) {
    return new Builder<>(prefix, api, app, wsPrefix);
  }

  /**
   * Create a new webapp builder.
   * @see WebApps for a complete example
   * @param <T> application (holding the embedded webapp) type
   * @param prefix of the webapp
   * @param api the api class for the application
   * @param app the application instance
   * @return a webapp builder
   */
  public static <T> Builder<T> $for(String prefix, Class<T> api, T app) {
    return new Builder<>(prefix, api, app);
  }

  // Short cut mostly for tests/demos
  @SuppressWarnings("unchecked")
  public static <T> Builder<T> $for(String prefix, T app) {
    return $for(prefix, (Class<T>)app.getClass(), app);
  }

  // Ditto
  public static <T> Builder<T> $for(T app) {
    return $for("", app);
  }

  public static <T> Builder<T> $for(String prefix) {
    return $for(prefix, null, null);
  }
}