{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ZrwVQsM9TiUw" }, "source": [ "##### Copyright 2020 The TensorFlow Probability Authors.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:08.053110Z", "iopub.status.busy": "2021-01-28T12:27:08.052553Z", "iopub.status.idle": "2021-01-28T12:27:08.054632Z", "shell.execute_reply": "2021-01-28T12:27:08.054195Z" }, "id": "CpDUTVKYTowI" }, "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ltPJCG6pAUoc" }, "source": [ "# A Tour of Oryx\n", "\n", "\n", " \n", " \n", " \n", " \n", "
\n", " View on TensorFlow.org\n", " \n", " Run in Google Colab\n", " \n", " View source on GitHub\n", " \n", " Download notebook\n", "
" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Cvrh7Ppuwlbb" }, "source": [ "## What is Oryx?" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "F_n9c7K3xdKQ" }, "source": [ "Oryx is an experimental library that extends [JAX](https://github.com/google/jax) to applications ranging from building and training complex neural networks to approximate Bayesian inference in deep generative models. Like JAX provides `jit`, `vmap`, and `grad`, Oryx provides a set of **composable function transformations** that enable writing simple code and transforming it to build complexity while staying completely interoperable with JAX.\n", "\n", "JAX can only safely transform pure, functional code (i.e. code without side-effects). While pure code can be easier to write and reason about, \"impure\" code can often be more concise and more easily expressive.\n", "\n", "At its core, Oryx is a library that enables \"augmenting\" pure functional code to accomplish tasks like defining state or pulling out intermediate values. Its goal is to be as thin of a layer on top of JAX as possible, leveraging JAX's minimalist approach to numerical computing. Oryx is conceptually divided into several \"layers\", each building on the one below it.\n", "\n", "The source code for Oryx can be found [on GitHub](https://github.com/tensorflow/probability/tree/master/spinoffs/oryx)." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "8cloSFmOiJqn" }, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:08.065072Z", "iopub.status.busy": "2021-01-28T12:27:08.064521Z", "iopub.status.idle": "2021-01-28T12:27:13.020838Z", "shell.execute_reply": "2021-01-28T12:27:13.020323Z" }, "id": "cdhNEzj6iJCc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[33mWARNING: You are using pip version 20.3.3; however, version 21.0 is available.\r\n", "You should consider upgrading via the '/tmpfs/src/tf_docs_env/bin/python -m pip install --upgrade pip' command.\u001b[0m\r\n" ] } ], "source": [ "!pip install -q oryx 1>/dev/null" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:13.027186Z", "iopub.status.busy": "2021-01-28T12:27:13.026557Z", "iopub.status.idle": "2021-01-28T12:27:20.486316Z", "shell.execute_reply": "2021-01-28T12:27:20.485801Z" }, "id": "Ve8yVrLbiOXv" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "sns.set(style='whitegrid')\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import random\n", "from jax import vmap\n", "from jax import jit\n", "from jax import grad\n", "\n", "import oryx\n", "\n", "tfd = oryx.distributions\n", "\n", "state = oryx.core.state\n", "ppl = oryx.core.ppl\n", "\n", "inverse = oryx.core.inverse\n", "ildj = oryx.core.ildj\n", "plant = oryx.core.plant\n", "reap = oryx.core.reap\n", "sow = oryx.core.sow\n", "unzip = oryx.core.unzip\n", "\n", "nn = oryx.experimental.nn\n", "mcmc = oryx.experimental.mcmc\n", "optimizers = oryx.experimental.optimizers" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "AF05PEzd8QFI" }, "source": [ "## Layer 0: Base function transformations\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "8WVTh54ZBJvq" }, "source": [ "At its base, Oryx defines several new function transformations. These transformations are implemented using JAX's tracing machinery and are interoperable with existing JAX transformations like `jit`, `grad`, `vmap`, etc.\n", "\n", "### Automatic function inversion\n", "`oryx.core.inverse` and `oryx.core.ildj` are function transformations that can programatically invert a function and compute its inverse log-det Jacobian (ILDJ) respectively. These transformations are useful in probabilistic modeling for computing log-probabilities using the change-of-variable formula. There are limitations on the types of functions they are compatible with, however (see [the documentation](https://tensorflow.org/probability/oryx/api_docs/python/oryx/core/interpreters/inverse) for more details)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:20.498061Z", "iopub.status.busy": "2021-01-28T12:27:20.497439Z", "iopub.status.idle": "2021-01-28T12:27:20.546051Z", "shell.execute_reply": "2021-01-28T12:27:20.546403Z" }, "id": "YxbReBYs5OpM" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "0.6931472\n", "-0.6931472\n" ] } ], "source": [ "def f(x):\n", " return jnp.exp(x) + 2.\n", "print(inverse(f)(4.)) # ln(2)\n", "print(ildj(f)(4.)) # -ln(2)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "-U08JAgs5w5p" }, "source": [ "### Harvest\n", "`oryx.core.harvest` enables tagging values in functions along with the ability to collect them, or \"reap\" them, and the ability to inject values in their place, or \"planting\" them. We tag values using the `sow` function." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:20.551809Z", "iopub.status.busy": "2021-01-28T12:27:20.551143Z", "iopub.status.idle": "2021-01-28T12:27:20.573068Z", "shell.execute_reply": "2021-01-28T12:27:20.572649Z" }, "id": "pFJNr4SR5_vl" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Reap: {'y': DeviceArray(2., dtype=float32)}\n", "Plant: 25.0\n" ] } ], "source": [ "def f(x):\n", " y = sow(x + 1., name='y', tag='intermediate')\n", " return y ** 2\n", "print('Reap:', reap(f, tag='intermediate')(1.)) # Pulls out 'y'\n", "print('Plant:', plant(f, tag='intermediate')(dict(y=5.), 1.)) # Injects 5. for 'y'" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ffR6Emmm5OVI" }, "source": [ "### Unzip\n", "`oryx.core.unzip` splits a function in two along a set of values tagged as intermediates, then returning the functions `init_f` and `apply_f`. `init_f` takes in a key argument and returns the intermediates. `apply_f` returns a function that takes in the intermediates and returns the original function's output." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:20.578079Z", "iopub.status.busy": "2021-01-28T12:27:20.577477Z", "iopub.status.idle": "2021-01-28T12:27:20.639802Z", "shell.execute_reply": "2021-01-28T12:27:20.639353Z" }, "id": "ojFVr_ZKm0UX" }, "outputs": [], "source": [ "def f(key, x):\n", " w = sow(random.normal(key), tag='variable', name='w')\n", " return w * x\n", "init_f, apply_f = unzip(f, tag='variable')(random.PRNGKey(0), 1.)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "jUJ5isbLjGy8" }, "source": [ "The `init_f` function runs `f` but only returns its variables." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:20.644016Z", "iopub.status.busy": "2021-01-28T12:27:20.643462Z", "iopub.status.idle": "2021-01-28T12:27:20.727518Z", "shell.execute_reply": "2021-01-28T12:27:20.727923Z" }, "id": "26VUK0nTjLcO" }, "outputs": [ { "data": { "text/plain": [ "{'w': DeviceArray(-0.20584226, dtype=float32)}" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "init_f(random.PRNGKey(0))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "0KWemKR2jOn6" }, "source": [ "`apply_f` takes a set of variables as its first input and executes `f` with the given set of variables." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:20.732354Z", "iopub.status.busy": "2021-01-28T12:27:20.731758Z", "iopub.status.idle": "2021-01-28T12:27:20.740731Z", "shell.execute_reply": "2021-01-28T12:27:20.741076Z" }, "id": "SpKFfQZqiDAR" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray(4., dtype=float32)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "apply_f(dict(w=2.), 2.) # Runs f with `w = 2`.\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Q0EtM2bj64fc" }, "source": [ "## Layer 1: Higher level transformations" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "DexZ_6Ds69J4" }, "source": [ "Oryx builds off the low-level inverse, harvest, and unzip function transformations to offer several higher-level transformations for writing stateful computations and for probabilistic programming." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "3zEucvAN7WJX" }, "source": [ "### Stateful functions (`core.state`)\n", "We're often interested in expressing stateful computations where we initialize a set of parameters and express a computation in terms of the parameters. In `oryx.core.state`, Oryx provides an `init` transformation that converts a function into one that initializes a `Module`, a container for state.\n", "\n", "`Module`s resemble Pytorch and TensorFlow `Module`s except that they are immutable." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:20.747579Z", "iopub.status.busy": "2021-01-28T12:27:20.746983Z", "iopub.status.idle": "2021-01-28T12:27:21.173314Z", "shell.execute_reply": "2021-01-28T12:27:21.173691Z" }, "id": "cmV2jLSr62Le" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "layer: FunctionModule(dict_keys(['w', 'b']))\n", "layer.w: [[-2.6105583 0.03385283 1.0863334 -1.4802988 0.48895672]\n", " [ 1.062516 0.5417484 0.0170228 0.2722685 0.30522448]]\n", "layer.b: [0.59902626 0.2172144 2.4202902 0.03266738 1.2164948 ]\n" ] } ], "source": [ "def make_dense(dim_out):\n", " def forward(x, init_key=None):\n", " w_key, b_key = random.split(init_key)\n", " dim_in = x.shape[0]\n", " w = state.variable(random.normal(w_key, (dim_in, dim_out)), name='w')\n", " b = state.variable(random.normal(w_key, (dim_out,)), name='b')\n", " return jnp.dot(x, w) + b\n", " return forward\n", "\n", "layer = state.init(make_dense(5))(random.PRNGKey(0), jnp.zeros(2))\n", "print('layer:', layer)\n", "print('layer.w:', layer.w)\n", "print('layer.b:', layer.b)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "dM02YafPiyVR" }, "source": [ "`Module`s are registered as JAX pytrees and can be used as inputs to JAX transformed functions. Oryx provides a convenient `call` function that executes a `Module`." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:21.179283Z", "iopub.status.busy": "2021-01-28T12:27:21.178736Z", "iopub.status.idle": "2021-01-28T12:27:21.236364Z", "shell.execute_reply": "2021-01-28T12:27:21.235942Z" }, "id": "DRYp96JFizoU" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[-0.94901603, 0.7928156 , 3.5236464 , -1.1753628 ,\n", " 2.010676 ],\n", " [-0.94901603, 0.7928156 , 3.5236464 , -1.1753628 ,\n", " 2.010676 ],\n", " [-0.94901603, 0.7928156 , 3.5236464 , -1.1753628 ,\n", " 2.010676 ],\n", " [-0.94901603, 0.7928156 , 3.5236464 , -1.1753628 ,\n", " 2.010676 ],\n", " [-0.94901603, 0.7928156 , 3.5236464 , -1.1753628 ,\n", " 2.010676 ]], dtype=float32)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vmap(state.call, in_axes=(None, 0))(layer, jnp.ones((5, 2)))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "p_ZPPibD-NI4" }, "source": [ "The `state` API also enables writing stateful updates (like running averages) using the `assign` function. The resulting `Module` has an `update` function with an input signature that is the same as the `Module`'s `__call__` but creates a new copy of the `Module` with an updated state." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:21.242080Z", "iopub.status.busy": "2021-01-28T12:27:21.241319Z", "iopub.status.idle": "2021-01-28T12:27:21.247529Z", "shell.execute_reply": "2021-01-28T12:27:21.247098Z" }, "id": "fXnL3ZvD-UKx" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.0\n", "1.0\n", "3.0\n" ] } ], "source": [ "def counter(x, init_key=None):\n", " count = state.variable(0., key=init_key, name='count')\n", " count = state.assign(count + 1., name='count')\n", " return x + count\n", "layer = state.init(counter)(random.PRNGKey(0), 0.)\n", "print(layer.count)\n", "updated_layer = layer.update(0.)\n", "print(updated_layer.count) # Count has advanced!\n", "print(updated_layer.call(1.))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "VO_VdtAA70EN" }, "source": [ "\n", "### Probabilistic programming" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "-bYaYxDA-5yz" }, "source": [ "In `oryx.core.ppl`, Oryx provides a set of tools built on top of `harvest` and `inverse` which aim to make writing and transforming probabilistic programs intuitive and easy.\n", "\n", "In Oryx, a probabilistic program is a JAX function that takes a source of randomness as its first argument and returns a sample from a distribution, i.e, `f :: Key -> Sample`. In order to write these programs, Oryx wraps [TensorFlow Probability](https://www.tensorflow.org/probability) distributions and provides a simple function `random_variable` that converts a distribution into a probabilistic program." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:21.252219Z", "iopub.status.busy": "2021-01-28T12:27:21.251576Z", "iopub.status.idle": "2021-01-28T12:27:21.377046Z", "shell.execute_reply": "2021-01-28T12:27:21.377442Z" }, "id": "fh8AFQq771VJ" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray(-0.20584235, dtype=float32)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def sample(key):\n", " return ppl.random_variable(tfd.Normal(0., 1.))(key)\n", "sample(random.PRNGKey(0))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "JWnPjFxx_i5I" }, "source": [ "What can we do with probabilistic programs? The simplest thing would be to take a probabilistic program (i.e. a sampling function) and convert it into one that provides the log-density of a sample." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:21.381750Z", "iopub.status.busy": "2021-01-28T12:27:21.381168Z", "iopub.status.idle": "2021-01-28T12:27:21.409106Z", "shell.execute_reply": "2021-01-28T12:27:21.409475Z" }, "id": "h6U4_pAp_huX" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray(-1.4189385, dtype=float32)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ppl.log_prob(sample)(1.)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "51yfR5Sm2ZuD" }, "source": [ "The new log-probability function is compatible with other JAX transformations like `vmap` and `grad`." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:21.414180Z", "iopub.status.busy": "2021-01-28T12:27:21.413593Z", "iopub.status.idle": "2021-01-28T12:27:21.522419Z", "shell.execute_reply": "2021-01-28T12:27:21.521984Z" }, "id": "je3wggIi2Ytm" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray([-0., -1., -2., -3., -4., -5., -6., -7., -8., -9.], dtype=float32)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grad(lambda s: vmap(ppl.log_prob(sample))(s).sum())(jnp.arange(10.))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "wEqAS9AfAPCh" }, "source": [ "Using the `ildj` transformation, we can compute `log_prob` of programs that invertibly transform samples." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:21.540544Z", "iopub.status.busy": "2021-01-28T12:27:21.539819Z", "iopub.status.idle": "2021-01-28T12:27:22.537555Z", "shell.execute_reply": "2021-01-28T12:27:22.537972Z" }, "id": "2SGe1YZ5AUP1" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD7CAYAAACCEpQdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAArPklEQVR4nO3deXxTZb4/8E+SJum+QtuwtVBpCRQEiiwqcsFRuNpOuTpemI6jPwUdxYVRGekM2pbFpTiXC7K4X0fmOjqDXkAKKjBFWUaBtoDWFgqlUGjThbSFrmmbPL8/SmNL9zTNyfJ5v159NTlLzvfp8snJc855jkwIIUBERC5FLnUBRERkewx/IiIXxPAnInJBDH8iIhfE8CcickFuUhfQE5PJhNraWiiVSshkMqnLISJyCEIINDU1wcvLC3J5x/18uw//2tpa5OXlSV0GEZFDioyMhI+PT4fpdh/+SqUSQEsDVCqVeXp2djaio6OlKsuqnKktgHO1h22xT87UFmBg2tPY2Ii8vDxzht7I7sO/tatHpVJBrVa3m3fjc0fmTG0BnKs9bIt9cqa2AAPXnq66y3nAl4jIBTH8iYhcEMPfRhqbjF3OG6MdZ8NKiIgcoM/fWaiUCsS9sLPTebv+K97G1RCRq+OePxGRC2L4ExG5IIY/EZELYvhboKuDt90d1CUisic84GuBrg7e8sAtETkK7vkTEbkghj8RkQti+BMRuSCGPxGRC2L4ExG5IIY/EZELYvgTEbkghj8RkQti+BMRuSCGvxVZOrxDd+txyAgiGgh9Ht4hNTUVX3/9NYqKirBr1y5ERkYCAAoKCpCYmIiqqir4+/sjNTUV4eHhPc5zJpaO2c+x/onI1vq853/nnXfi448/xtChQ9tNT05ORkJCAr7++mskJCQgKSmpV/OIiMj2+hz+U6ZMgUajaTdNr9cjJycHsbGxAIDY2Fjk5OSgoqKi23lERCQNq4zqqdPpEBISAoVCAQBQKBQIDg6GTqeDEKLLeYGBgdbYPBER9ZHDDOmcnZ3dYVpmZqYElQAxMTE23Z5U7ewPR6y5K2yLfXKmtgC2b49Vwl+j0aC0tBRGoxEKhQJGoxFlZWXQaDQQQnQ5ry+io6OhVqvNzzMzM20ewlJxtHY60++GbbFPztQWYGDaYzAYOt1pbmWVUz2DgoKg1WqRlpYGAEhLS4NWq0VgYGC384iISBp93vNfs2YN9u7diytXruCRRx6Bv78/du/ejZSUFCQmJmLLli3w9fVFamqqeZ3u5hERke31OfxfeuklvPTSSx2mR0REYNu2bZ2u0908IiKyPV7hS0Tkghj+REQuiOHfBY6pQ0TOzGHO87c1jrdDRM6Me/5ERC6I4U9E5IIY/nauq2MPPCZBRP3BPn8719WxBx53IKL+4J4/EZELYvgTEbkghj8RkQti+BMRuSCGPxGRC2L4ExG5IIY/EZELYvgTEbkghj8RkQti+BMRuSCGPxGRC2L4ExG5IIY/EZELYvgTEbkghr+D6m48f471T0Q94Xj+Dor3GCai/uCePxGRC2L4ExG5IIY/EZELYvg7IR4MJqKeWP2A75w5c6BSqaBWqwEAy5Ytw8yZM3Hy5EkkJSXBYDBg6NCheOONNxAUFGTtzRN4MJiIejYgZ/u8+eabiIyMND83mUz4wx/+gNdeew1TpkzBli1b8Oc//xmvvfbaQGyeiIh6YJNun+zsbKjVakyZMgUAsHDhQnz11Ve22DQREXViQPb8ly1bBiEEYmJi8Pzzz0On02HIkCHm+YGBgTCZTKiqqoK/v3+vXjM7O7vDtMzMTGuV3EFMTMyAvbbUBvLnZstt2ArbYp+cqS2A7dtj9fD/+OOPodFo0NjYiFdeeQWrVq3CXXfd1e/XjY6ONh9HAFp+UM4c0ANpoH9uzvS7YVvskzO1BRiY9hgMhk53mltZvdtHo9EAAFQqFRISEpCVlQWNRoPi4mLzMhUVFZDL5b3e6yciIuuyavjX1dWhuroaACCEwJ49e6DVahEdHY2GhgZkZGQAAD799FPMmzfPmpsmIqI+sGq3j16vxzPPPAOj0QiTyYSIiAgkJydDLpdj7dq1SE5ObneqJxERScOq4T98+HDs2LGj03mTJ0/Grl27rLk5IiKykEtf4curXYnIVbn0kM68EpaIXJVL7/kTEbkqhr+L4aBvRAS4eLePK2JXFxEB3PMnInJJDH8iIhfE8CcickEMfyIiF8TwJ7OuzvbhWUBEzodn+5BZV2cCff56bJfrGJqMUCsV7aa1Dk3b2GSE6oZ5RGQfGP7Uo55OD+Wpo0SOh90+REQuiOFPROSCGP5ERC6I4U9E5IIY/kRELojhT0Tkghj+REQuiOFPROSCGP5ERC6I4U9E5IIY/jRgeMtIIvvl9GP7cHAx6XQ3JlBfB4sD+LsksianD3/es9Y+WTJYHH9fRNbDbh8iIhfE8CenwOMLRH1js26fgoICJCYmoqqqCv7+/khNTUV4eLitNk9OjscXiPrGZuGfnJyMhIQExMfHY+fOnUhKSsLWrVtttXlyYTy+QNSRTbp99Ho9cnJyEBvbsgcWGxuLnJwcVFRU2GLz5CRs2X3Tuq3WW1K2ZbCwi6mred29nqXbIuqJTfb8dTodQkJCoFC0fIxWKBQIDg6GTqdDYGBgt+sKIQAAjY2NHeYZDIZebd/fq/OP7waDwarzrP16rrit7uYJUzN+m/Rlp+u8v+Ium26rq3lvLZ8DYeq6u6iz9bp7vZ7mGQzNXW6rM739n3EEUralqdkIpVvH33NX03syKmJ0l+2x9DVbM7M1Q28kE13NsaLs7GwsX74cu3fvNk+755578MYbb2DcuHHdrltdXY28vLyBLpGIyClFRkbCx8enw3Sb7PlrNBqUlpbCaDRCoVDAaDSirKwMGo2mx3W9vLwQGRkJpVIJmUxmg2qJiByfEAJNTU3w8vLqdL5Nwj8oKAharRZpaWmIj49HWloatFptj10+ACCXyzt91yIiou65u7t3Oc8m3T4AkJ+fj8TERFy7dg2+vr5ITU3FqFGjbLFpIiK6gc3Cn4iI7Aev8CUickEMfyIiF8TwJyJyQQx/IiIXxPAnInJBDnUzl8rKSrz44osoLCyESqVCWFgYVq1a1avrBezVkiVLcPnyZcjlcnh6euLll1+GVquVuqx+2bRpEzZu3Ihdu3YhMjJS6nIsMmfOHKhUKqjVagDAsmXLMHPmTImrsozBYMCrr76K7777Dmq1GhMnTsTq1aulLssily9fxlNPPWV+Xl1djZqaGhw7dkzCqix34MABbNiwAUIICCHw9NNP4+6777bNxoUDqaysFN9//735+euvvy7++Mc/SlhR/127ds38eN++fWL+/PkSVtN/2dnZYtGiRWL27NnizJkzUpdjMUevv63Vq1eLV155RZhMJiGEEOXl5RJXZD1r1qwRK1eulLoMi5hMJjFlyhTz31lubq6YOHGiMBqNNtm+Q3X7+Pv7Y9q0aebnEydORHFxsYQV9V/bq5dramocegiLxsZGrFq1CikpKVKXQtfV1tZix44dWLp0qflva9CgQRJXZR2NjY3YtWsX7r//fqlLsZhcLkd1dTWAlk8xwcHBkMttE8sO1e3TlslkwieffII5c+ZIXUq/rVixAkeOHIEQAu+//77U5Vhsw4YN+OUvf4lhw4ZJXYpVLFu2DEIIxMTE4Pnnn4evr6/UJfXZpUuX4O/vj02bNuHo0aPw8vLC0qVLMWXKFKlL67f09HSEhIT0ODikvZLJZFi/fj2WLFkCT09P1NbW4t1337VdATb5fDEAUlJSxJNPPmmzj0i2sH37drF48WKpy7BIVlaWeOihh8xdC47ebVJcXCyEEMJgMIikpCTxwgsvSFyRZbKzs0VkZKT44osvhBBCnDx5UkyfPl1UV1dLXFn/LV68WHz00UdSl2GxpqYm8fDDD4uMjAwhhBAZGRli1qxZoqamxibbd6hun1apqam4ePEi1q9fb7OPSLYwf/58HD16FJWVlVKX0mfHjx9Hfn4+7rzzTsyZMwclJSVYtGgRDh8+LHVpFmkdcValUiEhIQFZWVkSV2QZjUYDNzc3842Ubr75ZgQEBKCgoEDiyvqntLQUx48fR1xcnNSlWCw3NxdlZWXmGwbFxMTAw8MD+fn5Ntm+wyXnunXrkJ2djc2bN0OlUkldTr/U1tZCp9OZn6enp8PPzw/+/v7SFWWhxx9/HIcPH0Z6ejrS09MRGhqKDz74ALfffrvUpfVZXV2duR9WCIE9e/Y47BlYgYGBmDZtGo4cOQKg5V7aer0eYWFhElfWP9u3b8esWbMQEBAgdSkWCw0NRUlJCc6fPw+gZfBLvV6PESNG2GT7DtXnf/bsWbzzzjsIDw/HwoULAQDDhg3D5s2bJa7MMvX19Vi6dCnq6+shl8vh5+eHt99+26EP+joDvV6PZ555BkajESaTCREREUhOTpa6LIutXLkSf/rTn5Camgo3NzesXbvWIY9ftLV9+3asWLFC6jL6ZfDgwUhJSWl3MP7VV1+12c6f3Y/qaTKZUFtby5u5EBH1gWhzM5fOusftfs+/traWt3EkIrKQpLdx7A+lUgmgpQF97ePPzs5GdHT0QJRlc2yLfWJb7BPb0nIdRF5enjlDb2T34d/a1dP2Uvu+sGQde8W22Ce2xT6xLS266i53uLN9iIio/+x+z9/ZfPejDl8cyoeXuxL+Pmr4+6ihCfLCv00eBoWC78VEZBsMfxs6kHkJ6z/JQnCgJ2pUTThzsRJXaw0QAjh4sgjLfzsFnu6d988REVkTw99Gvv7+AjZ/dgrjIwbhpUenwUPd8qM3Gk3Yf7wQWz7/Acs3HUbSoukYHOAhcbVE5OzYz2ADXxzKx6ZtpzA5KhhJi6ebgx8AFAo55k4PR8ri6SirrMOyN7/FuUtV0hVLRC6B4T/A9h+7iPd2ZGPGeA1WPDIVaqWi0+UmRQVj7TMz4aaQI3HLYZy+WGHjSonIlTD8B5DRaMLf9p5BVFgAXvztFCjdOg/+VmGhvvjz0jvg563GG3/NQE19k40qJSJXw/AfQP/6UYfyynr8as5ouPXyTJ4AH3e8+GAM9FcbsPEfJ2Dno28QkYNi+A8QIQR2fHsOmkFeuGVsaJ/WjQoLxEP3aPGvH3T46rsLA1MgEbk0hv8AOX2hEnmFVYifOQoKed8HpJs/6yZMjgrGezuzUVB8dQAqJCJXxvAfIDsOnoO3hxJ33mLZ2NxyuQzP/XoyvD2UWPvXDDQ2m6xcIRG5Mob/ACjR1+L7H3WYNyMc7mrLL6Xw91Hjhd/EoKi8Bv88dc2KFRKRq2P4D4Bdh85DJpMh9vaR/X6tm0cPxrzp4TiWV4PzRez+ISLr6FX4FxQUYMGCBZg7dy4WLFiACxcudFjmxRdfRHx8vPlrzJgx+Oc//wkA2LhxI2bMmGGet3LlSqs2wp7U1Ddh37GLmDlpKIL8rHOl7kP3aOGhkuOtz0/BZOLZP0TUf73qk0hOTkZCQgLi4+Oxc+dOJCUlYevWre2WWbt2rfnx6dOn8fDDD2PmzJnmafPnz8fy5cutVLb92vv9RdQbjIi/I8Jqr+ntqcLdk/yw4/tK7DtWiLnTHfv+q0QkvR73/PV6PXJychAbGwsAiI2NRU5ODioqur4C9bPPPkNcXJzD32DdEodOXkZUWABuGuZv1de9eaQnxo0Kwke7f8LVGoNVX5uIXE+P4a/T6RASEgKFouXqVIVCgeDgYOh0uk6Xb2xsxK5du3D//fe3m757927ExcXh0UcfxYkTJ6xQuv2pqjbg3OWruEUbYvXXlslkePK+CahraMZHu3Os/vpE5FqsPqrn/v37MWTIEGi1WvO0hQsX4oknnoBSqcSRI0ewZMkS7NmzBwEBAb1+3ezsbIvqyczMtGg9S/xQUAcA8BAVA7LdK8VnMS3KC/uOFWK4Xz1GDHbcOxXZ8vcy0NgW+8S2dK/H8NdoNCgtLYXRaIRCoYDRaERZWRk0Gk2ny3/++ecd9voHDx5sfnzbbbdBo9Hg7NmzmDp1aq8LjY6O7vOtzDIzMxETE9OndfrjmzOZ8PWqQewvpkNuwYVd3Wlty9joZuSl/hMHcxvxX3fPsPp2bMHWv5eBxLbYJ7YFMBgM3e4099jtExQUBK1Wi7S0NABAWloatFotAgMDOyxbUlKCzMxMxMXFtZteWlpqfpybm4uioiKMHNn/0yDtickkcPJMOSZFBg9oIHuo3fDwvWNx7vJVHMi8NGDbISLn1qtun5SUFCQmJmLLli3w9fVFamoqAOCxxx7Ds88+i/HjxwMAtm/fjtmzZ8PPz6/d+uvWrcNPP/0EuVwOpVKJtWvXtvs04AzOF19FVY0Bk8cED/i27pg0DGmHC7B1Tw5unTCk3f0BiIh6o1epERERgW3btnWY/t5777V7/uSTT3a6fuubhTM7caYMADApauDf1ORyGRbHR+MPGw/h8/SzePDftT2vRETUBq/wtZLM02UYNdQPAT7uNtnemPBA3DFpKLZ/cw5llXU22SYROQ+GvxXUNTTh9IUKTI4a+C6fth6+dywA8NRPIuozhr8VnDp7BUaTsEl/f1vBAZ74j9k34eCJIuQW8LaPRNR7DH8ryDpTBg+1G8aEdTwDaqDdP3s0An3VeP+LHznuDxH1GsO/n4QQyDpdigk3DYLSzfY/Tg+1Gx66ZyzyCqtw8MRlm2+fiBwTw7+fisprUFZZjxgbd/m0NTtmOG4a5oe/7M5Bg6FZsjqIyHEw/Psp63TrKZ7ShX/LqZ/job/agO3fnJOsDiJyHAz/fjqRV46hg70QGuQlaR3jRgXhtpuH4PNvzuFKVb2ktRCR/WP494PJJJBboEd0xCCpSwEA/L97x8JoFNi6h6d+ElH3GP79UFReg9qGZowJ6/3opAMpNMgL82dF4EDmZeQVVkpdDhHZMYZ/P5y52BKwURKc4tmVB+4cDX8fNd7bwVM/iahrDP9+OFNYCS93Nwwd7C11KWae7ko8fI8Wpy9W4pssjvpJRJ1j+PfDmYsViBwRYHdj6s+ZMgJRIwLwYVoOauubpC6HiOwQw99C9YZmXNRds6sun1ZyuQy/u288rtYY8MneM1KXQ0R2iOFvoXOXqmASQJSdHOy90ejhAbh7Whh2HT6PiyXXpC6HiOwMw99Cpy+2DKQWOcI+wx8AfvvvWniq3fDu9h8hBA/+EtHPehX+BQUFWLBgAebOnYsFCxbgwoULHZbZuHEjZsyYgfj4eMTHx2PlypXmefX19fj973+Pu+66C/PmzcOBAwes1gCp5BVWYsggL/h6qaQupUt+3mo8+O9a/HDuCg6fKpa6HCKyI726k1dycjISEhIQHx+PnTt3IikpCVu3bu2w3Pz587F8+fIO0z/44AN4e3tj3759uHDhAn7zm99g79698PKS9qpYSwkhcOZiJSZG2v+tKOfNCMfe7y/i/Z3ZiBkTDE93pdQlEZEd6HHPX6/XIycnB7GxsQCA2NhY5OTkoKKi9+PHf/nll1iwYAEAIDw8HNHR0Th48KCFJUuvvLIeldUGuzzYeyOFXIYlv5qAyuoG3vSFiMx6DH+dToeQkBAoFAoAgEKhQHBwMHQ6XYdld+/ejbi4ODz66KM4ceKEeXpxcTGGDh1qfq7RaFBSUmKN+iXx88Vd9tvf31ZUWCBibx+FL7+7wJu+EBGAXnb79MbChQvxxBNPQKlU4siRI1iyZAn27NmDgADrBGR2drZF62VmZlpl+20dzKyCm0KGCt05ZJba7hz//rQlOtSEbz0UeOOv3+F380LgppD22oSB+L1IhW2xT2xL93oMf41Gg9LSUhiNRigUChiNRpSVlUGj0bRbbvDgn/u/b7vtNmg0Gpw9exZTp07FkCFDUFRUhMDAlm4SnU6HadOm9anQ6OhoqNXqPq2TmZmJmJiYPq3TG58cOYiosEBMvWWK1V+7K9Zoi9K3BKs+OIrzVT749d1RVqqs7wbq9yIFtsU+sS2AwWDodqe5x26foKAgaLVapKWlAQDS0tKg1WrNQd6qtLTU/Dg3NxdFRUUYOXIkAGDevHn4+9//DgC4cOECfvzxR8ycObPPjbEHTc1GnC+6iig7PsWzK7eMDcXMiUPxj/15uFRaLXU5RCShXnX7pKSkIDExEVu2bIGvry9SU1MBAI899hieffZZjB8/HuvWrcNPP/0EuVwOpVKJtWvXmj8NLFq0CImJibjrrrsgl8uxatUqeHvbz3g4fVFQfA1NzSaH6e+/0WPzo3HiTBk2/uMkXnvqdijsbGgKIrKNXoV/REQEtm3b1mH6e++9Z37c+obQGU9PT7z55psWlGd/Wi/uctTwD/Bxx+L4aKz/9AR2fnsO980eLXVJRCQBXuHbR2cuVmKQnzuC/DykLsVic6YMx4zxGvz1y1wUFF+VuhwikgDDv4/OXKxEpIPu9beSyWR46lc3w8dThT9/nInGJqPUJRGRjTH8++BqjQGlFXWIGmH/F3f1xM9bjWcXTEJhSTW27smVuhwisjGGfx+cvVQFAIgc4S9pHdYyRRuCe24Nx86D+TiVVy51OURkQwz/PsgrrIRcBkQM85e6FKt5JG4chg72xn9/moWrNQapyyEiG2H490FeYSVGhPrCQ221C6Ml565yw7IHY3C1phF//jgTRt73l8glMPx7SQiBvMIqjB7uL3UpVnfTMH88cd8EnMwrxydfn5a6HCKyAYZ/L5VW1KG6rhGjHfDK3t6YOz0Md00dgb/vz8Oxnxx30D0i6h2Gfy/lFbaM5BnphHv+rX533wREDPPDur9lovhKjdTlENEAYvj3Ul5hFVRucoRpfKUuZcColQokPnQLZDIZXvvLcTQYmqUuiYgGCMO/l/IKKxExzB9uCuf+kYUGeWHZgzEoLLmGtf+bAaPRJHVJRDQAnDvJrMRoNCG/6CpGO8n5/T2JGROCx/9jAo7nlOId3vydyCk5zzmLA6iwtBqNTUZEDnfOg72dufe2kSivrMPnB85hcIAHHrgzUuqSiMiKGP69YD7Y66Rn+nTloXvGoryyHlv35GJwgCf+bfIwqUsiIith+PdCXmEVfDyVCA3ylLoUm5LLZfj9ryehoroBGz7Ngq+nCpPHBEtdFhFZAfv8eyGvsBKjRwRAJnO9G58o3RRY8f+mYkSIL9Z8eBRZp8ukLomIrKBX4V9QUIAFCxZg7ty5WLBgAS5cuNBhmc2bN+Pee+9FXFwc7rvvPhw6dMg8LzExEXfccQfi4+MRHx+Pt956y2oNGGgNhmYUllxzqf7+G3l7qrD6iVsxPNiHbwBETqJX3T7JyclISEhAfHw8du7ciaSkJGzdurXdMhMmTMCjjz4KDw8PnD59Gg8++CAOHz4Md3d3AMDjjz+OBx980PotGGD5RVdhEnCZM3264uvV8gbw8tv/wpoPj+KlR6axC4jIgfW456/X65GTk4PY2FgAQGxsLHJyclBRUdFuuZkzZ8LDo+XuVlFRURBCoKqqyvoV21jrwV5nHNOnr1rfAFo/AXAYCCLH1WP463Q6hISEQKFQAAAUCgWCg4Oh0+m6XGfHjh0YMWIEQkNDzdM+/PBDxMXFYcmSJcjPz7dC6baRV1iJ4AAPBPi4S12KXWh9AwjT+OKVD49i95ECqUsiIgtY/WyfY8eOYcOGDfif//kf87TnnnsOgwcPhlwux44dO7B48WLs37/f/IbSG9nZ2RbVk5mZadF65u2eK8WQIFW/X8ca7KGGVv85wxOfHW7A2//3A37IPY9fTPSDvA8HxO2pLf3FttgntqUHogdXrlwRMTExorm5WQghRHNzs4iJiRF6vb7DsllZWeKOO+4Q2dnZ3b7m1KlTxeXLl3vatBBCiIaGBpGRkSEaGhp6tXxbGRkZfV6nrbKKOhH7/A6x/Zuz/Xoda+hvWwZCc7NRbP7spIh9fod4/aNjwtDY3Kv17LEtlmJb7BPb0nN29tjtExQUBK1Wi7S0NABAWloatFotAgPb38f2hx9+wHPPPYc333wT48aNazevtLTU/PjQoUOQy+UICQmxxnvXgDqZ13JWy8RIHtjsjEIhx5P3TcAjsWNx+FQxlm86hBJ9rdRlEVEv9KrbJyUlBYmJidiyZQt8fX2RmpoKAHjsscfw7LPPYvz48Vi5ciUaGhqQlJRkXm/t2rWIiorC8uXLodfrIZPJ4O3tjbfeegtubvZ/fdmJvHIE+KgRFuojdSl2SyaT4b7Zo6/fCvIEfr/uG/z+15MxPVojdWlE1I1eJXBERAS2bdvWYfp7771nfvz55593uf5f/vKXvlcmMZNJ4GReOW4ZG+KSF3f11bRoDdY/54vUv2bglQ+PYf6sCDx871inHwWVyFHxP7ML54uuorquEZMiB0tdisMIDfLC2qdvR+xtI7Hj23z84c2DuKC7JnVZRNQJhn8XTlzv77+Z4d8nSjcFfnffBPzx4VtQXlWP5/77G3zy9Wk0NfO+AET2xP473iVy4kw5Rg7x5fn9Frp1whCMGxWE93dm4297z+BfP+rwzH9OdLmRUYnsFff8O1FvaEbuBT0m8SyffvHzVuOF38Tg5UXTcK22ES9sOIh1f8vE1TreHpJIatzz78RP5/VoNgpMimKXjzVMHRuK6OVB+Cz9LHZ8m49DJwV0tadx3+yb4KHmnyCRFLjn34kTZ8qgcpNj7MggqUtxGp7uSjx0z1i8tfxORA11x6f7zuDxV/fj/w6cRV1Dk9TlEbkchn8nTuSVYdyoIKiUvR9+gnonJNATD9wehDeemYnwIb74MC0Hi1/Zh0/3nUFNPd8EiGyFn7lvcKWqHpdKa3DX1DCpS3FqY8IDsfp3t+LMxQr8Y/9ZfPzVafzfgbOYHTMc99w2EmGhvlKXSOTUGP43aB3SYVIUD/baQlRYIF5eNA3ni65i58F87DtWiD3/uoDoiCDcc+tITBsXyk9gRAOA4X+DE2c4pIMURg31w3O/noxH48Zh/7FC7PnuAtb+NQNe7m64dcIQzJo8DNERg6CQ82prImtg+LdhMgmcPFuOmDHBHNJBIn7eatw/ZzTm/9tN+OFsOb7JuozDp4qw71ghAn3dMS06FNPHaTD+piAo3fiJgMhSDP82jpwqxrXaRtyiDe15YRpQCrkMk6KCMSkqGE/ePwHHfyrFwZOXkZ5xCV/+6wI81G6YfH3+zaMHITTIS+qSiRwKw/86Q5MRH+7+CSOH+OLWm4dIXQ614a5yw8xJQzFz0lAYmoz44Ww5jv5UguM5JTjyQzGAlrOIbh49GONGBWJMeCA0QV789EbUDYb/dTu+OYfyyno8t3Ay+5XtmFqpwC1jQ3HL2FAIIXCptBqnzl7BqbPlOHKqCHuPXgTQcrvJMWGBGD3CH6OG+iFiqB8Cfd35hkB0HcMfgP5qPbaln8WM8RqMv2mQ1OVQL8lkMowI9cWIUF/EzRwFo0ngcmk1Tl+sQO6FCpy+UIljOT/fZN7PW4WRGj8MD/XB8BAfjAjxwdDB3vDzVvFNgVwOwx/A1j25MBoFHo0b1/PCZLcUchnCNL4I0/hi7vRwAEBdQxMKiq/hfNFVnC+6iosl17Dv6EU0NBrN63m6u0EzyAuaIC+EBnkhOMADgwM8zd85BAU5o179VRcUFCAxMRFVVVXw9/dHamoqwsPD2y1jNBqxZs0aHDp0CDKZDI8//jgeeOCBHudJLa+wEukZl/CrOaN50NAJeborMW5UEMaN+nmoDpNJ4MrVehSWVKO4vAa6K7Uo1tciv+gqvvtRB6NJ3PAabgjyc0egb8uXv487/L3V8PdRo1zXAL/gKvh6qeDjpYK7SsFPEeQQehX+ycnJSEhIQHx8PHbu3ImkpCRs3bq13TK7du1CYWEh9u7di6qqKsyfPx8zZszAsGHDup0npaZmI97b8SP8fdR44M7RktZCtiOXyxAc4IngAE9A2/5e0kaTQOW1BpRV1qG8sh7lVfWouNYA/dV66K82IPu8HlXVhnb3J/jfb741P1a6yeHtoYSXhxLeHkp4e6rg6e4GL3clPN3d4OHuBg+1GzzVbnBXtzx2V7lBrVJArVSYH6uUcqjcFJDz+BMNkB7DX6/XIycnBx9++CEAIDY2FqtXr0ZFRUW7m7jv2bMHDzzwAORyOQIDA/GLX/wCX331FRYvXtztPFurrW9CRm4pjv5UgozcUtQbmrF0wUR4uittXgvZH4VchkH+Hhjk7wGM7HwZIQTqDc2oqjbg+4wfMGT4SFTXNuJabSOq6xpRU9+Emrom1NQ3orK6AUXlzahraEJdQ3Ofb2qjcpNDpWx5M1C6/fxd6SZv+VK0PHdzk8FNITd/KRQyKBVyKBRyuClkUMhbvsvlLY8VChkU8pYvuVwOhRwoLKxFtezy9WkyyGUyyGUtb5bm522+y2Ro91gma5nXOh3Xv7fOa5nU5jnQbhng5+myG6e1vBxaJ8p+fvjza12fIZPJ0GQUaGwyQtZ2Rtt1zJNl7Z63fU1n12P463Q6hISEQKFouaBGoVAgODgYOp2uXfjrdDoMGfLzKZIajQYlJSU9zhtIF8oM+K+Xv0RTsxECgDAJNBtNMAnA31uNOyYNxa0ThvBWjdQnMpkMnu5KeLorERasRkwfblbf1GxEXUMzGhqNaDA0o97QjIbG688bjTA0GmFobIahyYjGJhMam4xobDKiyWiCocmIpiYTGpuNaGo2oanZdP0NpRHNJhOam01oNrZ+CRiNJjSbBJqbTR26srr0XaaFPxU79Pciq71UhzeNG2e0ndTjIrJu57d9qlQq8OuZAYjpXZl94jBHsrKzs/u8TqC3GyaEq2G6/ocvkwFKhQwRGncMC1JBLjdC1FxCVtYla5c7IDIznecfk21pT3n9y6f1QZfksGQwXiEETAItXybR+XdxfTkTYBICQgACgMn08/qizWsJ89fPy4o2ywjRuu3W6TdMQ9vHLRPaTWtd9oZ5rfNb5sE8Q7QufcP7nLjhQZfP2/ysulu/i6edurGWDvN7WMFNIUOAt2JA/l96DH+NRoPS0lIYjUYoFAoYjUaUlZVBo9F0WK64uBgTJkwA0H5vv7t5vRUdHQ21Wt2ndTIzM5G4aE6f1rFXmZmZiIkZiPd/22Nb7BPbYp8sbYvBYOh2p7nHXYigoCBotVqkpaUBANLS0qDVatt1+QDAvHnzsG3bNphMJlRUVGD//v2YO3duj/OIiMj2etXtk5KSgsTERGzZsgW+vr5ITU0FADz22GN49tlnMX78eMTHx+PUqVO4++67AQBPPfUUhg8fDgDdzutJ60ewxsbGvrXsOoPBYNF69ohtsU9si31y9ba0ZuaN3VitZKKrOXaiuroaeXl5UpdBROSQIiMj4ePTcYh6uw9/k8mE2tpaKJVKlzj9iojIGoQQaGpqgpeXF+Tyjj38dh/+RERkfbyBOxGRC2L4ExG5IIY/EZELYvgTEbkghj8RkQti+BMRuSCGPxGRC3KYUT37orKyEi+++CIKCwuhUqkQFhaGVatWdRiPyFEsWbIEly9fhlwuh6enJ15++WVotVqpy7LYpk2bsHHjRuzatQuRkZFSl2OROXPmQKVSmQcbXLZsGWbOnClxVZYxGAx49dVX8d1330GtVmPixIlYvXq11GX12eXLl/HUU0+Zn1dXV6OmpgbHjh2TsCrLHThwABs2bLg+QqrA008/bR4ixyqEE6qsrBTff/+9+fnrr78u/vjHP0pYUf9cu3bN/Hjfvn1i/vz5ElbTP9nZ2WLRokVi9uzZ4syZM1KXYzFHr7+t1atXi1deeUWYTCYhhBDl5eUSV2Qda9asEStXrpS6DIuYTCYxZcoU899Ybm6umDhxojAajVbbhlN2+/j7+2PatGnm5xMnTkRxcbGEFfVP23E5ampqHHaYi8bGRqxatQopKSlSl0LX1dbWYseOHVi6dKn572rQoEESV9V/jY2N2LVrF+6//36pS7GYXC5HdXU1gJZPMcHBwZ0O02App+z2actkMuGTTz7BnDmOPa7/ihUrcOTIEQgh8P7770tdjkU2bNiAX/7yl5Lfu9lali1bBiEEYmJi8Pzzz8PX11fqkvrs0qVL8Pf3x6ZNm3D06FF4eXlh6dKlmDJlitSl9Ut6ejpCQkIwbtw4qUuxiEwmw/r167FkyRJ4enqitrYW7777rnU3YrXPEHYqJSVFPPnkk1b9uCSl7du3i8WLF0tdRp9lZWWJhx56yNy14OjdJsXFxUIIIQwGg0hKShIvvPCCxBVZJjs7W0RGRoovvvhCCCHEyZMnxfTp00V1dbXElfXP4sWLxUcffSR1GRZramoSDz/8sMjIyBBCCJGRkSFmzZolampqrLYNp+z2aZWamoqLFy9i/fr1Vv24JKX58+fj6NGjqKyslLqUPjl+/Djy8/Nx5513Ys6cOSgpKcGiRYtw+PBhqUuzSOud7FQqFRISEpCVlSVxRZbRaDRwc3NDbGwsAODmm29GQEAACgoKJK7McqWlpTh+/Dji4uKkLsViubm5KCsrM9/BKyYmBh4eHsjPz7faNpwjETuxbt06ZGdnY/PmzVCpVFKXY7Ha2lrodDrz8/T0dPj5+cHf31+6oizw+OOP4/Dhw0hPT0d6ejpCQ0PxwQcf4Pbbb5e6tD6rq6sz98UKIbBnzx6HPfsqMDAQ06ZNw5EjRwAABQUF0Ov1CAsLk7gyy23fvh2zZs1CQECA1KVYLDQ0FCUlJTh//jwAID8/H3q9HiNGjLDaNpyyz//s2bN45513EB4ejoULFwIAhg0bhs2bN0tcWd/V19dj6dKlqK+vh1wuh5+fH95++22HPejrDPR6PZ555hkYjUaYTCZEREQgOTlZ6rIstnLlSvzpT39Camoq3NzcsHbtWoc8ftFq+/btWLFihdRl9MvgwYORkpLS7kD8q6++atWdPo7nT0Tkgpy224eIiLrG8CcickEMfyIiF8TwJyJyQQx/IiIXxPAnInJBDH8iIhfE8CcickH/HzTSu10kgOrAAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def sample(key):\n", " x = ppl.random_variable(tfd.Normal(0., 1.))(key)\n", " return jnp.exp(x / 2.) + 2.\n", "_, ax = plt.subplots(2)\n", "ax[0].hist(jit(vmap(sample))(random.split(random.PRNGKey(0), 1000)),\n", " bins='auto')\n", "x = jnp.linspace(0, 8, 100)\n", "ax[1].plot(x, jnp.exp(jit(vmap(ppl.log_prob(sample)))(x)))\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "AEvnv1-__8jd" }, "source": [ "We can tag intermediate values in a probabilistic program with names and obtain joint sampling and joint log-prob functions." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:22.543343Z", "iopub.status.busy": "2021-01-28T12:27:22.542772Z", "iopub.status.idle": "2021-01-28T12:27:22.796378Z", "shell.execute_reply": "2021-01-28T12:27:22.795928Z" }, "id": "yDttqgL7_umZ" }, "outputs": [ { "data": { "text/plain": [ "{'x': DeviceArray(-1.1076484, dtype=float32),\n", " 'z': DeviceArray(0.14389044, dtype=float32)}" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def sample(key):\n", " z_key, x_key = random.split(key)\n", " z = ppl.random_variable(tfd.Normal(0., 1.), name='z')(z_key)\n", " x = ppl.random_variable(tfd.Normal(z, 1.), name='x')(x_key)\n", " return x\n", "ppl.joint_sample(sample)(random.PRNGKey(0))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Q45YW73E2uVK" }, "source": [ "Oryx also has a `joint_log_prob` function that composes `log_prob` with `joint_sample`." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:22.800806Z", "iopub.status.busy": "2021-01-28T12:27:22.800262Z", "iopub.status.idle": "2021-01-28T12:27:22.854941Z", "shell.execute_reply": "2021-01-28T12:27:22.854507Z" }, "id": "FjZIhP7n2uwm" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray(-1.837877, dtype=float32)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ppl.joint_log_prob(sample)(dict(x=0., z=0.))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "OP8boCwYA50n" }, "source": [ "To learn more, see the [documentation](https://tensorflow.org/probability/oryx/api_docs/python/oryx/core/ppl/transformations)." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "eglTKzL6A72r" }, "source": [ "## Layer 2: Mini-libraries" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "9LdSK3XzBMuV" }, "source": [ "Building further on top of the layers that handle state and probabilistic programming, Oryx provides experimental mini-libraries tailored for specific applications like deep learning and Bayesian inference." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "iGXK3SHGBTqe" }, "source": [ "### Neural networks" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "0l7OEJM2BYJu" }, "source": [ "In `oryx.experimental.nn`, Oryx provides a set of common neural network `Layer`s that fit neatly into the `state` API. These layers are built for single examples (not batches) but override batch behaviors to handle patterns like running averages in batch normalization. They also enable passing keyword arguments like `training=True/False` into modules.\n", "\n", "`Layer`s are initialized from a `Template` like `nn.Dense(200)` using `state.init`." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:22.860234Z", "iopub.status.busy": "2021-01-28T12:27:22.859675Z", "iopub.status.idle": "2021-01-28T12:27:23.155930Z", "shell.execute_reply": "2021-01-28T12:27:23.156300Z" }, "id": "a6c2IjijA7Sn" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dense(200) (50, 200) (200,)\n" ] } ], "source": [ "layer = state.init(nn.Dense(200))(random.PRNGKey(0), jnp.zeros(50))\n", "print(layer, layer.params.kernel.shape, layer.params.bias.shape)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "1XKSMZyuiD6v" }, "source": [ "A `Layer` has a `call` method that runs its forward pass." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:23.162248Z", "iopub.status.busy": "2021-01-28T12:27:23.161684Z", "iopub.status.idle": "2021-01-28T12:27:23.213477Z", "shell.execute_reply": "2021-01-28T12:27:23.213821Z" }, "id": "z0n7l3DZiNre" }, "outputs": [ { "data": { "text/plain": [ "(200,)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "layer.call(jnp.ones(50)).shape" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "J73S0GXjCLQ2" }, "source": [ "Oryx also provides a `Serial` combinator." ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:23.219726Z", "iopub.status.busy": "2021-01-28T12:27:23.219161Z", "iopub.status.idle": "2021-01-28T12:27:25.071751Z", "shell.execute_reply": "2021-01-28T12:27:25.071299Z" }, "id": "xQhmJAHVB5iN" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray([0.16362445, 0.21150257, 0.14715882, 0.10425295, 0.05952952,\n", " 0.07531884, 0.08368199, 0.0376978 , 0.0159679 , 0.10126514], dtype=float32)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mlp_template = nn.Serial([\n", " nn.Dense(200), nn.Relu(),\n", " nn.Dense(200), nn.Relu(),\n", " nn.Dense(10), nn.Softmax()\n", "])\n", "# OR\n", "mlp_template = (\n", " nn.Dense(200) >> nn.Relu()\n", " >> nn.Dense(200) >> nn.Relu()\n", " >> nn.Dense(10) >> nn.Softmax())\n", "mlp = state.init(mlp_template)(random.PRNGKey(0), jnp.ones(784))\n", "mlp(jnp.ones(784))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "g8h2nzyICpVd" }, "source": [ "We can interleave functions and combinators to create a flexible neural network \"meta language\"." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:25.078128Z", "iopub.status.busy": "2021-01-28T12:27:25.077473Z", "iopub.status.idle": "2021-01-28T12:27:28.033208Z", "shell.execute_reply": "2021-01-28T12:27:28.032720Z" }, "id": "NvLB8zxXChyr" }, "outputs": [ { "data": { "text/plain": [ "DeviceArray([-0.03828401, 0.9046303 , 1.6083915 , -0.17005858,\n", " 3.889552 , 1.7427744 , -1.0567027 , 3.0192878 ,\n", " 0.28983995, 1.7103616 ], dtype=float32)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def resnet(template):\n", " def forward(x, init_key=None):\n", " layer = state.init(template, name='layer')(init_key, x)\n", " return x + layer(x)\n", " return forward\n", "\n", "big_resnet_template = nn.Serial([\n", " nn.Dense(50)\n", " >> resnet(nn.Dense(50) >> nn.Relu())\n", " >> resnet(nn.Dense(50) >> nn.Relu())\n", " >> nn.Dense(10)\n", "])\n", "network = state.init(big_resnet_template)(random.PRNGKey(0), jnp.ones(784))\n", "network(jnp.ones(784))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "7-qBbDe_D8oV" }, "source": [ "### Optimizers" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "a3c3GW1LEGKm" }, "source": [ "In `oryx.experimental.optimizers`, Oryx provides a set of first-order optimizers, built using the `state` API. Their design is based off of JAX's [`optix` library](https://jax.readthedocs.io/en/latest/jax.experimental.optix.html), where optimizers maintain state about a set of gradient updates. Oryx's version manages state using the `state` API." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:28.039726Z", "iopub.status.busy": "2021-01-28T12:27:28.039169Z", "iopub.status.idle": "2021-01-28T12:27:29.794202Z", "shell.execute_reply": "2021-01-28T12:27:29.794579Z" }, "id": "b7Gfm0d2EBC6" }, "outputs": [], "source": [ "network_key, opt_key = random.split(random.PRNGKey(0))\n", "def autoencoder_loss(network, x):\n", " return jnp.square(network.call(x) - x).mean()\n", "network = state.init(nn.Dense(200) >> nn.Relu() >> nn.Dense(2))(network_key, jnp.zeros(2))\n", "opt = state.init(optimizers.adam(1e-4))(opt_key, network, network)\n", "g = grad(autoencoder_loss)(network, jnp.zeros(2))\n", "\n", "g, opt = opt.call_and_update(network, g)\n", "network = optimizers.optix.apply_updates(network, g)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "EGDs47TEFKXB" }, "source": [ "### Markov chain Monte Carlo" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "T7b6IdRwFP-k" }, "source": [ "In `oryx.experimental.mcmc`, Oryx provides a set of Markov Chain Monte Carlo (MCMC) kernels. MCMC is an approach to approximate Bayesian inference where we draw samples from a Markov chain whose stationary distribution is the posterior distribution of interest.\n", "\n", "Oryx's MCMC library builds on both the `state` and `ppl` API." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:29.799301Z", "iopub.status.busy": "2021-01-28T12:27:29.798719Z", "iopub.status.idle": "2021-01-28T12:27:29.800530Z", "shell.execute_reply": "2021-01-28T12:27:29.800865Z" }, "id": "wWTHfPWmGrAl" }, "outputs": [], "source": [ "def model(key):\n", " return jnp.exp(ppl.random_variable(tfd.MultivariateNormalDiag(\n", " jnp.zeros(2), jnp.ones(2)))(key))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "hQB7rhQ5GmN8" }, "source": [ "#### Random walk Metropolis" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:29.806209Z", "iopub.status.busy": "2021-01-28T12:27:29.805465Z", "iopub.status.idle": "2021-01-28T12:27:30.737514Z", "shell.execute_reply": "2021-01-28T12:27:30.737923Z" }, "id": "O27O2oTJE1Nu" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD9CAYAAACsq4z3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAABWUElEQVR4nO29e3hc9X3n/zqXOXOVRhdblizZlm0QNthgcAghuGlwA6ThkkKzJWWXbrrJsn02bcM+6dMm2aaloUmWpskmW8gmbHp5mt9Ctt0CKXYaIAuksXGCY2JjgY18kWzLulm3kWbmzJzr74+jGc9Io/tIoxl9X88THI81M585OvP+fr6f7+ciua7rIhAIBIKyRy61AQKBQCAoDkLQBQKBoEIQgi4QCAQVghB0gUAgqBCEoAsEAkGFIARdIBAIKoRZBf2xxx5j7969XHXVVXR0dGQf7+zs5P777+eOO+7g/vvvp6urayntFAgEAsEsSLPlof/85z+nubmZf/tv/y3f+ta3aGtrA+C3fuu3+PVf/3U+/OEP8/3vf59/+qd/4u///u/n/MaO45BIJPD5fEiStLhPIRAIBKsE13UxTZNwOIws5/vk6mxPfte73jXlsaGhId5++23+9m//FoC77rqLRx99lOHhYerq6uZkVCKRyPP4BQKBQDB32traqKqqyntsVkEvRG9vL+vWrUNRFAAURaGhoYHe3t45C7rP58sapWnavN6/vb2dHTt2zM/oCkNcAw9xHcQ1yLBaroNhGHR0dGQ1NJcFCXoxyIRZFuqlt7e3F9OcskRcAw9xHcQ1yLCarkOhUPWCBL2pqYn+/n5s20ZRFGzbZmBggKampnm/1o4dO/D7/fN6zpEjR9i9e/e836uSENfAQ1wHcQ0yrJbrkE6np124FpS2WF9fz/bt29m3bx8A+/btY/v27XMOtwgEAoGg+Mzqof/5n/85L774IoODg/z2b/82NTU17N+/n0ceeYTPfOYzfPOb36S6uprHHntsOewVCAQCwTTMKuh//Md/zB//8R9PeXzr1q384z/+45IYJRAIBIL5U7JD0ZVMZ0+MQ8d7GRhJ0lAb4uadTWxeHy21WQKBQDAjovR/Ep09MZ599TTxpMGamiDxpMGzr56msydWatMEAoFgRoSgT+LQ8V4iQR+RkIYsSURCGpGgj0PHe0ttmkAgEMyIEPRJDIwkCQXzE/ZDQR8DI8kSWSQQCARzQwj6JBpqQyR1M++xpG7SUBsqkUUCgUAwN4SgT+LmnU3EdZN40sBxXeJJg7hucvPO+RdNCQQCwXIiBH0Sm9dHuff9VxAJaQyO6kRCGve+/wqR5SIQCFY8Im2xAJvXR4WACwSCskN46AKBQFAhCEEXCASCCkEIukAgEFQIQtAFAoGgQhCHoisI0UNGIBAsBuGhrxBEDxmBQLBYVoWH/szLHew72EVcN4gENe66pZX79raV2qw8cnvIANk/Dx3vFV66QCCYExXvoT/zcgdPvdRByrAIBVRShsVTL3XwzMsLm2W6VIgeMgKBYLFUvKDvO9iFpsoENBVZ8v7UVJl9B7tKbVoeooeMQCBYLBUv6HHdQPPlf0zNJxPXjRJZVBjRQ0YgECyWihf0SFDDMJ28xwzTIRLUSmRRYUQPGYFAsFgq/lD0rltaeeqlDsBC88kYpoNhOXzk1tZSmzaF5e4hI9IkBYLKouI99Pv2tvHAbW0ENJVkyiKgqTxwW9uKy3JZbkSapEBQeVS8hw6eqK92AZ+MSJMUCCqPivfQBYURaZICQeUhBH2VItIkBYLKQwj6KkWkSQoElYcQ9FWKSJMUCCqPVXEoKiiMGLUnEFQWwkMXCASCCkEIukAgEFQIQtAFAoGgQhCCLhAIBBWCEHSBQCCoEISgCwQCQYWw6LTFV155hW984xu4rovruvzu7/4ut99+ezFsEwgEAsE8WJSgu67LH/7hH/K///f/pq2tjZMnT/Kbv/mbfOADH0CWy9f5/+p3X+fAm33YjosiS+y5tpFPP/jukto0udVtVFlZAzoEAkHpWbTqyrLM+Pg4AOPj4zQ0NJS9mL96tBfbcZEksB2XV4/28tXvvl4ymwq1un3tRFy0uhUIBHlIruu6i3mBQ4cO8fDDDxMKhUgkEjz55JPs2rVr1uel02na29sX89ZLwhee7sZxQZEuP2a7IEvwJ7/ZUhKbXnkzhm44BLXLC2Xm77deKyo9BYLVyI4dO/D7/XmPLSrkYlkW3/72t/nmN7/J7t27OXLkCA8//DD79+8nHA4v2KjZOHLkCLt3716IybPiPtWNLIEkX1Z02XFxXZbsPWfjJ6feYHNTEFm6bNPFixdRgzXs3n1DSWxaKSzlvVAuiGvgsVquw0zO8KJiIydOnGBgYCB7EXfv3k0wGOTMmTOLedmSosgSk/csrus9XioKtbpNma5odSsQCPJYlKA3NjbS19fH2bNnAThz5gxDQ0Ns3LixKMaVgj3XNuICtu3iOC627eJOPF4qCrW6TRmuaHUrEAjyWFTIZe3atTzyyCN86lOfQpoIB3zpS1+ipqamGLaVBC+bZWVluWRa3eZmubx3e0R0ShQIBHksOg/9nnvu4Z577imGLSuGTz/4bj5dpNeanG54886mBQnx5Fa3R44cKZKFAoGgUlj1/dCLJbjTvfazr54mEvRl0w2fffW0GCQhEAiWhPJNGC8ChfK7n331dNHyuw8d7yUS9BEJaciSRCSkEQn6OHS8tyivLxAIBLmsakFfasEdGEkSCvryHgsFfQyMJIvy+gKBQJDLqhb0pRbcQumGSd0U6YYCgWBJWNWCvtSCWyjdMK6bIt1QIBAsCata0JdacDPphpGQxuCoTiSkiQNRgUCwZKzqLJdC+d233bSpqII7Od1QIBAIlopVLeggBFcgEFQOqzrkIhAIBJWEEHSBQCCoEISgCwQCQYWw6mPo5UhnT4xX3ozxk1NvFL1dgUAgKF+Eh15mZNoV6IazJO0KBAJB+SIEvczItCsIarLoDyMQCPIQgl5miP4wAoFgOoSglxmiP4xAIJgOIehlRqZdgW44oj+MQCDIQwh6mZFpVxDUZNEfRiAQ5CHSFsuQzeuj3HptlN27byi1KQKBYAUhBL1CWMpRegKBoDxYVYL+nWeP8eLhbgzTRvMp3H5jC5+497pSm7VoxOxSgUAAqyiG/p1nj/H8wS4sy8anSliWzfMHu/jOs8dKbdqiEbNLBQIBrCJBf/FwN4os4VMVZEnGpyoossSLh7tLbdqiEbnpAoEAVpGgG6aNokh5jymKhGHaJbKoeIjcdIFAAKtI0DWfgm27eY/ZtovmU0pkUfEQs0sFAgGsgkPRA0e72X+wC8exMG2wHQvNJ2PbLrbj8qEbW0pt4qJZjlF6AoFg5VPRgn7gaDd/t/8EIb/C1pZazvXESKQdDNPBr6l8qEKyXKA4o/Qyi9/QmE59dZA7b2llz67yX/AEgtVCRQv6/oNdhPwK4aAGwObmWhK6QTio8eVP7imxdSuL3MWvtspPQjf4u/0nAISoCwRlQkXH0IfGdIKB/DUrGFAZGtNLZNHKJXfxkyWZcFAj5FfYf7Cr1KYJBII5UtGCXl8dRE9ZeY/pKYv66mCJLFq5iMVPICh/KlrQ77yllWTaJqEbOK5DQjdIpm3uvKW11KatOMTiJxCUPxUt6Ht2tfCxO7cTDmqMjKcJBzU+dud2ERMugFj8BILyZ9GHoul0mi996UscOnQIv9/Prl27ePTRR4thW1HYs6tFCPgcyFyj3CyX3/iAyHIRCMqJRQv6V77yFfx+Py+88AKSJDE4OFgMuwQlQCx+AkF5syhBTyQSPPfcc/z4xz9Gkryy+jVr1hTFMIFAIBDMj0XF0C9cuEBNTQ2PP/449913Hw8++CA///nPi2WbQCAQCOaB5LquO/uPFeatt97ivvvu4y//8i+5++67OXbsGL/zO7/DSy+9RCQSmfG56XSa9vb2hb71ktB+LsHhjgTjuk1VUOHGtjA7NoVLbZZAIBBMYceOHfj9/rzHFhVyaWpqQlVV7rrrLgCuu+46amtr6ezsZOfOnQs2ajaOHDnC7t27523vTBw42s1PTpwg5PfTuFZFT1n85ESarVu2rMi48lJcg3JEXAdxDTKsluswkzO8qJBLXV0dN910EwcPHgSgs7OToaEhNm3atJiXLQmiUlIgEJQ7i85y+bM/+zM+97nP8dhjj6GqKn/xF39BdXV1MWxbVobGdGqr8ncKolJSIBCUE4sW9A0bNvDd7363GLaUhGyHwdEUw7EUa2sC1FR51ZGiUlIgEJQTFV0pOhuZDoMJ3WBtbQDLcugd0hkeS4pKSYFAUHZUdPvc2ciNm4cBWZLpH0lyaTTFVRvrRKWkQCAoK1a1oE+Om0cjfqrCPkbG06JfukAgKDtWdchFdBgUCASVxKoWdNFhUCAQVBKrOuQiOgwKBIJKYlULOpSuw2BnT4xDx3sZGEnSUBvi5p1Nix7yLBAIVjerOuRSKjp7Yjz76mniSYM1NUHiSYNnXz1NZ0+s1KYJBIIypiI99Gde7mDfwS7iukEkqHHXLa3ct7et1GZlOXS8l0jQRySkAWT/PHS8V3jpAoFgwVScoP/eX7xEV38y+3c9rfPUSx0AK0bUB0aSrKnJz6QJBX0MjCSneYZAIBDMTkWFXL763dfzxDxD2rDZt4KabDXUhkjqZt5jSd2koTZUIosEAkElUFGCfuDNvmn/La4by2jJzNy8s4m4bhJPGjiuSzxpENdNbt7ZVGrTBAJBGVNRgm4708/qiAS1ZbRkZjavj3Lv+68gEtIYHNWJhDTuff8VIn4uEAgWRUXF0BVZwppG1O9aYcVCm9dHhYALBIKiUlGCvufaRl492jvl8dZ1Ie7b28Z3nj3Gi4e7MUwbzadw+40tfOLe62Z9XZEzLhAIyoGKEvRPP/hu4HUOvNmH7bgossSeaxv59IPv5jvPHuP5g10osoRPlbAsm+cnDkpnEvVMzngk6GNNTZCLl8b52lO9NNQG2dpSs2TiPnkRCfhkDp8YyFa0bmtyWAXTtgQCwTyoKEEHT9Q/XeDxFw93T4i5AoCsApbNi4e7pwh6ricvARvWRXjPzvWMjKc43zeOBCQmDjWfffV00ePfkxeRjvPDtJ8dpr5Ko7Y6QEI3+NHRJFu3dIs2BQKBIEvFCfp0GKaNT5XyHlMUCcO087zh7r4x3rkQw6d4nnzadOjsHUdT+/D5VDRVwafKJNPWkhUETS48Ot8XR1NlTJvsvNNEUmf/wS4h6AKBIMuqEXTNp2BZtueZT2DbLqoq5XnDrx65gATIkoQsySiSi+26nOkZY8O6KkIBFdN0CAd8wNIUBE0uPNLTJn5NwbSc7GN+H7POOxWxf4FgdVFRaYszcfuNLdiOi2nZOK6DadnYjsvGhkjWG5YlCccBSbqcAqnIEpIElu0S8qskdRPDcmhpiABLUxA0ufAo6PdhmA4+9fKvK20yY9920S9GIFh9lLWHPp+eLZk4eW6Wy4dubKFvJMXZizGSaYtwwIcsg+OAOxGdURQZy3ZQZYlISCORMtm4LkK0yp8tCLrtpk1F/Vw372zi2VdPA94OYGNjhPazw1QFwXEd9JSFYbkz9m0X/WIEgtVH2Qr6My938NRLHWiqTCigkjKsWXu2fOLe6/IOQDt7YnztqSNIeMJpmDZBv0pct3BdF8d1sG0XF7jrvZv4xL3XTQlj3HbTpqILZKbwKPM+bRvruHbrmrwsl91b1Bnj54X6xQzFdH5ytIdX37hAfXWQO28Rvd8FgkqibAV938EuNFUmoHkfIaDJgMW+g11zbsJ16HgvG9dVcb4/jjkR0lhbE0SSUqQNG9Nys558ZiFYroKgQu+T+7mOHDky4/MbakPEk0bOweoYv3jnEn5NobbKT0I3+Lv9JwCEqAsEFULZCnpcNwgF8s3XfPK8erYMjCRpWhshFPDRPRAnkTIJBXxsrw3x+Y+/p9gmLyuTwzYnOoeRFYn6aCCbKQOGyJQRCCqIshX0SFAjZVgTnrmHYTrz6tmS8WJrqwPUVgcA8rzaubISs0kmh20My6ahNkjQ78v+TDCgzpopIxAIyoeyzXK565ZWDMshZVg4rvenYTnz6tlSjK6HKzmbZPP6KA/csY2HP3oDrU1RmNTmRk9ZM2bKCASC8qJsBf2+vW08cFsbiiITixsYps1VG6Ncv23drM/t7Inx1Asn+f6/nsnmdy+062FuNokseZkwkaCPQ8en9pQpJXfe0koybZPQDRzXIaEbJNP2jJkyAoGgvCjbkAvA9dvW0dU3TiToIxT0kdTNWUvxJ5fVJ3WTuG4uuHy/XKYPZeLk+w92ZTNlfuMDIstFIKgkylrQF5JrXez87MnZJLBypw/t2dUiBFwgqGDKNuQCnnccCvryHpvNO17Ic2ZCTB8SCAQrhbIW9IXM5iz2PE8xfUggEKwUyjrkMjnXOhMPn6kUfyHPmQ0xfUggEKwEytpDX4h3LDxqgUBQqRTNQ3/88cf5q7/6K55//nna2uZWel8MFuIdC49aIBBUIkXx0N966y2OHj1Kc3NzMV5OIBAIBAtg0YJuGAZf+MIXeOSRR4pgjkAgEAgWyqJDLt/4xje45557aGlZWH5ze3v7gp43W7fB9nMJDnckGNdtqoIKN7aF2bEpvKD3WqnMdg1WC+I6iGuQYbVfh0UJ+i9+8Qva29v5gz/4gwW/xo4dO/D7/fN6zpEjR9g9w8j7A0e7+cmJE4T8fhrXqugpi5+cSLN1y5aKKayZ7RqsFsR1ENcgw2q5Dul0elpHeFGCfvjwYc6cOcOv/MqvANDX18fHP/5xvvzlL7Nnz57FvPSi2H+wi5BfmWgRy6ytYg8c7c4riReDHwQCQTmyKEF/6KGHeOihh7J/37t3L9/61reWNculEENjOrVV+V7/dK1iDxzt5u/2nyDkF4MfBAJBeVPWeejTUV8dRE9ZeY9N1yo215vPDH4I+RX2H+xaJmsFAoGgOBS1UvTll18u5sstmBu3N/APL5+GEZ2gX0HzKdgO/MYHWqf8bN9wElyHkfE0PlWhOqKtmMEPK3FwhkAgWLmUdel/Lpk4eN9wEtOyWVutkUjbJFIWpu1y3y9vnRJC6eyJYVo2juMS0BRsx+uLHg6oJR/8MLnNb2ZwhqhqFQgE01ERgp4bB8d1cByXS2MG111Rz6amKPGkQcp0pjzv0PFetq6v5p0LMQzLwadImI5LLGHy4K9uX7A9z7zcwb6DXcR1g0hQ465bWuc8uDrXtmK2+V0JrOQdx0q2TSCYKxURQ8+Ng1u2522risSpC94YuOna4w6MJLlyUx3XXVGPpiqkDIeAptDaWLXgA9FnXu7gqZc6SBkWoYBKyrB46qUOnnm5Y16vU+w2v6VmJY/qW8m2CQTzoSI89NysFp/qhU58ioSe9g5Gp2uPmxlOsakpyqYmzxtbyJDoXPYd7EJTZQKad2m9IdYW+w52zctLL9bgjJXiea7kHcdKtk0gmA8V4aHnZrVURzRsxyVtOgT9yowDJ5ZiOEUsnsKybcYmXsu0bDSfTFw35vU6K32AdWYu69e/9wZPvXBy1tdcyTuOlWybQDAfKkLQcwcg+30y6bTBeNLk0miKf3mtC9e2C3paxW6l29kTQ5IkLMtFBlzXRU9bJFMWkeD8vP5i2LZUA6wXslAUe7BIMVnJtgkE86EiQi7NDVVsba7m7c5hzvWO4wCyBJpPxrZd/vVYL9HIMT5x73VTnlvMVrqHjveypTlKx/lRbMdFUSRs28W0XT5aIGVyNhZr21INsF5IiGIpBosUi5Vsm0AwH8reQ894i/XRIHfc3Io08Yk0n4IsyfhUBUWWePFwd97zDhzt5rNPHOChL7/EZ584wIGj3QVefX4MjCS59sq17Nxaj6rImJaDqshsWV897yyXYrBUnudCQhSTdxym5eDXFL7/r2fmFLJZSsTQE0GlUPYe+mRv0XG9VcqyHTRVAUBRJAzTzj5nqcr9MweZ2zfXs31zPbD4Q9bZmKkPzVJ5ngs9sM3sOHJz7KMR/4rIsRdDTwSVQNl76JO9RUWScPHi1xls20XzKdm/L1W5/2wHmc+83MF/ePRFfuNz+/gPj74471TGyWQWpoRu5C1Mmd3GUnmeiz2wXarYvkCw2il7D32yt7h5fRWnusdwHHBcB9t2sR2XD9142fOeT/Ou+ZAR0Nw0wdtu2sTm9dFsfrqmynn56cCCwzFz6Sq5FJ7nTJ9zLixVbF8gWO2UvaBPDitsa61nZCzFcNwgbTgossT7rmvKOxCtrw6S0I2sEML0zbvmw0w538XKT89lqRamubCYhaJYOfYCgSCfshf0XG/xTPcoQzEdV5K5ormGzc3VaKpC31CSJ/7xKKbt0FAb4sbtDfzg0DnAIBjwBmAk03bB5l1zZbbeK3HdIBTIv9wLyU/PZakWptlYbLGSyCqZykopAJsv5Wp3pVL2gg5kb6C+oQSj42nqqj2v9Z1zIzSvjdAzGCcWT5NMG7x8+AJuznMlQJEl9lzbSHNDFU+9cHLam3Omm3e2VL5IUCNlWBOeuYdhOvPOT8/lzltaJw5zi7cwzUYxmoYtNmRTaZRrI7ZytbuSqQhBh8uCatoOIb+KJEkAdJwfpSai0X1pnGTKnvI8VYGAX+XVo728evTyoVxtxEffUCJ7c852884WF77rltaJmLmF5pMxTAfDcvjIra0L/syZOHlulstvfGBppy0Vq0xeZJVcplxbD5Sr3ZVMxQh6RlDDAR+GaaP5FHw+GT1tUlOlZcVckiAnAQbTBjdlTnm9kbjJ0Xf6aawPs3l9dNabd7a4cCZOntuF8SO3zr8L42T27GopmoDPZRSfONAsPuV6TcvV7kqmYgQ9I6gtDRFOnhsBvNRFTVWyTbqmw5raWRfwRD1zc852884lLnzf3rZZBbxUMcm55uaLA83iU67XtFztrmTKPg89QyY3uqtnlN6hOKe6RzlzMUY4pNBYH0bK/KA706tMJXNzTq66HBlL8YuTA5y96DWpAhad812sZlqZKtiPfeGH/IdHX+Rz3zwwazXmXHPzl6Kh2WqnXK9pudpdyZSth14oPFAdVPnX7jEkXPw+GVmSGBhJc8OVDfQ3hLgwkJyi5z7FC7tMx9MvvsOPXj/Pe3esI657nn7asnn77BCuCzu21ufF0x+4Y1v2uZmOhJO97em88GLEJDOetiK5WJaL41qc6R5F88l5ZwKTmWsKpDjQLD4r/ZpOvl83rotwvj/OwEgSv6ZgWt6kr5Vm92qkLAW9/VyCn5yYGh5ITKQGZnK9AVKGxWvt/fzN52/nkW8f4EjHUPbfZAlkRUF17WnDLgCXRnW+f6CLhtoADbVhRsZThAM+trREqa0KZH8uV3inO0R999XreP3t/oKHq8WISWY87bhuocgSmqJgWjbn++K8Z0dk2sVhPimQ4kCz+KzUazr5Pr44MM7LP7/AVZtqaFoTyYYWRWbLyqAsQy6HOxIFwwMJ3csgySWT693ZE0NWVJrXhrmyJcqVG6KsXxtm59Z6vvZf3s/2jdUzvqcEDI6mqKv2Yzsum5urGU8Y/Oj18zz/kzMcOt5L+5nB7M9PV96+/2DXtGXvuWGdkbEUx08P8tqxHvqHknMOuwyN6QQDKqblIMteoElVvcPhmRaH3BbEjuuQ0A2SaZs7b2md0/sKKpPJ9/HwWJqQX2E4lhZtG1YgZemhj+s2jWvzTQ8GVCRZwjCdgrneh473EounCQd82b4uEhJjcYNDx3v5i0/dyt2f/v6M7+u4cKJrhJBf4dipS4wnTHyqTMCvkjZszl6MceBoN3t2tUzrbQ+N6VyztT7vccOy+Wl7L2trggyM6NRVawzF0t4OQoa6qH/O+b0ZT9unyti2g6JIWJZD0O+b8cBqthRIUUCyOpl8HydSJsGASiInM2wlZ7astvu2LAW9Kqigp6wp4YG1NQFG4waFcr3PD8QxLJtwTiMvnyqTSFtzuhkzsfe0YaPIcGk0RSigoioyju0iyxJVQTXbR2W6DID66iBJ3cw+PjKeov3MEOGAj83NUQKawrHTgwQ0hbU1IVrWRaitChBPGlPCJYVu1kyxkU8Bw3QxbQvbdtjaEp21GnO6FEhRQLJ6mXwfhwO+KaG5lZrZcuBoN//nR6ewbIdoWMO0bJ59dfpzpEqgLEMuN7aFC4YHPnbn1SiSzXjSZCiWZjxpokg29+1to6E2hKYqmOblYLlpOWiKnL0ZW9fNfFMGNQW/phD0q97Bq6Zg2S6KIrMmGqCmOpA9RJwuA+DOW1rzHj/bHUOSYEtzFFmSaG6oojqksSYaZOcVa7Ix+sleUN+IUTAjprmhio/duZ26aAhVlQhoKltbamjbWLfgG3kldUec3LHywFtikPNSMvk+rqv2k0zb1EX9KzqzpbMnxv/5UQcSLtGIhmk5nO8bx3acig4PlaWHvmNTmK1btkwJD/z9/naS6fyfTabhoS/+kM/+9s2cvjBKz2Dca60rgZ628Ckyr7/dx6tvXKC+OkjIn5zyGuCFPjRNIehXcFyXoN9HVdCX56kkdCPvENGvKbSf9Q5hr9pYmxXU5oaqrGdt2g5Xb6mntvry4Wp1WCOWyO/xMtkLOnFBp7omXDAj5oE7thWl2CizA3j1jQvUVQXYsK4qa2cpttmFOlb+uF1nw8sdJRkgshqYnIHT3FDFzTubslkuKzWz5dDxXmzbpTqsISFlw6xDoyl8qjLLsxdOqUM8ZSnoUDg88Nh3jxT82d7hNJvXR3nwQ9v5wWudvDNReFRX7efiQMJLvTJtOntipAyIhlU2rIuiGxa9lxLoaQvXAcO00dM2G9aFaV4b4czFMQr1UckNUbz7msZsJkCG3IyGp144ycVL4xw/PUgiZRIO+PBrMmpKJp40pi1SiiVtGpuWbrBx7meoqwqQTJmcPDfCtk211FYHSrLNLtSx0rTMRXWsFMzOSs3AmYmBkSTVYc8zz4i5zycTixtce+XaJXnPlRCaLFtBXwib10f55Ed2Zf/+2ScO4NcUEikvxS8UUEmkLGIJi8S5YXyqQkCTMC0Jx3XZvL4aPWVx8VKSj925nV/a1ZzdJdiWTcpw+B//cBSQ2LQuzPXbGoGZ88k3rovw8s8vEPIr3mGTbnBp1OZDN28iZTrTekHRkJIXi4fixjJzwywbGqs42TWMhMuF/nF8qpxdYJbTIynUsdIns6iOlYJ8Su1hFouG2hCmaXO+Pw5452XJlImqyEsWHloJvW3KMoa+EL763deB/Bjs213DjIylUGQJRZbzWgS4jovtOMTiJq7r4vep6GmbcFDjqk01nO+Ps2dXC1/+5B4+eNMmYkkLFwgFVEzT4lT3GG+dvZR9vem85/P9ca7aWEM4qF1+/Y019A7N7Glv3xBc0iq93ElQtVUBtrXWEQr4GB5PZatggaJUts6VSFDDMPMLBkyHKR0rMwVdX//eGyWfV1pOFKtSeSVw884mFEVm47oIPlVmLGHgAvd/4MolE9eFzNotNhXloTfV+ekdLhAAB1492ktX70v0jqSzMdi0kUY3XCTZIhzQSKUvl4y6eG11Dddr5nX9VWvZ2Ojlqjuum/dLmhwK0HwqpmVz6sIY12xZy8h4irPdMUzb4akXTuZ5PQMjSZrWRmhuqMq+3tCYztFTl7hx+7ppt26NtRpXX7101YWTsxtqqwL4FJlrQ2uz1bBPvXByyT2SXI9x/dowb3V6ZxKZLCbb9jpZ5v58qbe95cpK8DCLRW7s3+dTuPbKtUu+21gJvW0qStCf/K8f5KEv/nCKqKuKhG27dPUnqQr5ssIbCqjEdYtkyibkd7An2jD6fTK242LZDorkFeZkxBwu/5I6e2L89T+3c2nUy2yJJ01CAXWiHNrCMG1+1t5L71ACVZHZsbVuisDk3gQj4ym6++Oc6xvDp8iYtpPNKoGpX6yljG3OpdnYYitbZ9red/bE+MHBTo6eukR1SGNzczUb11URTxoMxVIkUxaRoMZNVwby4uflJEorLbxRad0Tlzv2v3FdhP/zo47sYWx9NICiyMs6uKWiBB08Ub/n099HkshWSsLltrm5laSqIqPIYDswGjeQJO/fA5rKmpogQU1leEwnkbK4eGmcodEUsYSBqsjcekMz3/jeG3T2jGVfzwUSKS9soygytuUwOJoi6FeJBH30DiaJRgLZlL/N66NZ4Ywl0pzrHUOWvHh9OKhysmuYba111FYFluyLNV3L3Ln0F1mMRzKTJw1eKOfiQJyqkLeFfefcCNta67h6cz2RkJbdJRw5kn8QXi6itBJ3EivBwyxXOntivP52Pxsbq7I6kUhZSxriKUTFCTp4oRLbcXEcN1sQlOmBnqkkNS0bPW0hSxKKT6K1sZq+oTim7RIOqPh9MgndwHbg/bvWc+zMELbtEp1YeV95o5ueQc/zNiY1gkmkLGQJ2jZESVtuduCGYdp098e5Zmt9VmAywvnks8dxHKiKaDTKMsrEYtTdH6e2Kj+rpLMnxitvxvjJqTcW5dnN1jJ3Ng9nMaPkZvKkwZs+deZiDNtxURWJmoh/yrUrRLmI0krcSYjRgAsn9/fZvNYLn8aTRvZQdrlYlKCPjIzwh3/4h5w/fx5N09i0aRNf+MIXqKurK5Z9c+KRbx/gjY4hXLyeK1UhmbGkmzfIAiDsV9ANz4O2bBvHcXFc2Nlax/bN9cSTtQzFdMYSZl5++/n+ODu3rskTic6eMQzTQVMlCuG4UBMNEE9YeQM3EilzisBsXh9lXX2Ia7bWI0sSI2MpTp4bwadKxHUje+CZySp59tXT6IbD5qbFeXa5LXOBiT+NbLVrLtN58gvtEjiTJ32ud4yzF72DOEUGx3EZjKW8Hjrrq2cU53IRpZW4k1jpXR9XMivl97koQZckiU984hPcdNNNADz22GP85V/+JV/60peKYtxcmNxB0QXGkg6aChPajYQX37r2yrWc7x+n51KC4TELRYY10QCxhMHx04OsbwgTDPj47MduynuPr3/vjSm/rOqwhiyBZU9tsC7h5bx2nBvl3Vc35g3c8ClyQYHJ9SxrqwNs21TL2YsxpIkYeuaLlTmIdNLyjPH1QkyO2fYOxVlXly+OhVrmzsWTny8zedKvvdmLIkvIkoTluEiSi+TCWCI9qziXiyit1J1EOeacrwTm+vtc6nOTRQl6TU1NVswBdu3axdNPP71oo+bDGxNiLuU4yq7rifn9t7V5Hu/EYeNLPztHKiftzXZgcFSnPuoJ7dtnh9ixZc0Ub7Q67COpq3m/rPpogOqIxuj41BxoL1fa9SpRVZm2TTV0XRxjXDfZ1baWD71385Rf4mTP0qfKNDdEpnjeAyNJFEWiqz/N+eFewgEf6xvCJKfxBDKfpXcojmW7XNES5YoNtcSTBpbtMjqWoi56+aYr1DJ3Pp78XJnJk/6nlzuQZAlJklBwsR0JSfJ2XHPZiZSDKJXLTkIwN+by+1yOc5OixdAdx+Hpp59m7969xXrJOTHTAKKkbmLaDie7hhkYTuaJeQbbgYGRFAMjKTRVotM/yi86LuV5o72DcdbVpdnaUsOlWJKOc6PoaYs10SCO7TCW9LYCEmSzXGzHZU1NkEhIIzmSZPf2dTOuxnP1LH2KTPvZQRzbpTaiYph2diHKXf19ikzP0Dgd52L4NQUJL3RxsmuEgKaysbGaK1qinOwawa9NrXbNZabhFwv1OGb6vNFIYKJYSEKSZTRFAlwiQW3FC/VcKZedhGBuzOX3uRznJpLrTo40L4w/+7M/o7+/n8cffxxZnr1eKZ1O097evuD3+5fDQxzt1MkdF5px0jMf6NrWEINjXmFQ78jMc0XBG3jhuFBXJVOVU5GoGw6KDDVhmTO9BqoiURWScV0J3XDwSQ5xA1TZ2ylYjteaN+SXAImqoMKNbWF2bAov+PNmeP71Ybr60/h9MqoClg1p06E6pJBMOdiOiyJDynJJ6A6K4i0CuuEQ8IGLhE+R2LUljOu6nLtkgOu1JM61s/1cgsMdCcZ1m0TKxq9J1IYvF03ohoMsQUONRkCTCPgkUqZLynB57/YIjbXaDJ9iZg68FePH7eOe7bJXPGTb8Ms7qthzzeoQvL4RgxMXdGJJm2hIYfuG4KKu6XJT7vYvBc/9dJjqoIyUE05wXZcx3eHX3jP/c8cdO3bg9+c7WkXx0B977DHOnTvHt771rTmJ+WxGzcaff+tFDp/WJzJBLq9HuSvT7rZ6/v3dO/nvT7/hPT4yxmyoioRhuciyD9UfYDiWIpnyFgRVUQiFAjQ3+Kc05JIVmW21QY6fHiJt2kgS2K5LIBCgpsqPnrL4yYk0W7dsWXTTrJ+ceoN3rZE4eaYP2ReguspHddjHm6cHWVsTJBT00TuYwHEdkFwkSSIY8GM5aWxXIhLykUpbrF+/nnjS4L3NWt7YPPDCNN5EKD+Na1UGhhIMjxv4NZm1tUH0lIVk22xtiVJXHcgLRcWTBjFb487d2yabnsdMnv3u3bDh5Q72HewirhtUhTTuuqU1m2+e+1xLH+XX79hdUZ5tZ0+MH79zmuqaMI1N3vb9nQGTq68uvDU/cuQIu3fvzv79O88e48XD3dnD+NtvbOET9163Yu0vBgeOdvO9Hx7HcNW8Q/uVxDuDJ6fE2eNJg+Zmjd2zfF9ymckZXrSgf+1rX6O9vZ0nn3wSTVueFfhopyfmPlXBp3pj5pwJNZeAG9rqeeQ/7QHgPTuaiCcNunpmF3TDcpGAhG6RMhIYpjf1x3W9IqMLA3Ga1+Z72cGAysh4mo/efj1+zYuP/bS9l5Rhk0hZBAO+aWPO04naTGLXUBvi4sB4ng3dA3FkSSIU8CEh4Tje4oQL9sShbUDzJjqlDZuAphJPGvQNJam3HL7+vfz0x8kx88Y1VcA4cd1CVdPZ7J+fnxxYUKlzobFmX3uqh3V1YbY0e7n59+1tK9hwa/JzO2PDJc/fLjaL2Zp/59ljPH+wa+L7IWFZNs9PDPpeLlFf7pTMzKG9azusqZt6aL9SWI5zk0UJ+qlTp/j2t79Na2srH/3oRwFoaWnhiSeeKIpx02HaoPkub1sCmorjOpiWy3NfuSfvZzMXsSasMpqYPezi4k0QUhwJaSIE47re1KBY3GBkLE04MHXuZu5NrKdtApqM48JY3CBYp07JHvGa71+uKjNNr/n+TDNHN6+PZpt5ubbDmohCQjcYGNGpr/ZnO8t504psZDlzQGwjI6FpCo7j0rQmjGk7uLj4VJlolT/vfTIxcz1tMpYwMS0HVZUJBSSe/Oxt2c9wvj++oEyN3Gs1MpbifH8cCc9bme2gaLJYBDU5r1BrNlZadWYhFpMC9+Lh7qyzAyCrgGXz4uHuZRP0gZEkqiLldRBtbggvWQpfxgFxLDk7knKxh/ZLwXKcmyxK0K+88kreeeedYtkyZ3yK53nKOdbbtpttk5lL5iJeGtHRL4yQLnAwCl783AVUxfPyJUnCdb2Kz/oajWhVAMdxGU+aJPSph4g/PzmQ/RIG/SqGZaOpXgET5GePeM33TyFBtsXn+f44G9dF2H+wi83rq6f1bs73x7lqUw3nugezzbwaakHCzRY4VYV89I947XpVRWJsord6Q22Qj915NXt2tfDUCyfxKXLB96mvDjIcS5JI2yiShKpIpA0bWZbo7Illb8CFehy5gtU9EEdTZXw+2Svnn8WbW4zYrcTqzEIsJqXRMG18k2ojFMUralsufKrM8dODhAM+QgHv4L79zBA7r1izqNedrhYi44Ak4pdbfhRKv10JLHUGVllWiu7aHOTwaR0sG2WiT4vtuHzoxsKr8eTCHYDzfWMcfPPy5BIXCAdU1tYGuTSSYv0aL7SSWSQM06a2KkDbxlpicWPK3M1cb/XKDVGOnR7CcexsxWlu9sih473eWKxIfvP9k+eGGRxNcXKide+VG6q5ZsvaPMEaGEnStCaCZI6zfr3XWXF4TOfYqUHaNlQzFEuhpy0UCSxJQpZl6qu9Oap2zlo2kzDeuL2B/++HJyc8eAVF9ha3to01eUI7X48j4x2fvRjjQt84W5qjJFImIb+KaTqEA748OwqxGLFbidWZhVjM1lzzKViWXdDZeSbnXCISzD+XKCpuThpxJhSaf9yVZa7x/plqITJzdHMplH5biHLYsc2HshT0X72xnsZ1at6N8KFZDn5yheB83xhHOy4he0ko+BQZx3G9cW8urKsPUR3R6BtK4OKCC8m0xfo1Ee6/7aqCv/DcL+GGxmpSaYszPWMwsQXMHbg8MJL0Zhyal5vvj4zrDI6mQQJFAtu2eeusV5C0qTGaFazM58hFUxV2ta31OiJOdJZ7/e0+HNuZcoC7fyKe+uapy9vhKzdE2dQUJal7Q6+7+sapjmgkdC/c4sgS21vruGJD7RShnavHkesdt22sof3MEO1nB9FUxTt4Bra0eK8zk0BPFjvdcJDnKHYrpZpvNhazNb/9xhYvZj7J2bliXWjKtKenXuoAKLqoZ6Zw9QwksvdYa3M1pp2/O86N98uyS9qw+P6BLk6eG+GTv3F93uedqRYiM0fXtR3CrjNt+u1kymXHNh/KUtDBO+CZTsAzq+7ht/u4eCmBbduEAj7qowGu3lzPsVMD6GnbO0h1wXQdVEWifyRJbVWAj93ZRnNDFT842Mk75z1R3XnFmoIFQRkmfwmvaq3nY3fvKPjzDbUhTMvmfJ93uOnzyVwandguupC2vNRDRZboOD9KfTSUFayMoOmGg+O6We9t8k346hsXCuaOX7wU94ZIqxKq7B0oHzs1SCptISsyI2NpDNNGlrxZq9VhP4ZpY9nuoioZJ3vHO69Yw9nuGHHdRPMpbGyMEI3489ocQOFtdu51DmrynL+AK7U6sxAL3ZpnvhOTnZ3X2vunTHsCa0mmPWWuc26IJZ40iFTnJ01k4v1eRTDZiuDTF2NThHWmWoiMo/S9Hx5nZDydt3OeiXLZsc2HshX0XD7yR9/Py0cH2N7qlc4rsoQkS6RNm3P941iWQ1z3Dgx9CriOlzduWi6qCh+7c3v2Rvjkv9k1Lzvm+iX0RDmR7cx2rkAGju0w0ZjqcnVkZqFK6Cb9IyaSb4wtzdGC3ltmG5rroespi5RhUxPRCAc1ggEfY3GDZMrkbM8YV26oJW3aVId9SJLE0FgKgEjIx2g8vagpRZO949qqANdv8zM4qvPh920t6I1Ot83+2J3b87otzvXLt1qqMws5Oy8e3jdl2pPmk5dk2tNcr3Mm3m9ZnmclSZ6n7jhMOeie7n7OhFX27GohaPfnpW/ORrns2OZD2Qr6x/5sP0Nj02etnOgaQZbAp3pDnb1SH5dzE93PHMcb16T5VGTHwbIcVEXh7//lRHYbV6wT8pm8TJ+qZGecFqK2KpAV88z2cHNzFKw4oYA6rZhmtqGTZ576VM+zAQhqKsE6L0PofN84jfUh9LTXTKw67H1xRsdT9A0lcXHx+2R+cbKfrr7xeW9TZ/KOp1sIi91yYDVXZ0aCGinDmvDMPQzTmTLtqRjM9Tpn4v25oXUvEUGaIqzT3c+ZGb6Hjvdy4vQw7wyenPY7MdkR8alytpq8uz9OImXiU+Rs6K8cKUtB/+o/dTNeeDBRHo6b8QJkDMNich8tywFMC8cBB4gEVWzL4e3OIdrPDvHE/z3Gv9l75aK2pHPxMl890s00Z0bZaTzzTdfLCF7uQvIbH2hl/8Gugp6OqiiEgj5a1kU42TUMgO14ufQBv0JDbRDHcfmH/3eKba21NNZ7h8Zz3aYuxDueaZtdiLnsHMqhz8tScNctrRMxcys77cmwHD5ya+uSvN9crnM23j+B63rdT69cXz0lFDbd/dzcUJV1dKqD8rQORqF4+eCojp6yGI2nCflVVEUimTYZHNXzsrnKibIU9LmIeQbLcVEhT8xlGXA9wc+0MldlbyDGwLDuDceQIJma/uBorhkDc/EyM56KbbvkHhvJwPXb1vHUCyd59Y0L1FUF2LCuitrqADD79nDPrpaCnmwhT+eaLXUkdTM7P7S7P05XbwpVkWisCxH0exkoznCSt88OMzyWJhzwFoBoxD/rNnUh3vFs2+xc5nrAVWlZDXMlc2/m3rMfuXWJslzmSCYs9C8/PY9hee012pqradtUV3CxL3Q/545BHBudvvtooXh5E/B25xDhgA/T9rKstjbX4FPlso2jl6Wgz5eUcTkHV1O9/idInmC6EylWAb/CwLCOiyfmuOBKEpoqTzk4eubljhkzBnJDLJeGdfyahG88jU9VqA5rU7zMjKeiKBJaTmbC+65ryopUXVWAZMrk5LkRtm2qBRZ2oLdnVwsDw0n2HeyibziZXYyu37Yu60FHI358isyZi6M01l8Wcz1t4kyM5svkF5/sGmZjY1W2qf9MzNc7nmmbPZm5HHBVYlbDfJiu+raUZOL9uQttbrvo2cjEwUfGU9kOpCF/fmfU3J/LJRT0kUxbvO/6lmw6M0ydGVxOrApBzyAD0UiQlGGip20sx0WWwe9TkCQZF3uiK6H385LkYtk2w2NW3nDnyUOhczMGGupC2RBLcKLrYjLlEg56hTmDMZ1wQM3zMqfLTAiF/Nm484bGKk52DSPhcqF/nLBv7ul6uXT2xOjqG+c9OxqzoY+uvnGu37Zuige9fk0Yx7m8tRlLmMiyhIqMaTr4JgqnzvfF+cgSCMV02+xCu465HHBVYlZDpbDQUFhDbYiLl8Y53zeOabvUVqkkdZNEyswLm0x3hlNfHSSpm2WR+TQXyk7QO3ti8/r5+mo/TWsidA+MMRo3swOdM+zcWo8iSRw/4/VVz41jSy5Yloum5sfm4rpRMGNgLGnwN8+/TTJl4jg+z5MNqiR0z7MMaCqm4xJLGDz4q9vznl8oMyF3sEYmFHKhb5zh8RRrGpV5eZaZXUNX7xiaT2bbploiIS1P1B64Y1ve621cF8n3kNMmErCttRbbxisKCvgIB31LJojThY0y9I0YPPXCybxipUxIavIXcyVkNazWkM9ScfPOJr72VC8SXtjUNB1cJDaui+Qt1NOd4dx5Syuvv90/5fFyzXwqO0E/dLw3bxrRTKgyNK2JAJDQzYI/c6JzCFmSMQtMHnJcT9SvmhC/U+eH+aPHD6Cnba/sPqASmqhu1FMWuC6JlEnQr5A2LeJJE0Xx2tw6DqQMh6BfYV1daE5ZGpO9itqqAD5F5trQWq5ak5iXmGd2DbLspUMeO+0tYJuaotOK2mQPOej3sbExQtvGy60+M/ZNV5a9UOYifJ09MV47EWdTSzivWOnqLfX4VWXKF7PUeegLCfksZgEoddfF5WDz+igNtUESusmgrlPtU9jSHCValX+uM9MZTnND1axnO8W+v5eKshP0gZEkAZ800XdF9pIRXZeUYU/JErEceKtziGs21zNdKwtvYSjc3wU8D3775nreOnuJU91jSLj4fZA2vWHQjuOgqgpp0+HKlmqGxgz0tEXa9AqXHMtBkjxb91y3Hp8qT4nvTcdMmSHDvafn9BqQfzAb1y3siZjSqQuxbIXodKKW6yFnBCmeNPLsWRMNzDiibr7MVfgOHe8loElTipVOnR/lPTuapnwxS52HPt+Qz2Ji/svZdbHUu46tLTXEkwZrwibr13vFTPGkMeWeni6sM1u4Z7YRjCuJ+TUvXwE01IZYW+PLtoZ1cbEdZ6I3+lRcF9rPDhX8t7kQCqj86PXzHD89jO24XhGSLWWHaeiGg2HaRIIKG9dHuXJDFD1tYVlO9mdwwXUcXjvey+ET/YyMpeYUOsp4FZGQxuCozsmuIV473sPD//3HfOHpbn7nyy/x1AsnZ32toTE9m3teHdEYGUsTixtcGtX53ovvcOjNi9y8s2ne9kRCGve+/woOnxjILhiZbnchv5JtMzBfcoUvMzc1k6KZS2Zxz+AVKzWwpTk6JXw0k/3LJT4DI8l5tRue63UAT1SfeuEkX//eGzz1wkl++Pr5bNdFWZKzPXlePNxd1M+Uu8jnLjrzDY0uhpt3NhHXzWz1dKbaeC739FzIdYiKcX8vJWXnod+8s4m3Oi5wxYYwZ7pHMUwHVZGpjqgMjxW/6u3IiX4CAV+e9597UAjwa++/gl+cHMiOgtNUb1BGprWAhLdb0NMm79nZiE+V5+xpZbyH7zx7jDMXx5Akr7rVBS4OJjl0rJu+ocSUatJcbyk3/W9gaDy/kAMYHDN45v+9w6cffPes16OQNzPffPHZmGusu6E2xJnYcN5j0+02Jl+XD79v67LHri+NJDh47CKW7Wabr+X26ZnMXK9D34jBj9/J9+QNw1mWrosr4aA5s1D/0wtHGBzVp1QbLzZUUuz7eykpO0G/ODDOwKhBPG3g11SiEY31ayJcGtUXLOiRgEo8NTUoL09MHrILxNdzicXTbGmJcvz0IGcvxtA0FUV20A0LRZW9Vr+ui+O4yEyfKzuZXBH68ZFuJImJ8mjAdbFdONefwHYlnnz2ePaAZ/IW/cbtDfzg0DnAYFy/HF6SJS/33rZdDrzZx6cXdPVmzhdfyHZ8rrHuzOI+OQQ0OYSyEtIVn3m5g5PnRnFsbzfpNV8bZjxh8l8eKFyuPtfrcOKCTnVNOE9UZXminUVOR+npWkwvhpVw0AyeqN96bZTdu2/IPlasUMl86iFKTVkJenYyiQsbG6uyecm37vba13bOYSpRIfS0hSKT115WkbxiJMeGtFn4QDVDd3+cnVesYcfWejrOjxIOqBM39MSQjAmPXpEkTl0Y4d3XNE1703f2xPjBa50cPz3IaDxNJOhlkVgTryFJLoosZQuiXLye4gPDcb41MI7fp1AV0mDifW3b4Xz/OA21QfqGLr9fRsy91/QOShcaC50uX/z9NzQsSEjnGuvevD7Ke7dHiNnajAda8/EiZ7oGi/H29h3swu9TUAM+0oY9ESaEkbgx7bWY63WIJW0am/JDOZvXV3O6ewxzji2m58uUVsgtUa9bKfmLTikPE4vVOmI+9RClpqwEfabJJA/du5OnX1zYsA17IjSSi6rK2BPDMMxZMmoSKU/wNVXhPTuauHlnE3/4V/+Kk43zZ97HpWfQ6yVTyNPq7Inx3R+coGcwTjJlktAtEroFXN7auS5Yk3YMLpC2wIwb2EGV0fE0jutm8+Bd4H3XN9PaVM2+A51e8VTOmYPrgiSzYC82NxvmwsA4hung93kFWRvWhefVJiAjFMmUxaURnVBAZWtLzbSFJo21s88vnasXOZMnf3FgfFHeXibVVZZk1KB3dOW4DskCO8MMc62ujYaUKbnU21vrkYDuS8k5t5ieK7nX6cqNNbx9dojjpwfZsbUeLSe7qNSHicUKlcynHqLUlJWg504muTSaZCiWwrJdZAm+8/03i/peadNrqTtZPAvR2TPGji312Rt58/oo9dEgI2MpdNPJ9mlxXNDTDr842UfT2qopntah473E4l5Jfe+lxLxtdvDaFfhUGVWWSKZtZMk72L04kGDnFWtoWhOiZzCJbbtePN71bKsNqxw7dQnDdAj6VZrWBkmlHP77029kF6nZUrm2rq/i0qhOXZVCMKByvnecd86NEPCrbGr0njvTdjxXKFon+nlkDrcWExqZa+ji0PFeYuMp2s8Moactgn6VDQ1hDh3v5fjpwUV5ewttjjX5zOLA0W6efPZ4nse7fUOQdwY8pyLXk/+9+29YkpBS7o4nAuzYsoazF2N0TMouevLZ40VtrjZfihkqma0eohDLNlAkh7IS9MwvKJa0iCUML6aMJ5RvnR2e9fnzpSrkI540C+aoT+b5A50A/LS9j9tvbGHnFWt4+fB5FEXKNh1yJ17m5PkYluNmsxUyX7qBkSSGZRMO+mZIpJwZF7LXxZ3oayBJUnYX8UvXt/DjIxcYHE1hOy6S7In5aNwCLGQJ0qbF0JhOQJVIWS7n+8d55pVTfPCmjVkPr5D39YND54iGfYQzo/gCPlKGxanzsaygv97ey4X+OK8c6UYG1tUFed8NG7h5Z9OSHLB19sQYGUtx9NQlqkMam5ur0dMW5/viNNQG8yqA3zx1ia7eMVRVJqDJGJbNO+dHSZsOw+OpRXl7xWiONZ3H+0vb/dz7/quXrYvklFbI1QGur/JaIWcazkHpDxNLGSqZrT3IUlFWgp75BcUSl+XOxRMve6EKOAMSEvJEMD3jzc6GnrZ4/mAX77uuCUmSsKyphkkSXBrRuXhpnGdfvZyh0lAb4kLfOOY0c0/nQsjvhVkya5Blu8TiBobpcL5vjLrqALe+ayMP3LEt6xG/cbI/29PGBRzbW4CSxkTsX3axLJd9r52jLhrkvr1tBeOTjuuSSF/OoqgOa6QNi0TKxHFdXm/vpbN3PJvOaZOfqZNMWbSurwa8gq+O86OkTQtVVQj45OwXITfObekx6poKd8bL9fivvXINXRfH+PmJAfw+hStaojStjeSFVQZGdGRZQps4SdRU7xziXN8YtuP9znyqQn21Rk1VcF7eXjGaY00XEz7ckeDf37d8XSTnuuMp9WFiKUMlM7UHEYI+QeYX8dh3jwCXBXYOOrsgYol0dqGYi5hnkCT42YlL0yb5uy7ohs2bpwZpqA3xg9c6+eRHdnHzziZOXxjNxtmnw++Tpx127dcUTMvFtOyszV52jMsbJwdYvzbMlRtq+fr33qB/KEldtR89bXvVrBM/n5uVmQ21S+A6bvaGzPW+dMNiLG5gO6CnbU6eG0ZTFeqifiIhH6blMjiqc6E/jiKBNDGjVJK8M4buS0l2XtnApRGdpG5yoX+c42eGUGQJVZVxHTfr3WSaiGXi3J2x4Wlj/ZNDA/XVQY6c8Mq8mxu8ZmK5uwDPObhc12A7LoZhYtpQHw0wPGZhWTb9wzqptIWiqvPy9hbbHGs6j7fv0vJmlEx3WHv15jqeeuFkdpeQm11VqsPEhYRKisF07UGWYqBILmUl6MvNQr1+28GrFHWmXwVs2/FyyQfG6Lgwyg8PnUOVJa67oo6dV6zh4gwxdHeG1WU8adJUH6FvOIEykaEW1FQURcGybfqGk1y9uZ6j7/TTN5LKe25AU7Bsh9wlMnN46kyEZzI3ZMb7khWZwVE9byfinTtY9A5aRIIav3PfTvbsauHVI934fDKm5SBJORk2rkso6E2Ij+teR8lMFo7ruoQCKvbEYpIynTn3hS90GGrYzpTVORPXb6wLMTymZxdEn6rgIKGqEuvqwgQ0laFYirRpkUg7/P6Hty+rWEzn8VYFi5uKOBuFDmuv3lw3JWW2Tzf50M2bOHxiYMUfJhab5RwokkvZCfr/fWXuJe+lZWaX3nZAT5kYlpvzmMuRjiHev0vLeoiFyH3OZKKRALdct57nfnya6nCAaFjLtr/tG06Q1E1+/nbfFDEHr82w3ycxWfMu/38pe0Nmwl8pI42MF2qSAJ8KtiPhOC4+VcavyZzvj/P1772BhNc8KddDd11vdmpSN9naUsPNO5v4yS8uoihe7D+gKV6Vo+sQ1405Z6w883IHr73Zi2HZaKpCy9oQfs3H8KiOqsiMjKempNm9a1sDf7f/BJGgSjAQQE9ZDI9BY31g4tr6iUb8OK7DyHh62YVpupjwL20PL6sdMPWwNrcvOVze+aRMhy9/cs+y21dqlnugSIayE/T+oSRhPziujG7Mz4Wea1OvYuDMwbQpwjyRDvOvR3u93OEFvG9jXYgH7tjG8dODJHQjK+YAqbRFKOCjd3j6QynbAU2VcR0Hw/ZCIrJ8Ob6emaCUEbNvP3s8e2gcCsiEAn7AxbQcaqv99A0luTgwzlAsRcCvENct70xi4v3cif/+6PB51q8Js3FdhNrqwIR3c/n2zHg3heK3x97p51x/gtfe3EckqLGxIUR71yiKBDIuhmlxqnuMaMhrpua4FEyzywhUbsx1/ZowqpJfcVmqopLpYsJBu3/ZbZnMSikwWimUaqBI2Qm6qkgYNkRCfnRjfqflkZCf4bF5jDtaZjKesIN3MLkQ7pwQ3ELenOvCpqYqBmNTvfMMG9dV0baxlqFYirhu0NUzhmHZqIrMB2/amL0hMymLlu2i+WRARVG87aXteKXtI2NpfLLE+f44miqzqbGa8/3jxJNmdpCI63reXG21H8dx+bv9J9jYEOLo6WH0lImqysiShO3CR27NH8QRCvo4dnac0aS3FioyxOwUR0Z1NJ9EOBTAsh3GkwbgkkjZ3HxdM7gUTLODqTHXTGbJSikqKRQTPnKk9IJe6k6WK5FSDBQpO0G/enMdPz/Rh2HN33+NhrUVK+jSxH/mc/haiMyXvaA3d0MLXX3jMz4/kbK47so12Zai69dGClZMZtLnGmoD9A/rmJaDa1g4fh8SLoGQwnjSZk1NAE2VsyXnGxuriMUNfmlXc3YXkRsTThtJTveMc2VLNef64xiWgyqTt5hk4rdvnrrE6IQDKEveYW5mITRM709Vkb3D1YmagkyYpVCa3UzXsxyKSpaTr373dQ682YfteJXL111RR3WV56FXQl/xcqXsBP2jt2/jQs8gCdMbD2cUSAucjt7BmcVsvrRtqGFgJMlofPEn1272Pwsn5Jf53DcP0DeUxHEdmuojU0qtO3ti/Ky9h2R66nULamS7yD10785p32f/wS4UySWuW5iWg8+nTLQakAhoKi5QVx1k8/oo5/vG8amXD4ZM0yEa9kr1h8Z0AppC/3DSex1VxrAsTNPh+m2NXD+htfGkQSjH88vEb4+fHgTIhkS8aVNTL6Iiy9lDzgzz8R5LlSmxUvnqd1/n1aPeUIlM24gjHUPsbqunsT68LLnwgsKUnaBvXh/lznfXEbNrGRhJMjiqc+zU4Jyem5q5Jcu8+OB7NnHy3Ahra4OMJQwc1/MSNZ9M2nDmpM0f3tPK0VODnOufOU1xLsgSrKvzOlAqihemGI4lp5Ra/+JkP7YrZT3a7POBrS31OK5D33ByxjYAvUNxLMudaFnskkrb2LbX9fLzH78pb4bn1546wsBIkrhuYlpeOuCWZi/nvqsnRt9wEr+qoCoStu0QT1pEgvm35XSx2KExPSvi2b40Of+eMqyJcJDXx2TrujCO6wrvcZEceLPPC3HlnC3Ytsux08M88p9W3wHoSqLsBB2m9u+4+9PfX3YbfvjTc3zwPZs4ezGGC2iqhCzLuK474Wl6HnCm7H8yu9vqs8NxM+KZqTadiUysGLwDzMxrOy509no7kKAGQb+KaUMkqOSVWmcKHqpDGiPjqWwVa2ZIrp7yslVmqtiUJRnHtZCRvOwWSZo4NHXzhH/z+ijXba1n32vnkCQXdWKhOdszxnuuaaSxPkzPpQSO4iJLEo7rIuFOyd+dzpuurw4yNq6TTLtZUc8sUrvb6jk/kMweSN124zpCIb/wHgsw3xJ12/EK7XLJeOqC0lJ2gv6dZ4/xw592Y33vIppPoSa8vDm4ufhUmeaGCAOjOsmUQdq0s+l4iuxVmlqTbvKgBjdevT5b2JKb0zsT115Rx/m+OC6wpiaIldapqqriXN8Y40kzb+HQDRvXdQn4fVNKrTMFD8mUmdenxnJdErpBMm3TUBvkZ2/1MDBy+byhodbPzisaAGisD030orezLQYcF6Jh/5Sc8DM946yrDWDaZMMqPgUOnxhgXX2I669ay6kLMfS0SdDv45ot9fSP6LO2xAXv4Pfbz4wSUWTGk2Y2fr67rT7PU/zqd19n/2vnsvHePdc2snn9zLHz1cJCStRdvLOe3IN7z9EoPGSmmJR6OtJKp6wEPTNWC0Dzed6hnl6mPMQCvP5WD/0jUw9ZXdfr4Bjyu1iT/lk34OS5IS/uPEHGm3315130Dk99vaAms6nRm/35ZscgfYMJIgEXWTcYT3pxJHlS+9+U6ZA207i4eelkkaBGLJHKHhrmktANfvvuHfzdvvY8MQcYGEnTfnoAgB1b16D5ZNrPDGG74FNkaqt9rImGpoRHhsZ0aqsDyNLlOLrjOgyN6ey8Yg3xpMGmpstfyHjSoGlthEho5pa44IWRjrzZweEzaUIBlXBAZUNDhOqqQHbie6F476tHe4HX5zTQo9KZb4n6Q1/8YcHXcYE91zYuoaUro6/9SqesBP3Fw90osoQsuRMCUbwGLhLg12RMy2XDuiouDoxNO4c0QyExz7yWIkNymoSagZF0wRDCZ3/7Zj73xL8ST13+XJGAzK3v2uSVr4c0ZCROnhshljRYW+eFQnxq4Ra/LhCLGzSvCWcF7q5bWvnbibh6LqoMLjJ7drVkWytM93lv3tlE31CCTY3VuK6LJEkYlkNLQ2RKeESVZU53x7xQlKJQHw2gKhL11cFpS8jn+gXt7IlxqidNdUgjGvHT0uDlsMeTRnaXMF28dzEDPXIpl+HB0zHfEvVCDkeGpV4gV8J0pJVOWc0UNUw774s5X2baEbp4X3QJr59Hc0PVjD+fYXIsESZCIrOsNZPnHWa8j1vftYn7b2vj7j2buXV3C1dvWZs3h3JTU5TbbtrEtpYgX/7kHoJ+FdzpDZUlb4huJqRz3942ZC4fHkp4O4BolX/OfSYyYaItLVHGdW+H0LapBp8q581yPHC0m1g8lW3Va9k2PZfiDI+lufOW1kXN+Mxcr0TKpirswzBtTp4bYWQslbdLWMp4byZ9M6Ebed0PDxwt7tzOpSQS1DAm9QVajhL1hTDfmayrkUV76J2dnXzmM59hdHSUmpoaHnvsMVpbW4tg2lQ0n4Jl2XMS2kJkcrx9ipefPLnS1HHBr0mEAj42NFYxMKxj2c68UiPnylyn6mQaVk0u2IiGvJDN7Te2ZMNQuWRCDC5Tb/r6muCUSsyUYWW/xJl4fK4Qum5+Bsnm9VE++ZFdeTHNSLU2peKyPhqkOmwzNGZ4mTCqRFVIzXqxs01cn47M9QoHFCzr8mi17oE4PlXO7hIKtVDItBtYLMWaiFNKSlWivhBE8dLsLNpD/9M//VMeeOABXnjhBR544AH+5E/+pBh2FeT2G1uwHRfT9uKwszHZM1NkCb9PJuD3EQn5CWpynqf6W7+6jQ+/70q2NHvjtNo21nhVj4pE0C9TFfLh1xR++87tzCQHqVn6C+xuq5/y2HTeRzjoI66bxJNG3kTz7Ru8uPgn7r2OuyeqQ7OfGyb6sHtj7ybf9Hfd0ophOaQMC8f1/jQsJ1vWf8OEfa57+X+5j+eyeX2UB+7YxsMfvYEH7tiWJ85DYzrBgEpNVZCtzVHaNtaytSXKAotg88hcrzVRFcOyMUwbVZWIxdN5u4Q91zZmd1/eSD63aPHezOfLZaUOD56O+/a28cBtbQQ0lWTKW+QfuG36CsemOv+8Hi8mN+9sKvhdmLzbXc0sStCHhoZ4++23ueuuuwC46667ePvttxkeLv6wCbgsXj7FG4Ab9M++wciIugQ8+5V7eOD2q7JiFgr6iEyI9Mfu3M59e9vybpqrWuto21iTHaSce7Pnit5kcj3fQhTK1W2oDZHU8xPlk7rJlubCYYnG2steyifuvY4P72lFlnJSGieEq2VtaMpNP9uX+JH/tIfdbfV5i93kzJG5UF/t9QzPpVh9UDLXKxJQ2NZah+ZTGEuYRCP+vLDNpx98N+/f1YQy0RBMkSXev6upKPHepfx8y8l9e9v4m8/fzj986S7+5vO3z5iy+OR//eAU8W6q8/Pkf/3gUpu5qBDdakFyZ+rFOgvt7e380R/9Efv3788+9qEPfYivfOUrXHPNNTM+N51O097evtC3zuORp2aPWW5dp/Lgr3he2YG3YhzuSKCbLkGfxI1tYfZcc/mm6BsxOHFBJ5a0iYYUtm8I5glohu/+vz7O9FsF3+Mbz3VTKLRXG4JP/drU7XjfiMFrJ+IENImATyJluqQMl/dujxR870L8y+EhjnbqGLYnwjVh2NlaNa39S037uQQ/OjqGpkr4fZA2vYZkH9hVzY5Ni+sQWIzrtViW8vMJBLOxY8cO/P78xbXkWS6FjJqNI0eOsHv37uzfn8/5/wCPfPsAb3QMZacZ3TDJu5z04wW5cw52zPQ6f797Nw998Yd5WQGzeTJXX305Ht08S47t5Gswmz2lYPdu2LrlchZIQ31xs0CuvjrGP71wBDVYM+v1WgqW+vPNlUL3wmpktVyHmZzhRQl6U1MT/f392LaNoijYts3AwABNTaWNaa2U8uP5bkMXekC4klnKPiib10e59doou3ffsCSvPxdEnxfBSmJRMfT6+nq2b9/Ovn37ANi3bx/bt2+nrq6uKMYJBAKBYO4sOuTyyCOP8JnPfIZvfvObVFdX89hjjxXDLoFAIBDMk0UL+tatW/nHf/zHYtgiEAgEgkVQVpWiAoFAIJiekmW5ZLIlDWNhwyHS6ZU5eWg5EdfAQ1wHcQ0yrIbrkNHMQhnni8pDXwzj4+N0dHSU4q0FAoGg7Glra6OqqirvsZIJuuM4JBIJfD4fUqEOVwKBQCCYguu6mKZJOBxGlvOj5iUTdIFAIBAUF3EoKhAIBBWCEHSBQCCoEISgCwQCQYUgBF0gEAgqBCHoAoFAUCEIQRcIBIIKQQi6QCAQVAhlJeidnZ3cf//93HHHHdx///10dXWV2qRlZ2RkhP/4H/8jd9xxB3fffTe/+7u/u2Qj/8qBxx9/nKuuumrVVh2n02n+9E//lNtvv527776bz3/+86U2adl55ZVX+LVf+zU+/OEPc8899/Diiy+W2qTS4ZYRDz74oPvcc8+5ruu6zz33nPvggw+W2KLlZ2RkxP3pT3+a/ft/+2//zf3sZz9bQotKR3t7u/vxj3/cvfXWW9133nmn1OaUhEcffdT94he/6DqO47qu6166dKnEFi0vjuO473rXu7K//xMnTri7du1ybdsusWWloWw89OUeSL1Sqamp4aabbsr+fdeuXfT09JTQotJgGAZf+MIXeOSRR0ptSslIJBI899xzfOpTn8q2z1izZk2JrVp+ZFlmfHwc8HpENTQ0TCmJXy2UfKboXOnt7WXdunUoigKAoig0NDTQ29u7aickOY7D008/zd69e0ttyrLzjW98g3vuuYeWltU7/u3ChQvU1NTw+OOP87Of/YxwOMynPvUp3vWud5XatGVDkiS+/vWv85//838mFAqRSCR48sknS21WyVidy1iF8OijjxIKhfh3/+7fldqUZeUXv/gF7e3tPPDAA6U2paTYts2FCxe4+uqreeaZZ/iDP/gDfu/3fo94PF5q05YNy7L49re/zTe/+U1eeeUV/uf//J88/PDDJBKJUptWEspG0HMHUgMrZiB1qXjsscc4d+4cX//611fd9vLw4cOcOXOGX/mVX2Hv3r309fXx8Y9/nAMHDpTatGWlqakJVVWzYcjrrruO2tpaOjs7S2zZ8nHixAkGBgbYvXs3ALt37yYYDHLmzJkSW1YaykYJxEDqy3zta1+jvb2dJ554Ak3TSm3OsvPQQw9x4MABXn75ZV5++WUaGxv567/+a/bs2VNq05aVuro6brrpJg4ePAh4WWBDQ0Ns2rSpxJYtH42NjfT19XH27FkAzpw5w9DQEBs3biyxZaWhrNrnnjlzhs985jOMjY1lB1Jv2bKl1GYtK6dOneKuu+6itbWVQCAAQEtLC0888USJLSsde/fu5Vvf+hZtbW2lNmXZuXDhAp/73OcYHR1FVVUefvhhfvmXf7nUZi0r//zP/8z/+l//K3sw/Pu///t84AMfKLFVpaGsBF0gEAgE01M2IReBQCAQzIwQdIFAIKgQhKALBAJBhSAEXSAQCCoEIegCgUBQIQhBFwgEggpBCLpAIBBUCELQBQKBoEL4/wE2vvQqOWUBfQAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "samples = jit(mcmc.sample_chain(mcmc.metropolis(\n", " ppl.log_prob(model),\n", " mcmc.random_walk()), 1000))(random.PRNGKey(0), jnp.ones(2))\n", "plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "0vTY-MiTGuQa" }, "source": [ "#### Hamiltonian Monte Carlo" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": {}, "colab_type": "code", "execution": { "iopub.execute_input": "2021-01-28T12:27:30.743193Z", "iopub.status.busy": "2021-01-28T12:27:30.742596Z", "iopub.status.idle": "2021-01-28T12:27:31.970140Z", "shell.execute_reply": "2021-01-28T12:27:31.970527Z" }, "id": "2CWSqdO7F3Ix" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAW0AAAD7CAYAAAChScXIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAACPyElEQVR4nOz9eZRk913fD7/uWntVV+/bLD2jWTUjDR5btmRBwIj4iQQPGAgmJjhk85PEJw8Jh5PwI4TkByEcAwF+OSTHOJwTAs8RdgiRfYzEETjGgKQxskeW1CPNPj1L71t17XX3549v3eqq6qqu6m2melTvc+Rxd1fd+73fe+/7+/l+lvdH8jzPo4suuuiii30B+UEPoIsuuuiii/bRJe0uuuiii32ELml30UUXXewjdEm7iy666GIfoUvaXXTRRRf7COpeHtx1XfL5PJqmIUnSXp6qiy666OKhged5WJZFJBJBlmtt6z0l7Xw+z7Vr1/byFF100UUXDy2OHz9OLBar+V1L0p6enubTn/505edsNksul+P1119veUJN0yon1nW9rUFeunSJM2fOtPXZ/YaH+drg4b6+h/na4OG+vv14baZpcu3atQqHVqMlaY+Pj/OlL32p8vMv/dIv4ThOWyf2XSK6rhMIBNod75Y+u9/wMF8bPNzX9zBfGzzc17dfr62RW3lLgUjTNPnyl7/MD/3QD+3aoLrooosuumgfWyLtr371qwwNDfHoo4/u1Xi66KKLLrrYBNJWtEf+8T/+x3z7t387n/zkJ9v6vGEYXLp0aduD66KLLrp4L+PMmTMbXDttZ48sLCzwjW98g1/5lV/ZlRM3w8WLFzl//vyWz7Ef8DBfGzzc1/cwXxs83Ne3H69tM4O3bdJ+4YUX+Bt/42+QTCZ3bWD7FVOzaS5MzrGYKjCYDPPk2REmRhMPelhddNHFewBt+7RfeOGFbgASQdgvfO0GuYJJf0+IXMHkha/dYGo2/aCH1kUXXbwH0Lal/fLLL+/lOPYNLkzOEQ1pRMMi79z/98LkXNfa7qKLLvYcXe2RLWIxVSAcqk14D4c0FlOFBzSiLrro4r2ELmlvEYPJMIWiVfO7QtFiMBl+QCPqoosu3kvokvYW8eTZEXJFi1zBxPU8cgWTXNHiybMjD3poXXTRxXsAXdLeIiZGE3zsOx8hGtZZXisSDet87Dsf6fqzu+iii/uCPVX5e1gxMZroknQXXXTxQNC1tLvooosu9hG6pN1FF110sY/QJe0uuuiii32ELml30UUXXewjdEm7iy666GIf4T2ZPdIVfOqiiy72K95zpO0LPkVDWo3g0/3Ote4uHF100cV28J5zj1QLPsmSRDSsEw1pXJicu29j6CoFdtFFF9vFe87SXkwV6O8J1fxup4JPW7Wau0qBXXTRxXbxnrO0d1vwaTtWc1cpsIsuutgu3nOkvduCT9txt3SVArvooovt4j3nHvEFn6rdGd/zwUPbcktMzab5+qU5PM8jGtIZH4qSjAVbWs1Pnh3hha/dAISFXSha5IoW3/PBQ9u+ri5aoxv87eJhwHuOtGF3BJ98t4imyHh4mJbDldurnDzci6bIm1rNu7lwdNEeOiVrqIsudor3JGnvBny3yJGxBFfupNBVCU2RuTWdZmww2tJq7ioF3l90g79dPCzokvY24WehyJLEyUNJphdz5EoWEl7HWG9dd8A69iJrqIsuHgTaIm3DMPiP//E/cuHCBQKBAOfOneMXf/EX93psHY3BZJhcwSQa1knGgyTjwcrPnUCMXXdALarvl49u8LeL/Yi2SPtXf/VXCQQCvPzyy0iSxPLy8l6Pq+PR6cHErjugFp1+v7rool20TPnL5/N88Ytf5Cd/8ieRJAmA/v7+PR9Yp6PT2451c8Fr0en3q4su2kVLS/vevXv09PTwW7/1W/z1X/81kUiEn/zJn+T973///RjflnC/fbidHEzsugM2opPvVxddtAvJ8zxvsw+88847/OAP/iC/9mu/xvd93/fx1ltv8U/+yT/hz/7sz4hGo5se3DAMLl26tKsDbob5lMlrl3MEdYmgJlGyPEqmx1Onogwn9dYHeMjQnY8uutj/OHPmDIFAoOZ3LS3tkZERVFXle7/3ewF4/PHHSSaTTE1Ncfbs2W2fuBkuXrzI+fPn2/qsj6nZNP/r65OkCzIJOUBvb5SxcmAw7eg8d/7klo63V9jOte0Ep0+v7zzG7sPO435f3/3Ew3xt8HBf3368ts0M3pak3dvbywc/+EFeffVVnn76aaamplhZWeHQoc4I4PhZEms5g3hEE0Uud1KcPJQkEQu8Z3240HUHdNHFw4i2skf+7//7/+Znf/Zn+cxnPoOqqvzKr/wK8Xh8r8fWFvwsiZ5oANNy0DUFgOnFHJq6eWViF1100cV+Q1ukfeDAAX7/939/r8eyLfhFE+NDUa7cXgVAVSXSOaOb0tVFF108dNj3Kn++Yl4yFuTk4V50TSGTt0hEA92Uri666OKhw74vY68umkhEA2iKTK5o7ZiwuyXgXXTRRSdi31vae1E00W0H1kUXXXQq9r2lDbufJdEtAe+iiy46Ffve0t4LdEvAu+iii07FQ2Fpb4bt+Kb3qgR8PmXy/MtXun7yLrroYtt4qC3t7fqmnzw7wtxKnouXF7hwaY6LlxeYW8lvu4+kP5bXLue6fvIuajA1m+b5l6/wm59/g+dfvtJ9HrpoiYfG0m5kUe/ENy0hUTJt0nkTz4Vc0WJmMbtty/jC5BxBXer6ybuooKt53sV28FBY2s0s6pvTa9vyTV+YnCMUUJBlmaFkmANDUQKazBe+cn3bltBiqkBQk7Y8li4eXlQbFbIkFvRoSOPC5NyDHloXHYyHgrSbPfyFkk2haNV8th3f9GKqwEq6hK7K6JqCJEmEQxq24277hRpMhilZtYKKD4tUqr/F/+LXV7tb/C2gG/DuYjt4KEi72cMfCWnkiha5gonreeQKJrmi1dI3PZgMk8mbaOr69FiWSyKib/uFevLsCCXT2/JYOh3Vu5x4SO766rcAv5q3GtULedff3UUjPBSk3ezhPzKW4InTQ0zNZvjLb00zNZvhidNDLf2FT54dQVEkCiULD49MwWAhVWBprcjCSmFbL8/EaIKnTkUfus4p1bscqbvF3xKePDvS1KjoFnh10QwPRSCyWf+/0xO9vP7uAhOjcR492kehaPH6uwuMDcY2JcuJ0QQff+Y4X/jKdZbXihQNm2hIQ1MVeuOBmmDRVlIKh5Odo+3tY6fl+t0u59uHX81bPf/f88FDTIwmeP7lK90Cry4aYl+RdjOCafbw7yR75Olz44wNxvjcC5OkcwaJaIDxwWil67pvSW4W/a8fb0Ix93aCtojdyF7Yb23NOk1Tplk1b3cx7KIZ9g1ptyKYRg//l/7y5o4e/InRBEN9YR492ocsrWd++MfYbFGAWkKfWcry1VurvHHn6xwd73ngZOGPc6fWXPUux6va4neiJO5+SrHbb4thF/cP+4a0t0Mwu/Hgb3aMzayh6vGmsiXuzmfxgHzZh+mTxcxilhdfvc1KpkhfPMRzHz7M0+fG2x5fu2hkYbay5tqxSqt3OfdmXcbG9MoWv9OwnzRlmrn8OnEx7OL+Yt+Q9na2izt98Kdm06SyJd68tkQ8rHN4LE5AVSrHuDA51xahTy/k0FUFV5MoGHbl85//0yvcnMkQDigkYwHyRZPfffEyi6sFSpa7a1v4ZhZmQFMoFK2G49+KVervci725znfYT77auwnl8Nm/u4u3tvYN9kjrdKjGmEnsq0+aWmKzOPH+gF4+/oylu1WjrFZ9L96vPmShabJ2C5EgiI1MRzSeHdqlXBAIRLSkSWZSEhHkeF//8XNXc0aaJbHjkTT8T+MhR/beYYeJCZGE3zioyf5Fz/6Pj7x0ZNdwu4C2EeW9v3eLtZvpXvjoYpV7b88rayhyngDKoWiheN4jA9GAUEWtuMRCtbeAtNysGy35Ra+HdeF/5mvvXGP3liQA8MxkrGgGFNIY3mt2HT8O40HdCL2o8uh0wKnXTx47BtLeztW805yXdutVmtmDVWPNxrW8YD+hEoiFqhYtEN9YYolu+Z4RcMhUkfk9edt57qqP9MbC1IoWVy5vUoqWwLWLUx/x+C7dC5MzjE1m953Vmk72O4z9KCqPbu52l00QluW9kc+8hF0XScQCADw0z/903z7t3/7ng6sEbba7KBVdsdmFsxmAch2rZ/q8U7Npvmjly+yvFasWLQzi1l+98XLgEkoqAoClyQOlK3x+vO2c13++ao/c2AoxpU7KSQ87s1nKy3ZvueDh5r6rp84PcTr7y4A+8cqbQdbeYaq56a62vN+ZZvsp8BpF/cPbbtH/vN//s8cP358L8ey62gWeLo5vcb8Sn7TINtmBTvbSRubGE3wXY8lOH/+fTW/A2qyR37kfePcns+SK5hNybKdgFr1Z5LxICcPJbm3kGU1W+Kx8EDLIo67C7n3fCCsmjQza9J9J839FDjt4v5h3/i0t4Nm1nKhZDPUG97UgtmLgp1GePrc+IYUv3pLvp4s20llrP9MMh5EU2UeCw/wiY+uZ3hsRgy73cZtv+FBk2Y3V7uLRmibtH/6p38az/M4f/48P/VTP0U8Ht/Lce0KDg5F+cJXruE4HvGITl8iiKLIREJa2/7qRgU7qiIxeWOZfMkiEtSQZY+78zm+9sa9Xcm1bkWW7QTU2g26dYmhOernJpUpcWsmjeW4PP/ylV0JCm7matuPgdMu9h6S53leqw/Nzc0xMjKCaZr80i/9Evl8nl/7tV9reXDDMLh06dKuDHSrmE+ZvHY5h+u5ZPIOBcNFkSW+/UyMpbRN0XQJ6etxWP/n73ps85fwy6+vcnvBQJLAtDwKhovtgq7CUI+KYYFpezxzLs6ZQ5E9vb7L94qkCw6JsMKpAyGGk/q2PvPa5RxBXSKoSZQsj5Lp8dSp6IbPtnvMhwXVc2PbLneWhAzBoUEdVZE3naetHr/Z3L+X5ruLjThz5kwlluijLUt7ZETIh+q6zic+8Qn+6T/9pzs+cTNcvHiR8+fPb+n4jfD8y1c4NB6psSBzBRMtrPNDHxqp+KV9C0YuWm0FmL5+603urSxQNBw0VcIzxIvsIaMHIyTiKvmiyZU5mb/3g7XXsVvX5uO5bX6m1rpL8re+/QB3F3IspgqMbRJYnZpN8xdXbxDviTA8Iubt6qLF6dNi3nb7+joBp0+Lufrq6zfp64lxZDxRSZvMFUzSzvZFwJo9o/XHbOc+7xQP473zsR+vbTODtyVpFwoFHMchFovheR4vvfQSp06d2vVB7jZa+WrbDbLVb19TWYNoWMe2DVzPw0NCL89iJmcS6lUJBVVWMsX7cZlbRqNskdffXWhrwXovZjP4rqrLN+5x4uhgQw2a7eJB+8y72J9oSdorKyv883/+z3EcB9d1OXr0KP/u3/27+zG2HaGVr7adIFsjgltYzVM0bEb6I0iSxM2ZNI7joCoSlu0AUCzZ9MVDTY/5IIsldkK872WSSYSbl/xvF914QhfbQUvSPnDgAF/84hfvw1B2F7sRxGlEcAeHYrx1Y5lCySIc0ohHNJbXHFQFVEUiXzQpGA4/8szhDcebT5n8xdUHqzK3E+LtVJK5HwvhqQMhri6KYqPdCgp2A41dbAcPbcrfTlwg/kt/c3qNfNGiYNhEghrjg1FGBqKspEvkSzbpnEl/IkQypjO3UoSyfsiPPLMxe2RqNs2ffHMNV9JrtLmhvWKf3cJOiLcTSeZ+ya0OJ3VOn97dvPX7IQr1oHd2Xew+HlrShu27QPyKwMVUEQmPcFDDtByu3ElxcCjKY8cGKqJK/svwzzZ5Gfxz5EsOg/3rxzp5KEkiFmir2Ge3sBPi3SnJ7AWB3E8/+17krbf7jG5n3vaTfngX7eOhIO2dkEGzl/7FV29zcDjK3fkslu2iaTKW7XB3IcsPf/fxLb3A/jkiQQXb9tA1BYDpxRyaKrdV7LNb2Cnxbpe49opAHnY/+07m7b0YOH4vYN+Stk/Ut2bSLKzmOTgUY2Qgysxill9/fpah3ghHxhItCbzZS7+SKfLo0T7CQY3phRz5kkU4qNWo/LUL/xz9CZWVvAhWqqpEOmeQK1ptF/vsFvbSumuGC5NzOK7L1GymUpTUmwjsmEA61c++W9jrwLF/ny/fWOXq8u4UDHWxt9iXpF1tfeQKJhJwYzrNrZk1UjkLVZGQZYnBZKilVdLspe+LhygULZKxYE1ebvXnAF55c7pl55nKOYIKAwO9TC/kWMsZ9EQDFau3k4hnL6ziWzNp5ldyeIjsmuW1IncXIJUxdjTWTvSz7yb2MnD8oAWxutge9o00azWqrY+CYSPLEtmCyXK6RECVUCSJhdUCK5kiM4s5fuMP3mgqq9mskcFzHz7ctEGAj1fenOZ3X7xMvmgS1BXuLWb5f77wJr/8P/665lxPnh1hfqXA9dkSV26vYloOI30RPvWxsy2bKTwI7EUDhHzRwrJd0jkTx/UI6gqeB3cXsjuSGt1Jo4v9gJ1I5LZ6rqrvs/SQNLp4L2Bfkna11nUkqJHKGqiKhOuCLMsggSLBO7dW8fAqDWcbaRE3e+mfPjfekgz+11dvUDJtltMlZpfyeJ5HQFO4cju14VweQi3AsByW00VuzaZ56bUppmbTHUc87WqJV2NqNs2fv53mNz/feIEMB1XyRaEdLksSjuMhSxJBTdkxSTzMHV52sqC3eq62c5+7ePDYl+6R6m3f+FCUqbk0qiKjKBKW4wIgKzKO4yEhEQ3pm/oCm/l4N/P9Ts2mmV3OEdQVzPI5DdMhGFAwLa9isUyMJrgwOcdIXwTPzLJSUImFdDzP49Z0ekNH+U7AVv3E/ja7aLpMjDR2pxwd7+H2XAbP87BsD02ViYYVYmH9oSKJ3Y4F7GXguPo+54o2kzeWSecMEtFAxZjoovOwLy3tausjEQ3QFw/huh6RoIrjeHge5AoWpuWQK1qMD4mmArtpRVyYnCMUUHE9cF0PRZaQJImi4RAKqDXn8i2a5YyDrsromoKuK1iOu+PtqN9ZpZmFux1s1brzt9khXW7qTnny7IgI5IY0xgYiJGMBZFmmLxF8aIKGe9VpZq92Ev59nlnMcm/ZJF80kWXoTQS6HXI6GPuStOu3faeP9HLiUJLTE31EQyqyBJIEuipX3BKwu8G9xVSBE4d6sB0XD3BcD9dzcRyPYwcTNefy/ZIly0VTxZRblkskqO1oIdlLkmi0rQYaLhDtbLMnRhN8/JljeEAmb6KpMgeHoiiK/MB897uN/dYM2b/PqxkD14VISOfURB9jA7GOHvd7HfvSPQIbt31Ts2k+98Ikqqow3hMmHtWYXSo0bLG1GxBbS5XHj0m8O5UiVzBRZImRviB98VDNufwMB0WWMG0HCfHvkfHEjhaSvczDbTS/zTJK/G12NRpd19PnxhkbjD20FXp7nTO+F8VJE6MJhvrC9IYMxsb6K783bIdLlx7O+7TfsW9Jux7+w/fo0b6KEls8HGjYYmunmJpNk8qWePPaEvGwzhOPDlEybO7O5xhMhoiG9Zpz+RbN/1hbYWbNIh7WOXEoueOF5H4Wlmy2QPiLUtF0cT1v07S73fLdd2J59l7mjO9ldeNgMszN9Grl51SmxLu3VggHupWUnYiHhrSh/RZbO0H1y/P4sX6mZjK8fX2Zc8cG+KlPvK/pQz0xmuD7PthH78h6UKk3ru9oIdlOwHC7RNeO1G194+K9esE7tTx7L3PG93JX9eTZEd65dq/Sl/TWTBrPgyPjiYqbZ7fO1cXO8VCR9v0otKh/eXrjoQpxtvNAN7M0t0OoW7neRkT3+y9dpi8RxHLcludsR+q2vnHxXqFTy7P3UgBqs0Vzp7uOidEET52KknZEJo/luJw52lcpKqs+VxcPHvuWtJs9qHutmrYXLontWo5bud56orMcl9nlHOmcwbedHGx5zk6qPNzJPdhrt8pepW42WzQ1Rd6VXcdwcr1bzvMvX2krRvFeQCe64fYlabciue0KGrVzcwaTYWaWsqymjRoNjbGBWMvj//nbaf7q+hsbjr8Ty7Hd660nuumFHOGAiuW4bW2Bd3tB3MnLsF3f8dRsmt976V0yORPTcbk3n+X6vRSffPb0A3sR252HZotmQFd2fdfRSQv0g0SnuuH2ZcrfbqRWVec3/5c/fJPff+lyW6lzB4eiXL2zRr5oEgoo5IsmV++scbCcC97sXH6grtHx70dlWn05dL5kgSQqSts9527lC+80VXG7VYIvvTrF/EoegEhA2CvzK3leenVqW9exU2xlHpqlYVq2u+vPTqdV6D4odGoK5760tHfqoqhfQb91ZZF8yaKvJ9jS6ry7kOPEwR5WM2VLO6RzYCjA3YVc0/P5Nz+T9njn5gr5ktjWvvTaFJ/+4XP3Ramu2noybYd01qBkirZpqWyJZCx437bAO/VJb9fqv3o3RSigVqRxdU3Bw+Pq3dQOr2h72A3ffDuiUNvZ0XRShe6DQqfK/u5L0q5/UFOZErdm0liOy/Mvt5aXrH5ZUpkSi6kijuNw8fIi508NkowFm96cxVSBkYEoY4Pr7hDX8za9kYupAqoicW/JpCeuEw6qmKbDm9eWmJpN35ftqE90L702xVvXl4mFNXRNxnU9Lk+tcGgkjiLL92ULvBsvw7ZJxWvxc5vYDV/nVuZhs2Ydr7+7UPlu9bPTqdv7/YJOlf3dl6RdTXKG7fDurRU8D84c7WsrK8J/WVKZElfupJBlkCSZkmlz5fYqJw/3oilyw5uznRs5mAxz8fICqiJVrDxJkoiHdS5MzvGJj57csb+4HRKZGE2QjAX5wKmhyoI1vSiCkatpo6I6uNfY6cuwXcI8cSjJ5I1lJEkSTS0sl4Jhc/aR/pbfrT//bpDhVuahmVV+dyHX9Nl5/uUrDyzLphMDeFtFp/r29yVpV2+PL12aIxzQODIuCCmVLVWyIibG4ly8vMCfX7zHuWMDPPvhCSZGE5WXZXoxh67KJGMBFlNFNEVGU2RuTacZG4w2vDnbuZFPnh3hzy/eQy4rDlq2i2m7HD/UU7GqdrId3QqJVFt3yXiQZDyI63ksrxXv20u1k5dhJ4T57FMTrKyVSOcM8kULXVUY7Y/y7FMTWxr/bqUcbmUeWuXJNzrvg9re79ai9qCJ/35ko20HWwpE/tZv/RYnTpzg2rVrezWetuEHxY6MJfi2k4OVnFI/KyJfsrh6R/gqY2FRMOAHefxAVjpnoKoSiiwTDakkE0Es18Ny3KYP2HaCNBOjCc4dGwAkCoaNrimcPJQkoCq7stXaSsBkJ/rMu4WdBLp2EhyaGE3w48+e4vypIY6O93D+1BA//uypbbk1diP4t9k81AuBaaq85fv2oO71biUK7IWuzlaxV2JdO0HblvY777zDm2++ydjY2F6Op234HWNuz6XRVYVTE70cHI6TL4nONWZZkEnXFDzPo2DYlQfnybMjBHSFbMFiLWcwmAzx+HFB/O0UymzVKp6aTYMERdMlGpEZHYygqbunhbIVi6pTtnzb3Vns1HrcjQDbbvo6G42nkaW6vFZEQmKYre3wHsS93g0Lv1MLqDoBbZG2aZr8wi/8Av/pP/0nPvnJT+71mFrC7xgTDigMJEMspoq8cWURz/PQFJmCYaGpwm8JYNnrinrVnc8/eGaISzdXMCynJnVsNx/q6hdwYkgnZ1Ipe9+tgNBWSKRTt3ztYivXupPtdf13E8p6scmTZ0dEvvedFKbjoisy8ajOJ3fpuWlEWCOI5zha1h9v5749qHu9G4tap2ZudAIkz/Naxs9/9Vd/ldHRUX7sx36Mj3zkI3z2s5/l+PHjLQ9uGAaXLl3alYFW47//2SJF0yWkC1I2LJdMUajnnTwQJJN3WEybFAwP1xM+oLF+lcGeAAspi6GkVvluruQwt2pRshxiQZWgLjOc1Dh1IMRwUt9kFO3hz99O14wVqPz8XY/tzssznzJ57XKOoC4R1CRWczZLazaJqMJIUt+1a2lnHJfvFUkXHBJhpa3zbvU79ddasjxKpsdTp6I132v3c9s5x3zK5P+8lSZfcrEdF1WRiQRlvvvxxK7M8xe/vko8JCOVhc8APM8jU3T5gQ/17vj4e42dzL2P+/XebPV5vd84c+YMgUCg5nctLe1vfetbXLp0iZ/+6Z/e1RM3w8WLFzl//vymVtJv/+mf0d8bQJbEDY0BoZLJUqpIT7IPWypSXFrB9UBVJDRVYTHtEQ7rDA/GODwarygBAvT3F3nr+jKPnRiqbCOvLlqcPr11S7h+3EXXYuKQEN6ZnZ1ldHS0EvhrV6ejHYvx9On17vQ5M8+JiV5GBqI7upatjOngUJTXLl/m0PgQwajD1EyGq3O5mgBwo2P8xdUbxHsiDI+0P+/+tS6mCow1mY/nX77CofFIjbWXK5iknfVy7WZo9N2bt6dJO0meO3+S51++wskj2zt2O7i6fGWDpZormIyN6ZzfheM3gv/e7RbauUeboXdkfYfqv5Ny0drW7rTRtW332btf2MzgbUna3/jGN7h58ybf/d3fDcD8/Dz/8B/+Q375l3+Zp59+endHWobvUnBcl5W1ElfvpLgwOcfHnznG0+fG6YuHyBdNIiHxUBcNi8VUkYCqVIplNE1hMBbAMF0s28FTwLBdHh9L1OZ4Z0t8450FLMflW9cWyRdtLNtBVxVMy+Fn/t4TWx53tS9yMVUkqCs1ed1bTW9rJxLv+0aff/lKRR4W9sYX2GhMX/jKdSKai+W4XL2TQleVmgBwo5dtu37LdvzSO9leN/puUJNqOhFtJb96L4XAOhWb3aN201P30rWzn33mLUn7U5/6FJ/61KcqP2/FPbJdXJicw3Fd7s5n0VWFRFSnULT4wleuMTYY47kPH+Z3X7wMmISCKivpEq7jcepYL7IkGgzoqoxhugz1CnJ0PZdU1thQGXjp5gol0yYcUplfKSABkZCG47pcvLLIK29O8/S58bbHXf8gHByOcnc+SyIaqDQY3soLuNWHqxmh3Jxe4/mXr+xK+lSjMdmOS852mF7IoatKwwDwZumH1WNtRaztvPQ7qRRs9N2S5TFW1YmoHZ/t/RACe9DYbB4b/Q1oe072sipzP/vMO1J7ZDFVYGWtVHn5JSTCQQ3H8bgwOcfT58b5iedOEQnppLIGIPFtJwY4OBwHIBTQcF2RD+2jWLLpi4dq0qyu3V0jEtQY6Y+QK9gosoSiyJiWiyLLBHWFF1+9vaVx16eCjfRHGeoVW+lM0d2yjsNW08sapXnNLeVYTBV3LX2q0ZgSEZ284YkS/QYB4Ebj3U5KWqtUMD9V7ub0GpM3V5hZym7QJ2l1jEbaJiXTq5BOpbfiUpa3ry/xV2/OMHlzZYP+zE7TEzst1awem81js7+99OrUludkL/qgdkLq63ax5eKar371q3sxjhoMJsNcvZMiEV23ZCzbJR5Z79z99LnxigVcLyV5/GAPb1xZJKBLuJ5LsWRTMBx+5JnDwPoK7q+26azB7bksqio+b1oelu0QDirMr7a/8mqqzLeuLGI5gqzGh6JoisyRMfECXuzPb9knudVIfKOt9d2FLAeHY7u2FWw0pr6eIEurGTRFxjQdJEnCtF2OjDVvqbYdN8BmOw9Yt+ImxhIEdYW781kMU4yj3UrBRpbuU6eilbmaGE3wxOkhvvCVaziORyKi05cI8vq7C4wNxiqf28/WnI/NLOlW96LR3y7dWuGJR4drzrHZnOxVKf5+dkF1pKX95NkRVEUUE3ieh2k5mLbbtHN3vWXUGw9yZCxBXyJIKmsQCen8xHOnNrg5/NU2GQ8SC2u4Lti2hyRJREIqridh2U5bK/vUbJrltSIFQ+SJG6bN5I1l5lcKO2pc266inW+NfOkvbxLQRKd3v2BjqDfCSH+tFbgT8njy7AjzKwUuXl7gwuQsFy8vUCw5fMeZGEfGE2TLFszxQz2VfPRGc1C967k9m2FqNkO+aHFhcq7pnG+286i3bMcGY5w92l9ZNKvJtNkxqucR4Pu/4yif+OjJDVkFdxdynD3az9Pnxnjs2ABjgxub4e5naw5a72qazeOtmTRfvzTH5M1lJm8sk8qWKn8DtjQne6W0t5+VDDuyjH1iVHTu/sJXrpHJm8QjOsN94aaduxtZRv+fH3xsg2/tNz9fq2VdvdqenEjyzXcXkSWIhFU8JFzP49hovC2L9MLkHCN9EfoTIaYXcxWt7b5EcEcPQjv+zXprxLca/IfQ34lEwzqpbInphRxrOYOeaICp2fS2xlfpcl/OwvHw6I9r/L1nztVYZ9EWLdX838+v5BlMhgiHtE2tqc12Hu1atk0bCqjNGwrUo51z7WdrDlrHUxrN49xyjoXVPLoqFBRNy6nR8zlxKEmuTNrhkMbcUo67C1mGeiMNxd62ultp5kdvhL30me8lOpK0Yeudu5vdgFbbK58QC4ZNT1THdj1s2yMUUHh0IsGB4fiWMg5kSSIZFyX1fmrfTtHq4Wr1cvnkkc4b3JnLIEsSiizRGw9sa6vpL1DHDiQrv8sVTC7fW+a5Nsa71fFXYzMivDA515YrqWlDAa15Q4ETdZpS7bitthJQvJ86G+0SWyvCbOiKm89xcChGOKhx5U4KXZVq9Hz8BfDC5Bw3p9dYTBU5OBxjpD/acLHeajFVo3f9xKC54bP7GR1L2rD9lbD6oVxYKaBrEkupYsX61VWJz70wyVBfuOYFqbZIfeQKZlvb2Qcp49jq5fLJ43MvTOK6EIvqjA9GScZF2f5Wfdv150tlS9ybzzK3VGhLGrfZ8fxdQL5kEQ6oNXPpoxURtmPZNjvGl/7yJolYbT2BP4/1pN2uFd3OM3w/O+pshdja6QtaP4+DyRAjA1FkSeLkoSTTizlyJQsJr4aM/fdtqDe86WK9lerTZov/5XtpntutCewAdDRpbwf1D+WlG8uk8wa6piLhsZIuYtkusbDOo2UpV391r5d8vT2TIVMwOXd8oKUb4UFuhdu1+ob6wjx6tK+msGg7vu3q86WyJa7cXkUCYiF5W4Eiv4Wbn+IZDqoUihb5ktVw3psR4VYs20bH2Hwe89s+Vyv4HXXCAY1IQMWy3UpHnU//7XNbPt5m2AqxtfNM189jteHjq0g20/Npx/Uxs5hldilPrmAhySJLKUHjQr1mx7s367QxM/sHHRmI3AnqAxeG7WA7LqbloGsKpuVilQOb9YEN/0W0bJe3ry8D8Pix/krz1M0CknsV2Ggn3andYOVuBcaqz3dvPosEeEj0x9WGgaJW1/Dk2RHuzueQEBk4luVSMh081+M3/mD30ry2cl3ttDHbrbQ8v6OO47ospoospooUSjZv31jayeU0RLPgYbqwkdi280xvZQ5bPY9Ts2m+8JXrBDSZA0NRhpJhZFkmFFC2pGCZCCvNJ2Qf4qGztOtXW8fxkJBwXRE4c10PJHCc9RzueldCMr7eKKAau1GptxVspRqymdU3NZvmpVenuHo3hVkWxjo23lMpcd/ObqD6fKvZEr2xIAeGYhRzq0DtfFZfg6pIDfXNRVArRL5oUTBsZEk0idB1Bct2a7q03F3INfX77jQ9bLN5XN2DtoC+G28pVUTGw0VCV2VURcKyXdL5xjuNnaDZbqIZsW31md7KDqSVJX9hcg7bcUlEdSTWG4ispEto2sbxNjveqQOhDZ/dz9j3pF0fVNHKqYL+Q6nIEgFNwfVErreqyMiyjCKvuwjqrU2/PdjkjeWKH3xsMHLf82u3EqBr9HL5vtL5lTySJFEoCVJ84+oSh9IlHjs2sO0tffX5fBIolttkVs+nfw2tytuPjvdUjjN5YxnL9kSlq+sxNZshoMt84SvXOHu0vykh+5W0U7OZyn3rTQS25LOvn0d/l3D5xipXl7fur2+G6gVmsDfEvfksAKqqgSeC2H2J4K6XVd8PYmuX6FsR/GKqQCKiY1luhbA1VSaTN3ns2EDbx1udu7Fr19YJ2Nek3UijxHU9eqIBjo4nCIc04hGdpbUiA8kwsZBGtmCxvFYgHtdxPa+htampMpM3lkU1X1DFtES5+1bbUu0UN6fXKtZnJKgxPhglEQu0vXhcmJwjkzORJZl03kCRJSIBDccTlYu7QUDVJOB5Xtk3nWMwGeL5l69waybN4dE479xcqSlvT2VLzCzm+I0/eIMPnRnh4FC00utwNVsikzOQJInBZAjTcri7kCFUFZxstIDdmkkzv5IjoKmV+3ZnLoNhug1G3hrVxBrfxF+/ncyP6gX5+MEkc8sFHMehWLKJhDRiYZ1Th5O7bihsldj2OqtlM4IfTIaxbIe75QVN02QKJQtFkTZ1WdUfby92SQ8S+5q0m2mUpHIlLCfG8lqR0xN93J3PYjsuecMmqCscHolzcCTO8lqx8fbNq6QfU5OO7N2/a5uaTbOYKiIhin1ml3JMzWXoiwc4PdHX1jEWUwVMxxUPuiw69ICHY7kVSYCdvoDVJHBtysZwsxwcijEyIFK4FlbzBHRZZIQExeOWLZjky+Tk67G8/u5Cxf1RKFooikx/T4iQLr7jOF7FxeWjPmiVL1rIklTTbd2yHfJ1fs52UU2smTWp4UKxXZdMTdu3WJDRgQipdBHTdhntj1aqaXvjuy8V2i6xtePa2ksIgyDPweFYuU2ciarIfPyZY/syv3q3sK9Ju16jBCAc1LAdj2QsyCc+KkrGW1kL/hbY/3sqa3D6SB+zi/nKNvvwWBzLqbXY9tIKuTA5x8HhKNfvrZEriMIPVZZIZQxW0qW2fJ2DyTD35rMYlkNQF/PjuB6KLNdIAuwUPgnMzs4S7+mvFcwainF3XvTi9Mvb03mTRCSAhEQ0pNc0qf3ER09yc3qNhdUCiiRVemqqilyT9QIb3VrhoEq2YGJajgho2i6uR2Wx2CrayW7YrlpcvW/52HgPl0yb/oDGo0f7OqIQp13X1l6h2iDQVIXHjg209Y5t1sCi1Wf3QwPifU3a9Rola9kiy+kSrgt/8toUB4eiPH1ufNMtWCNLybcOq90huYJJtMrq2Y6FNZ8y21baW0wVGOmPMrOUxzAdXNdDU2UCusJwX7gtK/nJsyNcv5dibiVfzpYB2/GIhfWmkgA7QbrgMDxSJ5g1EKVkOvQmgrx5bYl4WCcS0ghoMqbtcGRcXEM1GR4d7yGoK6xmjMqieXQ8zmrGIFcwm1bSNfrecF+4RhZ3K2gnlXIzYt+MEOp9y5oqM9wXob8n1HwHeJ/hX1u9a2sz5cbdxlYDoY3ey3emc5w+vdHI2Stdk73GvibtJ8+OiGrGooVpOSykikhAUFfwoCzfyqbSqg3lVMvWYSISaJqfulULa2o2zWuXcxwaj7T1gPiE4boeI/0RJKRK2mK7udUTowk++expPv+nV3jrxjKKJDOYDDE2EG0qCVA93q1aIImwUhMEBkFyR8d7+MRHT1aO+fVLc0iSxMnDvZWGzNVk6G+LJ0bjtfP/xCHuLuSaVtI9cXqI+ZWN39vuddb76xvJ6m6nJN4nonrf8ic7TH7Vv7Zq11Yr5cYHjUbvZVCXGr6X+1VTu6NJu9kLVf370YEwU7MZVtZKyDIEdRVJkoiFNXIFi99+4RJ3F3JNSaeRpeRbh5v149uqJsKFyTmCutT2A+ITxlZU8xphYjTB//UTH9wSCU/Npvn9ly6TzhmYtsO9+Sw37q217Fx+6kCIq4vruhL1i51PVtXX1igY3CqroFkl3d2F3JYKXlpZWtXjuDfrMja2UUdlOyXx1VWBnUwOu/UM3k+0amDR6rOduhhVo2NJu9kL9cTpIV5/d6FKHEnFsjzSWRNVlQhoKgFdJluwkCUwLXuD8E9NiqAqN7UOfUveV5CDdYGjrZatL6YKBLVan+xmD4hPGC+9NlVxK1Sr5jXydW5GzFshiJdem2J2OUckqBEJaViWy+xyjpdem+LTP3yu6feGkzqnTzcnzerxBco52M1cAZuNd7OXrd3rnJpN87kXJivCWeND0YrV34hYm8nqNltgNiuJ3y/YzjP4oNGqgUWrz3bqYlSNjiXtZluXF1+9TW88UJuLGw+g60o5aOWQK29PPSAcXA90vfTaFIbp1CwEy2tFJCSGqbWUTk/0bmqFbbVsfTAZ5mZ6teZ3rR6QidEEn/7h9lTzduKfqyf7t68vCY2WqiwMz/O4eie16XH8MQMbFjuo7VhSr0S4Fez0ZfPnKp0ziEe0GiW6RLT9lMq9HGOnYCvPYCeg0XtZ3cCi1Wc7dTGqRseSdjNran4lT7ZgEtAUwgGRi3tjeg3TtAAJTVNwHBfLstFUhUeP9FW++/o785w50lezEIwg/HT1rpBW/q6tVH6BeEDeuXavEkjbygPSjvW4Xf9cI7JP50UubECvejykpodoebwXvnaDgN7aXdDquP5cF0sWt+eyKLJEvNyAQFHkLbdwS0QDlTgBwPRCDk2Rt0SsrXaEsL8IoRk63ZXjo9F7Wd3AotVnO3UxqkbHknYzS0WSJGSJGiswn7IJBjROHkpy/V6aAjayItPXE6y0IPM1CXzdhVSmVKVABv/y77yv5mZ96S9vNlWeq7dMv/87jjKzmOVzL0yykinSFw/x3IcP1wRAJ0YTPHUqStpp7ievxlYDgc0WuVsz6U0zVhqRfV88SCpbIqRrldS5omFz5kjr4qJmi8dWO5bUz4VPjIoicXchi+U4BAM66bzI+W6Wu9toHv25Gh+McqW8e1BVibWcsWVibXa9W/Wvd7F7qF9gLl682PZn9wM6lrSbbV2GekPrubiaEBeybIdILMihkQSHRhKkMiUu317Fcb2aQNdIX5hvXVkkXxI/JyI6AU1BkqQaV8LUbJqFlQLfurJAvmSjKDKRoIph2qxmivz+S5cZ7gtXLKvf/t9vM7eSJxHRScYC5Itmw8yV4aTOc220G9uKq8MnpVszae7NZzkylqjoefuC9IPJUNPjNCL7UxNJvnl5EYC8YaMrIh3t2Q9PtBx7s8UDaBg7aMeqrSbGqRvLhAMa4YBw33zw0X5yBZO7C7kN39vM6vc7Fvnyoemyb3ur7prd8K930cVW0LEqf/7WpV5h7OwjAxwaiaNrCoWSja4p9CfCRIPr+cHJeJBDwzF6ooHKd584PYRli/LtQklY3amsQa5ocmQ8Uck79V90XZPIFixc18O2HYqGTTpv4gHpnFHT/mhuOY/rekRCOrIkEwnphANbawpcjXqlwmYtlqrbQR0/2EO+ZHHp1jIrGdHE9/rdNTzX49KtFd65uSLcQKHWLbF0VeEDp4c4f2qIo2MJzp8aalvbuZnS2omDyRr1t5mlLJM3Vypd4jdT8atWpsuXRLxC00SVJTS32JvNIx6VsSRiASZG4zxyoIdPfezstvzr+7mlWBf7D21Z2v/sn/0zpqenkWWZcDjMv/23/5ZTp07t9diaWir1ObzzKwU8vBp/saLINS/h8y9fYbgvTF9PkK+XSUtTZMIhnWQsiOt5LKYKvPTqFDOLonu554EsA5KE48JAT5BM3iSo18pYGpaNqtaqjoWCKiuZ7XWtaTcVqX5rfvaRfm5Np7l+d40TB5OYtkM0rKGrigi23Ulx/FAPhTZaYm23wGCz4/ljvjWTZmE1X1Pu3k7OejSsEwmKwCFApLxQX7+7yr2FPJ/65T+rcU1taNaQKXFvIctqtsTjjwxU+mjuxHWxl8Gs/Vitd7/xXpyjtkj7M5/5DLGYqCr7yle+ws/+7M/ywgsv7OnAmqFR8ODHq6QcW+VVy5LE+GCsUupcMGygXBChyLx5fYlYWMPzXNGd3YVwQAEkYhGdVNZAryPogKZi2g4LqwUs2xWWoAJ9ie1ZW+1mHtSTUjIW5NtOit1FMh6kLy7+JlXpcdyeyXD+1NCm87ldAptPmVxdnqNQsllKFQkHVY6O99Qcb2JUdCwZTIa2nLMOMDoY4d1bK3geTIzFuXpnhXemUvTFA/TUuaZqmjVkSly5k0LCozcWrKSt7bT6ba+CWfu1Wu9+4r06R22Rtk/YALlcDklqM5Vgj9AorezJsyMVrZFGqH6Bx4eiXLm9imU7hINapdItoCvEywSiayoeUDJsioZDLCzEqPysg2qrvicWYHophyTZ6JpMybTJ2C4f/dDhbV1fu9Zbqya3gtREoE3TZDzPI9ugQnA3fK/VFZ+HR+M11Yj1x95qUUM1MRZSBQ4OxVhIFXn7xjJrWZNYSKW3vEBFQjpg8uKrt/nUx85W5vHeQhap3I74wHBsV6vf9sJ3vV+r9e4n3qtzJHme15Z23b/5N/+GV199Fc/z+J3f+R2OHTvW8juGYXDp0qUdD7Ie8ymT1y7nCOoSQU2iZHmUTI+nTkUZTjZWRav/zmrOZmnNJhFVGEnqnDoQ4utXc8iSx/SyhYdHvujiuB62C7GQgq5KfMeZGP1xjcv3iqQLDomwQq7ksJa3WU7bGJZHQJPoT6iM9wX4rscSNWOo/t6pA6FNx9vqs5vNw+V7RYqmGP9y2qZkuSiyxEivxvc90bvj+a8fm3++kL4eJvF/rp4DgD9/O932Zxudu/qa//pqHln26Imo6Ko4XslySBdcTo6HUGTA85haMIkGZQZ6NKJBsevwPI9M0eUHPrSz+dgLfPHrq8RDco2B1MnjfRB4L8zRmTNnCARqi7Tazh75pV/6JQC++MUv8iu/8iv8t//233Z04ma4ePEi58+f3/Qzz798hUPjkcrKmsqUuDWT5uW3SnzoTLKpX+v06XX/19mxjf6vtCP62w0Oukwv5EhlSxRLNsGAyt9433jN56v76f3m59/gkYlQjQqd34n9/Pn3AcIS/d+vvc6h8SGGR4T1fHXR4vTp5lu5dpqRVl/TWJVP73TV1vGRiZ37qn1Mzab5i6s3iPdEaq6j5MkEtRyjo6NN58BH78j62PydhNzm2OrvfXzmLiXTxvZU+mJhiqZNLpsnHNQ5cXS8ct1PDChoqryhafPYmN6w0rEe7TyXu4mry42bTLc73q3ifl/fbqDdOdqP17aZwbvllL8f+IEf4Od//udJpVIkk8kdD247qN5e+75KTZUqoj7N/FqttrEHh6J84SvXcByPeETn6FgCRZE3TbXzO75btsPYwLobqd4H3Uh7JJ03GnaF3wqaXdNe+VqbbUmXUkU8tXbT1iyLYidjq3etHDuY4K1rSxRLFq4nAoue53HyUJJ01qik8+m6QjIWYIT9UeyyX6v17ifeq3PUkrTz+TyZTIaREeEH/epXv0oikaCnp2evx9YUvi/XclwuXlmkZNpoqkwyHmxIhkFN5i/fmmVhRbQROz3Ry4/+TbES+8ShKTIr6dK64PomRRv1gkqu6zG3LPKEQwGVKb+L+7GBSirb1y/NUSgYFOxlxgejIMGduQyuy4au8Lvlj9sLX+utmTS5grneTWcoSiIaIBxUKRVqM3g2e4G2O7Z6P/6h4QQlw+beQp5U1kACHjveTzwS4MqdFLoqE4toZPIWEtKmeiedhP1arXc/8V6do5akXSwW+cmf/EmKxSKyLJNIJPjsZz/7QIORT54d4fdfuszscg7DtCuNUAsli7vzGaYXsxUyvHZ3lcmbK8gIQnWBN68tMbecp78nXCmS8Ytuzj7Sz9gxYTE3K9poJKhUNGymZjJYjks8rPPYsX4KhsUv/+7rmLaD54IsuZXUO0UGWZKIRddziKGzgyhTsyJVT0JYNr5mx8HhGEfHe0goXtsVn63O064OdaFokYgG+YnvPVPJTMkVTKZmM+iqjK6JdMeeaIDhPqEMuFnAupPQLc5pjffiHLUk7f7+fv7n//yf92MsbWNiNEFfIkg6Z5AvCgtqMBlCkSWu3V0jGlIrZHh3PodIuJZqcqnnVvJoqsIjB3oAsByXcEBleiFXUXtrltFw9U6KcECtpNE5roftuCysFjg8EufIWAIkuDufpVjuLh4N6yysmoTDNgFNYX6lQG88KKzuMjpdBe7C5JzQGl/IYVluOW3O5K3ry5QMh5Bc5Ic+2l4RTjNsRS610eLgk3o6ZxAri0H5zRY6dX53mmv8XsxVfi+jY8vYW8FyXL7t5CDprFG2XCVUVSJfNAkH1QoZFg0LD/CqFI80RSLveJj2epFMJKhhWHalyg5aVLaVD1c0bJbToojG8zw8vIol7edzO65HPBKgWCxSLDnIkujXeGg4Vik5b3m+DsBiqsDIQJRwUGN6UQRq80WLoK4yMZZg6k5uyy6eesJJZUs70qH2Sb1advXIeIJkLCiCzB02vzvNNa7//sxSll9/fo7BZKgiL9wl8IcLHVvG3gp++bCvH6FrCtm8RSSk15BhKKAhAVJVV17L8VAVqaZIZnwwStGwK8L8fu52I0nHEweTFA0b03JI5wyAchm7hoSErgr/uKbJyJWGuhAKyGiqjK4qREIqdxeyzCxmW56vU1A952cf6ScZC9IbDzKQDCNLEiFdblhu3wzVZfg+Yb15balmMYWt70AmRhN86mNnOXagh4nReCWvvhPnt13Jgna+n84Z3J0X+ej5cpn+C1+7salEQBf7D/vW0q72bSZigUqFmy+J6QfEDg5HmbxpgAe27eACluUw0hepKZLZSo++Zz88wXK6SCZnUjIddFUmHFI4MtbDzJKQ9wRhOfuCVKblUCg5GI6J7QgNkFzR4q0by6xmSpx9ZKDjgyj1/uS1nIEiS9t28TTKRImHdaZmMpViGdjeDmQzN8qDdCfUn/vm9BoTY7Xn3socVmfTTC/k0FWlUum7H+IkXWwd+4K0m71kzV7KscEYL702xevvzANw8lAPqYxJKmugKhKPHEgwOhAjlTF49/YqxZJFQFc5cTDJs09NtHzAJ0ZF78ULk3NYl+bQFLmyBY9FdG5Np4mENDzg2IEegkGV2zMZciWPZFxMuarIDCRDm1YOdhrq57wnGqA3Hti2i6dRZeThsThvX19umYXSDvE2cqO88uY0X/jKdWzHJRHRsWyHF76Wvy+lz41cIYupIkFdqWk+vJU5rM6m8Xs5WpZb0WXpVD9+F9tHx5N2O4GpRjBMhzNH+iovfl9iXbTI73vXE9OZXswiyxLHDvZUmrG28wL7567vd6gpMmOD0Q2tzc6fGkJyS6h6CMt2K0HMcFAjkzf3hTVUT5TvPzlYs6spmi7yFvJkG5XhB1SFc8cGNu3PuV0/8NRsmi985RoSkIjqWJbL3fksB4dj92X+GzaRHo5ydz4r0ia3kWtcs/sJqBSKFh6SCIbT+XGSLraOjift7egL+Ep9liMsjvHBaI2f0D/e5I3likUyu5jn7CP9LY9dj1bZDP6/U7NpXvvWFEsrGSJBlXhEJxTQsGyXeETvKGuokRULbCDK199d4InTQ9xdyLGYKhDSGxciNcN2FQa3qzlxYXKuUjglsS6iNb2Y4175GvbSXdKwiXR/FMPc2DlpO8+fb20fHIqSiK378R/2YpP3GjqetLcqLjQ1m64o9fntyOolSf3j+dtJPFpqM2+GVrmivmUYC8ukCh4l08G0SiSiLrIsM1wuAuoEbMhGWMzy68/PYtliATwynqjJK7+7kKvkPV+8eHFLZLfd4ojtdtFeTBWIR/SanY7juswt5zkwGNtzpbhmAl9HxhI7yh2vfv7qF9xOj5N0sXV0PGlvtUHqhcm5ilKfL0laMm2++e4CiWiASFAT6n4BjVzBYmWthKpKJMuds5sdezMfaiv/qm8ZujGNRCLJpVsrOOUWXqcO96Io8p5kNWwn4FZtxaYyJe4u5JCAfFG0W/Mb4CZjwV3xl26nOGK7TXMHk2Esy6kUTGmqzHK6hKrIHBlL7HmR0/0ou34vFpu819DxpL3VB/3t60ssrhXIFSxURSKoyxRKDo7r0RMLoGsSl26uoCoyoYBC0bAwTCgYNjOL2YYNYjfzocJGt0Gzll6ZNTg4HCcW0bk3L8T4xwZje7Id367ftyYbYTGHXu4Sk8oaYhFUlUoB0l77S5stOu0+E/XfPzgUZX4lz8GhKCvpEpm8yOQ5+0hfTTB1r4J379Wy6y52Fx1P2lt50F95c5rb81lkCSIhlULJZi0nKhITUR1VllnNGOiajGm6KCGNob4IeGBYDqsZo2HLqc18qEBL/6pvGfpIxoJoisxj4YE9K6nert93QzZCQGQj9CWCmLaLpkrkiuae+0tbLTqtnolG36/2wWuawmPHBkhlSmhqbbnCXi5GXUu4i52i40l7K3jx1dskIppoxitLaIqM7YhCjd54qOLHXMsa9MQCfOjMukvClxFt9EK9fX2JxVSBouEQCqgcO5DgwHC8Yo218q/6lmHRdGsaDe9lgGi7ft8N2QglUVF68nAveEIwSiq7Edq1EnfqpoGNi04r8tusS3r1QumTO7y3lOK62L/o+IrIRlVzzaq8VjJFemIB+ntCKLKM5bjIEkiSEIsC0cHFcb0NLcOaWVe+9V4yHYK6jGk7vHVjhet3VhlMhttq7OpbhiFdrmlSvJcW13YbzvpjjYb1cq65xMHhGImoKGAaG4zyL//O+/jER09uqcy6nftXjepmvj62WnTSzverr/d+3ZsuutgJOt7S3so2vy8eIl80iYR0Qr3i0nJFE9f11ntClhq3DKu2rqotw4tXFnFdF9NxsWyZYEBBluDmbIaf+L4zAG1ZahOjCb7rscSGhgB7hZ0EvXwrdmo2zUuvTnH1bop7CzlOHEpumdB2w03jY7tFJ62+33VZdLGf0PGkvZVt/nMfPlxu6moSCqpoqoTrwcRIDEWRyeRNFEXik8+eYmww1rTE2dfKzhVN0jkDWRJFMJbtVrIoNFWpvOhPnB7ixVdvs5IpVrqBNyKB+ZTJ8y9fuS/l0zsNelX7hJ94dLhC+lvFrrhpdlp00qFuj646XxfbQceT9lYspqfPjQPUEOhHP3iIkuWymCrw2LGBmhej0Qvy+T+7wpXbq0iShOU4eB64Hli2Syys47ii7+Jwrzj/1Gya199dYGI0zqNH+ygULV5/d4GxwdiGwJjf+PZ+dY7eiQW5W01Tt2sx73TR6fRMjfvRSby7KDyc6HjSbmQxza3k6XdcfvPzb2x4GJ8+N14h761iajbNW9dXUGUJTZUpmRaKDI4ryuKjIUHYhunw3IcPMzWb5nMvTJLOGSSiAcYHo5XUsXpya9RurNHnOgXbtZDrsRtummZoRUqd7PbY607i92NR2A106sLSqeOCfRCIrA8UWbaLhMgM2Upgqx1cmJxDlkBRJCRJQlUUFEVGlUGSoWSKjubvOzEAwK8//wa3ZtYoGhaZgsmVOylSmVJDcltMFQhqtd1+6j83NZvm+Zev8Juff4PnX77yQCU1NVXmW1cW+fqlOSZvLJPKlnaktrfbgb7tBjg7BTsNtLbCTiVf7wc69R526rh8dLylDbUW0/MvX6npqr1dC6XRSrqYKjDQE2JprQgq6LqMVXBAkjg8FGN8KMrd+RzpnMlnX5gkElRFuzHbrVjb04s5NFXeQG6DyTA306s1v6smwU6yjKZm0yyvFSkYFqGAimHaTN5YZrQ/yo9/8NCWrZC9sHj32lLda+w00NoKu7VT2kt06j3s1HH52BekXQ3/YUxlS0wv5CoFINUPfys0Isjfe+ldUlmDfMnC8zwcxwVJIqgreEA4qApFuKEYK+kSxZJFNm8iyaILfECVKRQlPNdruP1/8uwI71y71zRj5aXX6kSuhtZFrvwA6W5v15od88LkHCN9EfoTIaYXxRxHghp9CeH62aw69H6hU0hpu/dlrwOle70o7AY65R7Wo1PH5aPj3SP1GEyGmVvOceX2KqblEA6KApCF1Xzb25f6raNlu8yv5LEsh6AuFPg8IKAp9PWE+fQPP87ZRwY4e7SfscEYC6t5LMfD9cBxPDwPSpbLWl5Y242s44nRBE+dijZ0E0zNpnnz2hKe5xEOqpWGuabtsJgqtL1d24p7ZbNj+lt3v0PNh86McHgsztW7KX7jD96oLC4Pctu93Tz03cROttF7nR/+5NkRcuXuNZ3aGakT7mEjdOq4fLS0tFOpFP/qX/0r7t69i67rHDp0iF/4hV+gt7f3foxvA548O8KvPz+HhPC7ZvMW6bxBUFP43AuTDcvQ61G/kk4v5ggFVGzH48TBJNOLORRZJhENVI73m59/Q1j4mRKrWQPPWz+eLIlgpYS86fmHkzrPna+txnv+5St8/dIcpuVgqAoBfb1h8NRMhvOnhppu1156dYpkPMhiqoCmiBZnfnf5Zu4V3zL8ut+8oYFQUr2VlsqUePfWCuGAhum4eJ7XUDjqRH+bN3EX0AkpfTvdRu9loLTTs2egM+7hfhqXj5aWtiRJ/KN/9I94+eWX+fKXv8yBAwf4tV/7tfsxtoaYGE0wmAwRDop2V+m8QSKi09cTZC1ntGXp+CtpKlti8sYyt+cyLK8VkWWpYl0+9fgoQ33hiiW8sFLgaxfvceHSHJ5bezzPE31+ZVnach50rmDieR6RkMpKRogY+Q2CMwWz4muvD1oZtsOb15cqVt6tmTSzy+sWsOW4zCzm+I0/WLe668/pNyFOZUrA+haw3kq7NZPG8+DIeIJoUKsRjoIHY4X4pGQ5Lq+/M8+lWysENKXhZ/cqwLvXwcSdYmJUSL7+ix9tv4L1fqJTq1E7dVw+WlraPT09fPCDH6z8fO7cOf7gD/5gTwdVj3q/YW88yFBvmKnZDJGghq4pmJZDQFMqRPWhMyNN/YtPnh3h91+6zOxyjnBARddkDNOhUBJEXq1g5xNdbyLA3YUMnkelRbCEKJGXZeH7VpX2vU3VVlo0pGNaDn1xiaJhoygiO+bcsYHyIrXRP3l7JkM8rFckVBdTBVzX4+LlRY4dTDC7lEdTZDyoWN0BXdlwTl2VmF7MkYyvX3O9lWY5LmeO9pGMBcGDK3dSDYWjVudutHX/djN9qr5DUf3uYi8DvPvBb9zpuB9pmdspauvkdFHJ86o3+pvDdV3+wT/4B3zkIx/hk5/8ZMvPG4bBpUuXdjTA+ZTJa5dzBHWJoCZRsjxSORs8SOUdQjo4rkTJFOZvQBNVkIcGA5RMj6dORRlObgxSfvmvV5hL2TiuhwSYtouqyAQ1ieFevfLdy/eKFE2XkC5z6U4By/bIFMW5FEm4RlwPFAWODut8/DsG27quL359lXhIRpIkciWHe0smikzDsTeag+uzJSaGdCRgetmiYDi4nvCzy5JEKCALWVVFoj+uMJeyWc3ZDCVU+uMqSNKm56zGn7+drswBQK5oV+bu5HiIUwdCDee42f3b7L5sBfXjAio/f9djibY/s13s5bV1sTvY7/fozJkzBAKBmt9tKXvkF3/xFwmHw/zdv/t3d3ziZrh48SLnz5+v/Pz8y1fQgh5T8zmKhkUooHFwOM5AMsz0Qo61nEFvPEDRtFEkkQftuC4FW2MpXeCPLmQ5MpbgyFiiZoX9q+tv8MiREHL5O6lMiXsLQuP66OHxymevf/4NJkbE51YKy5iWQ7/rMr2QQZZlHMdDUyWOjCb4Jz/0+Karc/W1XV2+wsxSlpW0yFiJxTQsy0VRJI4e3rhLOH163VodS4bp7SuhKTJTsxkScYeo67GYKhBQZSzbxUVGDwQY7Q8zu1wgEtHJGwVkNchKQeLkoSQDA3BrOo3luPT2JkGC68suaafWGukdWbdWfYs23rOxLVj9vfPv36HxSI01miuYpJ1a//528FfX1++ND1+t0dd4aecz7aDRtcHG+9JJRRhbQfX1dXJhyVbx/MtXCOp5jh5eL7jbredvL7GZwds2aX/mM5/hzp07fPazn0WW71/SyaWby9ycXkPXFIIBFct2uDy1imm5fOpjZys6IfPLOQKagizL6JqCZbuUTAfbcZlfyRHQ5Zqu2/Vb22Q8iKZu1Liu/tz4UFSUuAMHhuIoskymYHLu2ADPfri2i3v1g6+pMngwv7jK1eUrPHl2hINDUb76zXuEAwqhoEqxZGNYDj/xN081rOis36752/50ziAW0cCGaFgjHNRYThWwbY+Th5KVRgYA/T1hHNdFwuPegljMxgajPHF6iNffXagR0qpvntyuvkr99X/tjXuEdBVJlnBdj0hQY3QwUmn9thO0457YaxdGtbjWhck5vvSXN3dMdA+SNDupXmA30E5R225jr+9fW6T967/+61y6dInPfe5z6Pr93VLMrxRQFBmtLKWqqQqO6zG/IibdK3uYNVXB9cA0bcJBjaJhQ1noKaCprKYNJkbjlcj+waEoX/jKtUqj175EsGHXmvrP9cYDrGZM+hIhjo73NLwh/oPvuC7TCznmVvKoisxYr7TuX9YUThzsYTUjLO1ISOfAUKDSCqsVfL/z516YZC1n0BMN8Pj4AMlYkJmlLHfns2iqTK5kockiMHnysMj48bvmPBYe4Hs+eGjTRsg+IX31m9OVuIFpOXz1m9Mb9FXqrz8a0gjpKgtlf3tAV1hZK3FnPsPZo30bvrPVB72dKP/9yARo1dloK9f1oEmz0wtLtopWRW27jftx/1qS9vXr1/nt3/5tDh8+zI/+6I8CMD4+zn/5L/9lVwbQCq4nsiEcx0UuW2uyJOF6Li+9NkUmZ2I5Ln2JEEXDJpM3KJRMTMsDySMe1dE0WRThlFdYX+Tp4HCMlbUSS2tFppdyHBiM8tKrUyAJgShNFfrX/ufSeRO1JPPxZ45tqm9yYXIOx3W5O58lkzcJBVQc1+PuksXhAy7RkMalWyscP9gDGaPyvVBQ3ZIFMDGa4FMfO1vjusgVTBRZ5uPPHK/0d5QkqZKeB9R0zWmnEfJLr00xu5wjEtREBajlMruc46XXpvj0D59reP3+iy9J4p7ZtkgVDAVUTMvj9lyWqdl0ZVFo90GvJ/fqbvCN0truR+pb05TM16bKmjXtv8APmjQ7vbBkq2hV1LbbuB/3ryVpHzt2jKtXr+7KybaDkb4oq+kClrNOpEEdQiGdN68tEQtphIOiJZaHh6rKFEq20NQOihLsmcUcjuuRzt7h1ERvzcSGg1q54k+0J7t0axnPgzNH+7g1naZgWJw50s/YsRgg/GGtrOHFVIGVtRK6quC6HpoqI0sSpgnTCzkePdqHaTlcurlCJKhVCmou3Vzh7CNbS3ZuRUq+pakp8oauOb7gVb5oYVoOyViw0izidjlHHODqnRSyDKmsgWU7aKpCMCBz9U6q6fX7L77reQQ0RVSZuh6qotAbD2JYbuVBrn/Qq9MVq7OAmrUQq88WaZQpsJeE14zoXn9nnjNH+rb0Aj9o0nzYMmL8ora0o9+XfPX7cf86vozd18iOhtZ9vwXDYSgZYrkcXJIQXddj6KLFWMBDVSQWVvIYllvuGalh2A6Xb6+yljU4UybH6YUcuqqgqTIzS7nKwzmzmMdyXEIBtZISB+3dgMFkmKt3UiSietmd44IHmiKRL1kUihZBXcWwbPGFcv6OJK3/fx/tuA02I6VmpA5UfOIDPUGW1kTa4EAyhCJJZKuq50zLIZMz0DQxT47rspaxiEcbB5erX/xIUGNlrUQooKIqCkO9YUzLIRFZ31VUP+ipbIkrt1c3pCv617CZFfOg5E6bER2w5TzuB02anV5Ysh3UF7XtJe7H/et40m6kkf0jzxzmm1cWOTymc+3OGiCqIz2E7sdwbwTbcXFcQdiyLKEoMoPxAK7rsZgqUiha681ry5Y6CElWJEEepuVSNCxURUFWPOaWihUdjlfenObpc+MNX+Inz45wYXKOQtEiFtaEABWgq8I1kSta9CWC9MQDzC7mK8c8PBbHckQ6od815s3rS8TDOhNj8W2TUCNSf/7lK5XAo2k5DPaGRbVnusRgMsy54yJH/JU3p1leK2JaLrLhEAooYgcBBPWNxSxTs2lSmVJl3H09ATw8TMujNx7EtBxM22G4P1x5kKsf9Ot3U6JLui184JbtVvzrrayYByV36gdy/fH4RHfiYLLynPlo9QI/aNLcD5WUnYz7cf86nrShsUb23YUcuYJZyZDIlyw0RRZb/KDC3fkijivyqIO6gqbIhAIarueSLdqVLizhgEqhaOEh0ZcIiqwTS/w9EdEpGTaGafHOzVUCmoKqSGiqxO++eJnF1QK357MNLbuPP3OML3zlGrbtMpAMYdsu2XyRI+MJnn1qgguTc+QKJvGIxsJqgaVUkdnlHCcPJSvkMLOYIxYWltrVOylOHu6tCRDuBLdm0uQKJqlsiXzJJhENMNwXJpO3GBuM8uxTE7zy5jS/++JlJEkq53N75Es2bgDiYY2+RC2BzqdM/uKqILXHjvVzeybD1GyWibEYyykDw3JJRFSG+8Moslyx5P0HPZ0zmF8tiCIlSfi/q/3rrayYvd6abtYseLPdjD+Odl7gTiDNTi4s6XTcj/u3L0i7EfwXPRrSKh1jckWLewtZLk+tomsKiiThuB65ok3JdLBsF/DoSwTXt9tla/vgUJRgUOXdWyukcyY9UZ2grhKLeKxkiviOmIHeCCFdJV80+eNXb/OhM8MNLbtPfPTkhpZmCSXFc8+cq1zDZ//oLW7PZYTbQZEwTIerd1J8/k+v0JcIYTku4YCKVHYD+f7wnZLQ1GyahdU8EiLVUVVEWzXDtBnoCVcs+c+9MEk4oOC6OoZlY9seli26+Rw/mGRsMFZz3Mv3isR7RE52FNGz0yfZJ//fI03dPNWZMLIkocgSyVioHLR0Kv71VlZMu1vT7aZkbbYoNCO67bzAXdLc39jr+7dvSbt6RXv7+hKLqSIesJYpoagSmqoQCgirGsB2PDzPxbRcgrq4bD8fe2o2zUuvTXH1TgrX9UTQzANdU3hsLMErb80SDWk4jkeo/N1QUGV+dXPtifqbd/HixZrxG7aLpilIgKoq9PaEcB2Xd6dWeeLMMLmCxcpaiYCuEItoWCV3V/xjFybnODgU4+5CDssSbdQUWcKwXMYHo5Vc4/nVAoPJILIis7zmEAzIhIIKRcOhYNiksqWa7kHpgsPwSOP5aPUgT4wmGOoLMzYY4eqdFIos4SHuhe9fbzfo6p+3kWW7E7/3dvyVXQLuYrexb0nbx535NDdn0gR0hd54gJU1D8f0KMhCwL9oOtiOiO7pmsqZoz30xoMbXAzVGhZff3uW1ayBn5KvqjLFkg1ITC9m0VQFTZWIhvQt+yyrrbyFlQKJqIZpCQs2kzPRNYlM3uSr37iHiEpKuJ5HybRJxgK74h9bTBUYGYgSDmoV15IkS5RMG02VScREgY1lO6xlDXrjIfp7QmRyJoWSha4ple5B1cU4isyW56MaPimePNxb0UrXFLniX4etBV01RSagKzUFLzvxez9of3MXXcA+Ie1G21kQ/sIrt1MENAVFlljNGOi6jG25GKZITVMUmYAG4aDOM08cBIRvttrFUP0ipzIlTNvF9TyK5QYLtm1TNB1CuoKiCHLLFFz+X08cIFO0SOcMVtKlSrf3jz9zvOF1VPt8+3tCKDLMLxeIhDUCmoJh2SynLaioBoqS9HzJQgIUReaJ00Nbttzq509TZApFi2Q8WMmKuXh5gf5EqIbMjo7GuXovTUATmTvRkIosSxwdi9NX91mAzJq3HivYQQf1apfX/EoBPBr2A22E6gpF/1jVC0uhZHN4NF7znXb93p3gb+6ii44n7VfenK6pSLQshxe+lq8o1gl3h4xhORRKNuXsOiQ8euM6M0sWRcMlGZcbKvj5pda9sSAHhmJML+aIhUWO92q6hO166JrQuFZkhaJhEwqoHBuPEA4HODnRxxe+ch3bcSsSsdXd2KsJ88btNQ6N9VdILhbWKv52XZUpmQ54IMsQCakUDQcQqYCxkEYspPF/vnmPt24si5x1Ra4UAjUjtEbugJV0CQ+PEdbJNVMwefxYbY74sUO9GJaLJEkbMncauYUcT9rVDup+RlC19d+uK6OZRb1UlTnkYyu7ga67o4sHjY4m7anZNL/30mVRkg7kSxb3FjJIkkS+ZNMTDSBLkC8K3Q4kQXASQrnu7nyWcFAlFtHQFJk3ry2iayK4dWQszvV7KUb6IvTGghRKFlfupMgXTQAM0yWgK5w4mOTK3RR4Hk+eHa2MrdpaP3u0r4YEcgWz0smlmjDfKjncmc8SDoquMJqmMtQbZCVtUDJdJCCgy3geyJKMIntImtDrdss+9vmVPJmcyeGxeE0hUDNCa0Rew4gClmh4veDg3PEBsQhUoVC0eOxYrRYLrGfu1BNfIqzsmNSqv//8y1fQlO31A20WNAwH1R3tBrro4kGjo0n7pVenSOcMgroQgioaFoWSg6II10G+ZFWsUaDiVvAF3TwgGg4QCigYlkMmb6EqDv2JIJM3VnAcl/nlvCjiKFpoMmSLtvDZSiLYeOVOCs91CejrU5XKlirqeLdm0hw72EO0atz+drueMCNBBVmiUqwTCQpGPnYgzNlH+pm8scz0YhbXcXE8D9txkSQPDwlZFpU3oYCK5bjMLuYJB4S1O7OYr1RS1veUrN5FVBcILa8VK2XsFybnSGUMFlbzHByKMTIQbRrEuzA5x62ZdMPPnjpQS5I7xU5S+JoFDX29mK6Lo308TKp/DwM6mrSv3k2JAg5JQpJEdgMSuC4EdZmi6Wz4jlf5H9ECLJM3UJUg2YJFWFdIFyyW1oqUDLtsLRcZSAbxPFjJmngemJaLpsmsZkqosshRHuyNMLOUrRGAOnO0l5U1gzevCu0O14NIUKM3HmBsMLaBdPoTKss5j3TOwPU8euMBltaKHBgOVH6eXc6haxoBXcYwbWxH5Jk/eqSXueUCcvkcfkNjJLEDgXVCq3aJVO8iTh5K1jQ7qP7c4dE4AV3m7nyOkulwdLynhsza+WyzJgibYTNC2El1WSX3O2+s68YoQjem6+JoHw9awKqLjeho0gaIhDUyeQtwcV3Pd1ijqgqK5YBMpYgGhBuhGqblki1Y2I5LJmejazJBXaFQsnBdkCSP5bUSAV2t9H2UZQnPdSkZwrJVFIUzR3p5+a/vspwu4joetuPw5rUV4hGNtWwJ03IYHYiQL5osrRUrwVKRheEyvZhjadUkGAiia0KIamwwxpNnRyqCR/7Pb11f5urdFL3xICAx3BdifqXAzFIO1xF55obtsrxWRFPligXtE1q1hX9gOFaWkxVyrJoqVyzo+p3A2ECMRCRANKxvcIm089nVLfb2nZpN89k/epv51TzFkoWHEFl68uwIzz410Xa2RjPif+L0UCUekigrOVbHG7pojQctYNXFRnQ0aZ84lGTyxjKJqF5OuROIhrSKNa0qMhIiWGaVU/vKvC7cJJJHvmghC0nrsiCShCLLuK6L4wjVQNNyqtqHCWve88R3JUniK9+4hySBKss4uHhI2LZDKis63tiOS9FwhMTqsJBYffLsCL/30rvMr+QJBVRkCWzXYzAe4vu/42jTh766+vOVN6f5/T+5QipbIqDJmJLLcrqEIgvxJdf1yBdNZpayKLLM6Yle/uf/uY7neURDQgP85OHeDXKsE6MJvvSXN9tyP0zNpvn6pbmaY1Y39N0uPv9nV5iaSwv1Rk/Md65g8ua1RQzT4WPf+UjLwOZmluDdhRxnj/Y3jDfspRbJw0Rme1FlutdzVn/8hGLu2rE7AR1N2s8+NSG2tjlDVMopIq0vGlIxHdFay3FcgrpI+bOK68QuyyDJEo7jIUmgSKL1locgB6FHImGX/+55HuGgSsGwsexacz1f7oU4NhAR30eQtyt5mJZDKKKjqwofOiOsaz9IOTGaqOQ3W46LrsqceqQfTZHbJo67Czlh6YREL8yZpSx5zwZPiGT1JULkiiaXp1IM9Yb4wleu4bkeui50r/2u6UfGEps2ePBR737ws3dW0yVURcisXrktcqk1RW7qqmjnxXz31iq6qmDabrlUXkaSXFIZo1Ku36ohbX265vRijnTO4HMvTBIpu3KqsZtl7e8F18FuCyDt9Zw1Ov470zlOn04/NPfk/rWg2QYmRhP8+LOnOH9qiKPjPXzX+YP82N88QW8ijGW7OK6H50HRcMhVEbYii4wLytrbuirTEw1wdEy8wCJNUOiIqDIEdAXPE6lzirRxHFK5D+RCqohlO9i2i+U4+Oa+47jo2vpUFooWmirz/MtXeOv6ErYrNMFLlmiKYNpO28SxmCpg2g5a+fieB+GAgiSJBcfPmAmWr8Hvd5kvz4emyNyaTpOrUu3zUd913W/S639uajbNF75yHQno7wlhOR6prIHjuE2P6X/P7/pe/WLWd0G3HU/sajyvUsjkeR6W7TJ5c5mvX5pr2Tnd74ieypS4cieFaTnEIhprORFYnVuuldHdCuFUd3H/87fTG8ZSvWDIklRumKxVMoceBrR6RraKvZ6zRscP6tJDdU86mrRrshVWCtycXmNupcD4UBRdlSskVQ9RmCK23EgQi+ioisxa3qInotPfEyQW1tFVGUWVywTn4XoutrvxWHY53mlawp3iemDZHiXTLQtIyUSCGq7nMbOU5eKVRV57a4aXv36btZzJvfkscyuiIYGvm62p7U29pspkciZ357MsrBbK7hAx3lBAIZUpkc2LKsWCYRMOaSLPPKCItmuuh+W4DS0ZPy86GtZZXisSDes1n7swOYftuOVUOY3BZAhNlckWrKbH9L/XzospZFpF+zMPD9t1sWyhia6pMpoiNyT7agwmwxSKVqWtmq4p2LZHTzQgSvXnc9sinPqFp2i6G8biLxjV2M8NAxqh1TOyVez1nDU6flCTHqp70rHukeqWXfMrOWRJYnmtwO25DJ7nsZIuAWVXhSLcIL5TQ1Hksg9ZIhxSUWQZx/U4Mhrnyu0UuZKFrspYjtAhiSZUsgWbXHGj78vXuPY2/EW4YPp7QpRMh1BQ4ZU3Z1jLGXiui+NByRJaJ5IErumQcSEea6yb3WwOlteKqOVmvYWSSb4kVhDV89A1mUzeqqQDRoIapuWgqTKW4XL2kf7K1naz0m9Yb4nlE+vEaILFVIFERMeyXHRNIRTQCOoqmbzJiYPJDT0RfbTrB/3hjzzCf/vSJQwLTNPBLkeRNUVidrnAcG8Ix3E3dSVVKwTGIlpF+vXIeIJENEDJdGry0dtN76sPwIV0eYPC4oPWvr5f2M1sm72es0bHL1keYw/RPelYS9t/aVbTBgFNJRLSMW2XUsmiZDo1WSJ2FWEDOK5HX0+Ikf4IuYKFLAkXyOxynmhYI6AKnWa/BVauaKPIEAvraKokgpHl/yRpnV81RSjQ+da9JEmEAhq6KnPtToqVdLHsrnExTGGyyz7pe2BYokDm9JG+im52qzkY6YvwvhODJGI6pi3876oCuq6wmCpRMm3wPEqmw/hQlGzBYGYpRypT4uLlBeZW8ptalpu5MgaTYfp6gpi2g2k5eHgUSha247CSLm34znxKLHq+9VuNRi/m0+fG+f5vP0I4oFWKolRFIqCrDCVDKLLMnfksN6fXmo7fzxIpmQ6357KksgajA5FK5evR8R4+8dGT/IsffV9L/3g12rEId9t18F7AXs9Zo+OXTO+huicda2n71prfpADAdT0MW1QqbgbX8UhlSvQlgriu8JnKHuiqwtJaEdP2ZVoRllhII19uWCAhEQkq6z7yqtXA9cSxFEUQt+fBcrpIvmDiuMKXLsvrC4htr/e1VBUJz/PWrd+4sAQ2C9jdnF4jX7QoGDaG6dKXCJDOmZjW+vj9a7CzJdK5Eo4rFq1omXCkhg6kdWyW0iWs2DwHh2OV/HTPg2RMFCzVf+fyvTTPQSVrJnMnhem46IpMPKrzyQaperfns3zozDDhkMafff0Ohu0QDijCBVOW0lUbBRqqjvHVb04TCqrkSyam7TA1m0GV5YaNmttFOxZhV4tk69jrOWt0/KdORVsefz9lAXUsafsvjb/l1zVF5E97UChZ5YyPxt/1gFxBWHoDyRBjgxHuLeTwXCFoJEkiQCnJYFkuti4I0Ci7ATyvsRXsuH7gsYqYnXU/uOOud4cHQfJqmbRdz0OVpYpl8T0fPNRUV8Xv4r2YKiLhEQ5qLK4WMCwHWfIqxAzrTR48Dy7dWGWkP8L7Tg6SjAVJZUrcmknzmd/7hkjRC6obOsjXuzJS2VIlPRDgidNDvHV9mYJhM9IXYWIszju3VmvK8UFYofdm14udKotF+SY1Wjw2uCCCKsW0zfxKAVWVcV1xN13PqzQBrkd10+GRvgiprMFa1mBmKc9P/ui3bfvFq88RL5oucoMc8W6hztax13O2mSRyI+y3LKCWpP2Zz3yGl19+mZmZGb785S9z/HhjBbvdxNRsmlS2xDffXcC0XQzTIhjQ0BThurBt0FUJ027M2h5+BoVFXzzAjXtpsgVhoeK5gEwwIGPaLqoqZFnVcuPbEwd6ePvGyqbjqz6rZZVFncr/U+/1cF3RyMD1IKgJvY/TE7289OoUf/XWLKoiEQ2qzK3kuTOfpS8e4KVXp0jGgxwcjnJ3PotVVh30XI9ap4NYGHrjAUJBjemFHJoqGu7KksiqURWJVLaErspkCyZBXalZGBZWCly9k6InGiAR1ZlZEgHT3liw0jg3oCt84NRQhVx7ogHyRbOmd6avPQKCjIf7wjxyoKcyzkb50fULhlhoijgeyK6HUraWPU9IGnz6b5/bcC+u3kkRDghBLxB+d8O0sV1vx2qI1Z3eHcfdIPPaiS90F1vHfisgakna3/3d380nP/lJfuzHfux+jKciX+q4LqoqYTlCZ8RzXVxZIRHTWU0bFY3sZtAUCdt2uTWbIRbRSER15pYLSEgoighWqp4oULFsl2//tjEef6Sfuws53ri23HKcsm/pSxK6KpV1QiQcz6v8TdNk+ntCRIMa8ajO+w/D6dMjlVZisgSe67KYKhIKamiqRLZg8vq78wR1FV0X0rKO6+I4IrjZCLmSXU4BdCmULMJBjbmVPJbtltMByzEBy2E1YzAxGuelV6cwLIfeeIBswSRfNLm7kCnng6scGIpVHt5Lt1Z44tHhyvnGh6JcnlqplOPXa4+0G4j0d1OWI1IhU9kSRdNFlSUS0QCOK/LgJQleeWuGZDzYmCzrjfhNPELNtsGNdj3zK+uL2+/80coG/fC9sMT20zb9YcH96KC+m2gZiHz/+9/PyMj9c+JfvlesBCDj4QAHBmMcGIxxcDjBhx8bJRLUOTISa+oaAUGomqrglXVKSoZLrmARj+gEAyqqIqOrCslYkOHeCN92YoBkLMg3ryy2Pc6ArhDQFYK6QjCglC1CMahQQOX0RC8fOX+A8YFoWdZV5vK9Ii+9OkU0pGE5LgGtXFgiiwVGVeRKW7SSYaPJEookUTQcNK25Hz9XMFnLmvT1BPEQ+eau65WtbZuemOiarqlCZCsc0rh6N0U0pDE2GOPU4V4iIR3bETnSJw8lAZi8sczkrRVSmRJzS+v5zslYkEMjcRLRAFMzaaZmMxRKNpfvFSsBzHYCkU+eHWF+pcDkjWUM0yYcVMupmsKNlS+KoHPJEA0oGuV7nziYpGjYIlDqCZIvGjYnDiY3zFOzoKsgbJGPHo/oWLbL3YVcJXPlwuQcQV3a83zsdvPbu9hdtPu8dgo6zqftt6yqDkBWkw3ActqoKPf5+tnVkCQwy70MPUQA00M0o11eKxKL6Hzn+QMVkX3TcplZzDK9mOMb78y3HKNfzJIrWliW+FmRhZxqNKQx0h/lUx87Cwhp1qHeMOGQxtSdLFevL/HYsX4iQQ3H9UjnTWQJHNetFAz19wgrU1RRKpQMC9dtnm1iO2BYNmdGe0lEg0wvCIJVZImgrqLKYm22bJdIUKs8oP58+s0QPM8jkzfJ5A3emVrFc/1sDoWrd9cAKqp+iizz3IcP8/q7CzXXt1l38kb+4L5EkHTOwHLcsl86zPJaUcgLlMW63LL/3nLcDWl3z354guV0kUzOJG/Y6IrMcF+EZz88sWGemm2DX3z1ttBDj+pISGVXC6ykS5XFMqjVmu/1lthuWMj7bZv+sGC/dSSSPG8zm3UdH/nIR/jsZz+7JZ+2YRhcunRpSwP687fTFE2XhZSF5XhoilT5dyip4Tguk7dFap2LcFH7peWViyr/6/9OLqeTlYskkSUYSKiM9wnSKlkecysmRctDkSBveC3TqAMq2K7YqkgyqDKM9OlEAipPnYoynNQr1xLS1zc012dFgG8kqTK9bJEu2BiiWQ2aAgFNJhaS0VWZcFBmdsUkU3AraYj1xT8AigS6JhEOyBwe1EGSmFk2SeVEKmNAk8vFPB7RoEyu5OG4HgFNZqRXIxpUyJUc7i4arOVFOqWqiCCq43mEAwphDbJFkTkTCyl84HiEpbRN0RSum+WMQ8lyUWSJkaTKB47HuHyvSLrgkAgrnDoQYjipbxj7F7++SjwkV5oX54o279wtYdniHiiyWDgiQZlIQOHQoE6m6PIDH+qtHOOVd9J841qeouUR0iQ+cDzC049uJLn6c4GowLx0p0g0KGO7wq3m/75gejx6UGyb6++j//N3PZZgPmXy2uUcrueSyTsUDDEP334mRn9ca2seNhtf/fV2sfuYT5lt36f7iTNnzhAIBGp+d18s7UYnbob51AWuLuocirjcmcsgSxKaAoeGY5V2WzcXJlnLiZxgWQZdFkHJCtH6+dVlgnO9ddL2fy7ZMisFVWhaK4DiEZLLvm7dYS27uciM5QgXSSigkC/auEhkigo//Mwpnj43zitvTvPW7SUs2yYc1Dl+sAfVzXH2uFDxGx0ZQg8XuHh5ERDl3JIkUzBcSqZHLCKzVnRJxMIoikm2aKHIEo7p1iwoiixxaCSGUZapnVlzKZYsHBdUVSUW1oT7wBW1/St5ODqWoD8Z5t1bK8yn4VAownw6SyAYZDAIiysFHFe0aOtNBLEsVxSvxHQ+/PgohaLFStGi5Nn09Opcu7NGIBjA8/IEQmFm1iz+3unTPPfMOnG+8uY0/+vV25UOOM99+DBPnxvn6vKVSmpdKlNipZDDQ7SNUxWx2IQDKoblsJKzCQRVjownOH/+XOW4k/dW6e2JgASpjMErlwsUvSgf/54TNRZq9bl85Aomo8UMvYkAd+ez6KqCpol2bFENfuij5wH4nT96naHB/oolJhetik/7+Zev0JNUuTufJRgKEouL7796pcT4oM5wXz/DI+J7VxctTp9u7AtvNr6xMZ3z509u+Pxu4uLFi5w/f35Pz/Gg0O61PXcfxtIuNjN4O664ZjgpymTHBmIM90WJhHSGesOMDcYq2+4jo/GK9ey6YJZT8CRAV4WLQlNk1LIbWJWlDZZzOmeSL1pML+VYXisKX6hpkymYwp+sbp7f7AG2LcjddkQ+sSzD6+8u8L+/eo3fffEy4IlSctvhzWtLLKYtdFXh3PEBoa9xYwXPE5awIkk4rovrCSXAtZxJvmAyv1ognbcqZfOwLkMrAZGgcH8kY0EODcdYSRs4jkdQV+lLhJBl0dw2Xi7lt22XmzMZcgWTM0eEm+bdqRSRoMaZI/3oqkIiGiAa0lAVmZCuksmXKBo2uYLFOzdXsGy3ktt+eyZTKR+XJAkJiXhYr/H3vvLmNL/74mXyRZNkTGSe/O6Ll3nlzelKMcTMYpbLt1fJF00CukIyJsYgyxLpvMj80VWZgmGxvFas+HlffPU24YBIB13NCGGxgCZzeWp1gz+4WWHHcx8+jCLLHByOoaky6ZyJh+j16aePPXUq2rSUezFVYGWthK4qlabH4aBGvmiTzhlt+8K7xTpdtIOWlvZ/+A//gT/90z9leXmZv//3/z49PT28+OKLezqoZnmcz798hWhIY7gvws2ZNbIFW/itPd+nLOzpwWQY1/NwHJd7CzncJh6gVKaE60G6/GehWULlmJtBQljbwm0hyuhLpkMmb/D5r8yXt/Uylu2hqzKqIjG9ZDA8aFUWH0mSCOoKecOm2mXtF+7YwiEvxkXZHQQoksiA8VwPJAnTdjkyluDWTJqgrjCQDJHNW2TyBiXTQVFEezZFkggFVSzb5Z1bqzx5doRvOznIX35rmm87OYgsScLX7ohGDSXTpmBYZAvCyu9NBIRy4J0Uxw/1EA6qLKeLxEIaHh6W4yHZDicOJWv8vT6xRkLlDj4hHTB58dXb/PKnn+Zj3/kIn3thEsf1iEUDjA5EK6mHxaw4t+249CfDHBvvQVPXVRJXMkWSsQBLqVJ5XmRkRaJk2Bv8360KO1589Tar2VLNTsDHcFLnuSbW7mAyzNU7KRLRdQvZKhdWmXZto47NshK6xTpdtIOWpP1zP/dz/NzP/dz9GEtL3JpJkyuYFAwbSVaQZVswGRLhoIpp2rjA3FIWF9Hn0aM5Afv+Yd914iHyrEU5tSDlZnCriN51RfPZkK5w/d4apuXQE9PL/nYXF0Guhu0R0BT+5/+5juO4FEs2judtGJ+/a/Cqf5bK26Iyg0dCGoYpMiaOHxJElimYxMMaSynRHEFTZfJFC8PyiIU1ApqK44qGwGZZcVBTZPrioUqz2/HBKFfupIiEVEzLZWGlIIS5JIls3iIeFUJbt2cynD81RG8iyLtTqyytFbEsm+GQTNGwGRuIVa7HJ9ZqhIIqK5liJYC3kinSGwtWtLpjEZ1781nmVwscGU1wYDhGMhYsz/16f86+eIh8UTSa8CsnbdslFNAaEmQjg2BqNs3r7y4wMRqvdIHfSrMEv31ZoSjSLS3bxbRFo2ddrc36aZWV0C3W6aIVOs490gyvvDnNpVvLXLubYmYxRyZrYDvCp+u7MhRFRsajYLoYltsymOij/nNlA7dFAfh6A2HPE+dezZTIlzMzcgUL1/MIqAogLEXHFQtPNm8IiVN3I2E3G5OPeFgnqCn8rScn+MCpIQ4Nx3Ecj2hY59yxAZQ6kpBkqdzhXSIe0dezUnSZtZxRcQ/42/JELMDBIZGmONIv/MQDPUGCuoplOyynChRN0b39ybMjPP5IP4bpkIwF6I0puK7L1TtrHBxa75rZFw/VNLEAKJZswkFN5KwvZbFtl9tzGS5MznF3PkMyFuTIWIKjYz0cGUtUCBtqie+5Dx+mYDjYjkOuaJLOGeSLFqMD4bbTtnYqFzoxmuDjzxzDAzJ5E02VOTgUZaA3XMnr7ro7utgtdFzKXyNMzab5nS9domQ45fS9dUqzHA/V8zA8cD23XC7tbcgg2SqcsstlM10nRRKfCwYUNAWKJXG2nphKNm+RzpkVLW4QK+T8Sp5cySJYLsv3y9E3gxCJElKzjiv81bmCiaLIfOpjZ2v6OL55fYlISCWdNbEcBwmxqNllDfFEJEA6b6BIEj3RQMU3OzYYq2zLxwZj/PB3H+fC5FwlwBl1RTqgYdrkCzZPnBlmYjTBhck5ThzsYTVjsFQskojrHBgSnXt8PPfhw2UfvylK1Us2BcPh6FgEx3W5O58lFFAxy/nlkzeXURSpJq0QGqdjPX1unMXVAn/41Rtla1umNxEU6XqqskHvpBF2o7ji6XPjNXNYrXzYdXd0sZvYF6QturKbKKLHL3VuQiRJQlWhZFDxX3v1PoaNP7ZEUFcqUqj1CGgywYAKnotpe8K/jkj9K5kesixju26NJe0iur57QMFwWlryPiJBVcitlkltpD9CNKxXCKA6R1hVJJbXhK9ekWUCmuiz5iKswHhE5/CIyMSpDqbVb8v9FmMlw6JgOCSiAQaTIUzbIVuwePYpkQe9mCowMhBlbDDG7KzJ6Gg/rudxa0Y0EPDJ6tknD/GNy4uV7JEfeeYw37yyyPxyvhLA01SZTN4kX7JZTRuVBameDOuJr2S5fNf58UplZb5koSmiGrUdgtwtudBmro0uSXexm+hY0q4moouXF0QaH75Oai312o6H7XhomoRplUl748dQFFESb9sumioqGJvplwgtkca0GtBkHhnvYTFVoGgIFTpHEdkjhuliG3bD74FYTGRZWOjtLCBy+b98yUYvN1vIFS1SZUGnerGbd24Jf2oooBLSFSxHVAkeP9jD6Yl+MeaSxcxynl/+H683DLr5x9QUGTWso6kO6ZzoWKPKEpGgWtHg0FS54g/3MbecY2E1z2AyVKnsmy9aNbsCEK3UqgN4oYCGIssMJmWG+sJNF5R6+JayLEk1fu/ltWIbM7z/iiu6eG+jI0m7nojcctqE44qGutUWs1zlfmjFgpIkcWAwSqFkYZguqiqzmGr+Ytvl1DbXcTBsFxCCVT2xALGoxpU7Qgo1WC5nNxtVvjQcR1sfA4SFrKoyYz0hltMliqZNybT4y5U8F96eQ5IhoCkM9IQZH4pSMhwCqoxhOFiWi6LIxCOio/0nPnqykn4XDig16XdAxaL9+qW5sqUaZHZZuAgUSWItZ4AHJw72VMh4ea2IhMQw6415787nOFilXWLZLjOLOX7jD97gQ2dGODgU5e5CTvj3Cyau45CMhyoBvOG+8Jas3J1aytvJ2thuBWSz73U1R7poFx1J2vXlvAPJEDOLdkX6tJqbq7M4mlnNPmRJolASmhbBgIa8SRhWkqCvJ0SxZGEic3hQZEPICPnSyRur5bxkYclalrOhMrMRPDa6d1phNWuyljdRZFlY6bbI5y4aDpT93bIkkSuKfGbbcZEVmUQ58Fgo2qiKCFC++OptFFlUHqayBpqqoKkS/+ur1zk4HCca0vA8McuzywXiEZU78zlc18OyHPp7QqxmTdJZA4BMziRXNCkaNpZhMzamM5gMMTIgApF+70ZNFXriM0tZvvrNe5w42MPh0Ti2I1IIHVd0ARruC1MwbFLZEr/5+TfaIrZmlvLpid4aF81mRLiVrI3tSnn6Gie2IzJLLFtI8fopoPtFGrSLB4uOJO2/vjRXUanTVIXxgTDJRJC1rIEsSxSKNooiNEls28Mpa1VvRtqSJBrwprLCr/vIeIJvbaLmJ5db1kiIgF1IV0llS4SDKtmiJcroqxrS+p3f24Ffmr8VuC4V/ZHanYbYeaxmSkRCKkVDLB5SWfEvoKm4iN0AiECoaTmVPoyO62IUXFJZg7HBKK+9PctKxqicV5VhdCBGUFdYTBVIxkWF5PV7aziuh6YKrY6J0Th3pouV4Jtv+fq9G0umTdFwePv6Mh4ws5RnbDDG8YO9RMoCYUN9YTRFFgJZdYp6mxEbiOrUS7eEpO6Jg8k9JcLtaISIJsnXkIBEVLRwuzuf5eBwjBdfvc3EaLyrOdJFW+i4lL9X3klzZ16kgMkSlEyb69MZ8gWTvkSQwWSYaFgTjXQdkcqGJxrtbgZNkZFk4d5IRHUyeVG00Qy26zG7nGc1azC3nGd2OUsmb7KSLmHbHhIeWrks0/Nom7BlqZyGtwPU7DRc4c83LJfVjFmT210oOZRMi1hIoy8hsiMkSSpresiAyNDwEON/48piDWGLeYDZ5SzZoklfIoRliXjASlpodEtIREN6Tdfr6sq+XMmiaFispEuEggpeeQ4WVgukMmUdlv4oQ31h/sWPvo9kPMhwX3hD+t2Lr95umJb30qtTFf/7E48Oc+ZIH4bl8NaN5T3r+r2d5rQXJudwHNHQwhel0lWFlbUSK5niQ98guIvdQ8eR9jeu5UXGgyRhO17FkswUbFbSJQoli0REp2QJP7NhicqzVpRp2qLTd8lwyOZN8iWLUKC9jYbrwWrGxHUdsgUTx/UwLA/DFp3Z23GLQLWF7LWdOdLymNU+/fI5NFW4TDzAtDxGByIcHe8BYKg3hOu6ZRVE8a/ruiiKtN5irQ62A+GgxrGDPZi2Q6FkCTeV52HaojdlKlNiIWXytTfucWFyjidOi6YJEpDOWyiK6CpvWULNUNdkphdFWmC1/7kZITYjNl9itp6cr95J7RkRbkfKczFVqMi++tA0mXTerBQ3beV4Xbx30XGkXbQ8QkFVNNAtl6X7nKQqMobpkM6b2FUPf6uGCNWQZdHGq2RYPHIgviXyzJfclhZ9s+MFdblSnm5Z3paCkZuhOqXQbz3mOL4sqwwSzCwVKsUuZx8Z4MShJJIktDrSOZNiOWi5GQxTpP0dHI7hITJIDMtBUWTevr7EK2/PMr9mUShaXLy8wFe/Oc3BoSgj/WFyRZOSKYheloUvXlMlciVrQ8GJpsp868oiX780x+SNZVLZEoWi1ZTYgIbkXP336s/vBhFuRyNkMBmmLxHEtN2K9nehaKEqck1xU7cIp4tW6DjSDmkSpiW6pAc0mYCmVsShApookbZsZ9Oil83guMIynVspcuX2GgPJYOsvbQF+yTmUy+Fl4cOWJQl/yJLEpkHQ7ULXxEH9whTH8dBUhRMHeyrFLk+eHcEwXYpFk7J0Sdla33wxCmgKy2tFxgZi/NQn3seP/61TGJaL67gUSjalkoVlQzQi/LG359L8/p9cYTlVIhpQkZAolBxUVaG/JygWLqgRX5qaTbO8VqRgWKiKhGHaXLy8wMUri0gSTN5cZmYxW0NsJw4lG5LziYPJPSNCP9ukmYBUIzx5dgRFEZWSfj66B3z8mWM8fW58y8fbDUzNilz63/z8G/z52+lus4V9go4LRH7geIS/ulzA8zwcx8Px1jWwM+UO3e0pgLdGKmsQC2ub9pvcLiSgLxHEcV1Kpotb1okNaDK24+J5ErLk0UZBZM0xdU3GcVx0XcE0RdMEuVzK77jguQ6aJqNpwof/+CN9jAxEa9wCq5kSlitIU5aoNEx2m+xYIkGV3niQ7/+OoxUiuTA5x4lDPaymRRNdRZHRJQ/TdElGlbKAllDs6+8Ns7haFK6VoomMhiRL/Mu/876aas7PvTBJOmeI4KnrUTRsSqZDKKBx+kgfc8s57s7nKJkOR8d7KnnUjTJH/ADlXnb93sqxqtMKNU3hsWMDNa3OXqySrX3/ycE9Iezq7BtNEXGJ4b4w/T0hptKr3YyVfYKOI+2nH01wYHycF/7yFpmcCXgEdRlFljEsZ9cI20e2YLHDuOAGqLLwrb7/9BC3ptNY5W1BsVBAD4ZYTRcxLUG4UptFNiAs4v5kiKXVQrmbi+9m8Rjpj5LOGuRLdtld5BIJKmTyFnNLOcbKKYsXJudQFDGfqio0vF1PtBlTygqHbp3L5dzxAXrjwZpshsVUgZH+KGMDMaYXs6JBgy12QSC61MsSRIIambwhXCMSuJ4oxw9VaaT4KXRrOYN4RMO2RdFTKKASCWmV2MbYQIxEJEA0rPOJj64r7m2WY91JBNSI6DfLm68ueNop6tMUv3VlkXzJoq8niCxJhHR5gypiF52JjiNtgB/8yHFW00W+/Npt0ePRdFFk0Q5qNbN5cwK5LjDXDmRZQkHoemz1u82Od+pwkkLJwrRdBpMhCiUbyZFYyRkoikxYkcjk2xO1kiQYTIriE8sPvHoIcSjHQZIkCiWb0YEYHzg1yEsX7hAOKISCKqvpQrm7vcHzL1/h7etLGKZdbhYsgpbieKKyMxzSsW0HTV3voDKzlBMdXm6tVIixuhKyLxFkMSVahIV1WfhsgZ6oUO27MJlFU2QCqggwxyM6B4djFYLwU+h6okL61W/3tZgqMNATEo0qymhXuc/HXhatbPfY1Zb1WtYkGlKJhER2T7Vs7W6Sdn2aouW4hAMq0wu5ShVpN2Nlf6AjSft3XniLL796u+IW8Tzhi15rQdiwdcKG9UCmIkvIbJ+4JUm4Eh4/NsDCaoF7C1ko+1MHkiFWcy6yLJoAFw0bSUJ0crc3P6emyoQCGkfHw0zeWEFVZGQ/UOspRIIKuq7xqY+drXFbrGaK5A2HRFTHdT1mFrPcns8SCsjoioRheZi2iyKtl+3LkoftuBimgyxLDCRDuJ7H5M0VRvoiDSshHznQQyZvkrNt9HI++IFB0cBCU2QiQZWSYWPYHsO9QYb7Iqyslbh+bw0QyoeHR+OMD0V569oihiWCdQXDplCyODAUJ5UtkYwFWwYTN3MB7GaudrXlqigSFy8v8OcX73Hu+ADPPjXR9Pj1lvVSqshqxiaoqySiQr7Wl63dTdSLYkWCGoZlky+txwO6GSv7Ax1J2n/6jWmRQ+1Ll5Zz5VrFHiXKQcA2PtsI7SjuNYJfkahpMv/xnz3NzGKW3/rDt9A10bHdtkUhRTIqY9ii5F1Coi8eFGTnWJsGAm3bZSVdZC1bwrJEMC8W1jEtF00VehuRkMbEaIL/359cJl/WJkllRdNgy3ZJZUtIkkQiopEv2Qz2RVhNlyiUbBwPxgYinD8+wFcuziBJErIsOtpn8mY5D100WkhnDaYXc6RzBrqm0JcIYjse508NsbS4SE9v7waFu4CuEg3pHBlLgARXbq8iAb2xILmCycJqnoAuEw5oSJKM7YiO6oosmhqkcyVee3uWoK4SDmp8/JljwEZL9+BQtKagpt4F0G4RTDvWs2+5WrbLtTtr6KpMLKRxazq96cJQ3xBCUxVs22ElXaqQdrFk0xcPbfjuTlBf6j8+GOXSrWXCAQ3X8yiaLnJXb2VfoCNJ27QcZLmx8FMzpb5KZ5cdujdaybE2goREKKigKjITowk+98IkkiQyLoQ0qrA+M3mLU0f6CAdVrt9bIxrWWMuaIpNks3N6YmGQEAuL5LqcPdpPMi62tf7LODWbZjFVxDQt8iUH23aQJCgaYNoyklSiJxZAkiTi4QCKLPovRsM6P/cPPsjzL19hrD8CCJ2XhdUCuYKN64kAqigwEsJVsYjG8pooUR9Mhjg63sPhnhjPPfO+mqFXl5lrqsytmXRFPPdAWZ/k4FCMu/OicjISUimZQmt7MBkmXzJJZYyKwJdPzMCGiscvfOU6vTGdpVSRfMkilTGIhdt3AWxWnl4P33J95+ZKpd2ah0ehtN4x51tXFvjjV2+TK5pEQzrf++HDGxpC9MV1FlaLGJaN67kV2dofeeZw6wdvC6gv9ddU0bW+vyfE8lqRkC53g5D7BB1J2rqmYJjCfaBIEm5VdxdZlggFlJpCEFkSZcxFY4uiHg2wVcKWJYiFhbUy0idyoe8upCkUbXJFcQ2aItp8lSyPI2MJDg5FuXRzhRurBZENw+a+eFkROeqibFyQw62ZNL35EtfurFE0bEYHotyZS3NwOMpb15dBKpf5OyLIGAvrrGYM1rIG4aBW6RBTbX0tpgpMjMW5eidF0bAxTAdVEZk1iizxzq0UAVVC11Vcz8N1hTsoX06te2c6x+nT6Q0vfnXmxGq2RG8syIGhWGXRGRmIUjIdltaKeJ6H63oMJEOEAior6YKQvFWENG04KAinUel3oWSRypQY6g0TDqikZYO1rFlzTzdzAWxWnn6iv/azvuWaL1mEy0ValuUSCYqOOX99aY6Z5Ty6KhMOioXo+T+7RiIs9MR9S7snFqJk2OQNISXgy9bupj8bGotifbIqYHvx4sUuYe8TdCRp/80PjPOlV24jeSDLXsW6Hk4GSeWE+H9QkyiVZVjDQYXeeJjVTIlcXc7uXkP420Vl5nMfPswrb06TL9gVw9nzhJCVnbfQFSqWYn9PkJV0EbncSKERYUuAqoIsyRUdluG+ACXLYTVT5O5CloCuMNIXxnVc3ri6xLkT/UTLrcgcxwNX+KbzRVO0p5dkNFXizauLBDQFw3Y5d3yAqdl0hYhOHu7lL96YAYSUbTwsMkwcV1SBBgMi8KmqIq2wYNg1ZezNNKX93zdS5Ds63sPR8R5yBZOp2QzTi1lu5TJVE20RDemV/pQrmSKPHu2rOYfrejiuWwlkJmMB5lcLlAwbt1zMspnk6mbNEOpJu7J7UGRMWzSbMG2HI+MJCkWLuZUCuioT1MUrFtRlwKZkukiyQ3VDCEVV+f9+/6ldJ+p6dFuZPRzouOIagH/0scc5cSCBLAuhJFmWOHEgwQceHeH9p4Y4MBgjmQjRE9UZ7g0xkIwQj+g8dqwPXRWX1CyNTzTcbS/Hzy/q2QweIr3t/MlBnj43zouv3kbXlYp/3f++60FfXOXFV29z494at+eyonhIV2vGKgHxsEoiqhOLCB9vJKQxPhhjqDeMqshEg+L3I31hDgzGCAd1IiGdgK5w7c4ayViQ3niQwyNxhvvCSEi4LkTDAR4/1k8ooJLJm2QLFo8f60dTZF742g0ODkXJFUUDAUWWiIY0ArrCQG8YTVUquwFFER3eg5rCWsaoZHcENanieqgu3Hj+5SuVwo3Nqgn9v61miqzlaoPOpu2RKZhkcgZXp1YbVkiKoiKpnL3iocjCzxwKqm0VrWylPN23XI+MJcgWxHdOHEqiKTK5oohR+MVOPnRNxnJcfuK5U0RCOqmsQSSk8xPP7T1hd/HwoC1Le2pqip/5mZ9hbW2Nnp4ePvOZz3D48OE9Hdinf+TbKv7F6qKJj3/PiZqCjPrPjA1EuT2fEdV+rGtzyGV3waHhOLqmcO3u6qYFNYpMZdFwXFHZWC+XLSGyP1RV4ePfcwIQTWx1TUGVJYqmaI8mS+L8BdNDKuciz63kUGSZkC4TDgZZyxlIkodtg6oo9MQDWLZDabWIUq4ORIKiYTPcF2E1axAK1t6+vniAuZUCvfEAd+azWLZTrjrUCAVUTh3uJRkPkslbjA1GsR2P3qqA192FXGULrSqigrO/J0RIV/E8j3C5k3syFmA1I3pN2o7H+KBwC5Usj7FkuKV06WZ51R/7zkf4qd/4i4axC9sRlvRCqshH3n+A2/NZYL2oJhTQGBuIYFoe+ZJFJKhx/FCSsYFYTV53M2zWDGF17saGz0+MJvj03z5XE7zsjYuOQpdurlAy7bKFLWBaLtGQztPnxrsk3cW20RZp/7t/9+/4xCc+wfd///fzpS99iZ//+Z/n937v9/Z0YO0I0zf6zL/8xPv4P399m5cu3MFDELWuClfKxHCUoulg2Q7xiM5y2mh6/kPDMVYyBobp8OjBBNGwzmtvz1dJsYpAY6SsP+2PSzSxzaKpCsGAsEAt2xGVfbq8nousqtiOIwSvJBGsVBUJJSgxOhAVgla6xvtPxohHA1y9kwLgzJF+nv3wBJ97YZJ80az4RkG4YkbLrb9KpkOhZBMJaeSLFr3xQMWHnC9Z5Q40G/Of/S30waEov/viZVxHVHN6CPfDiUM9uK5EQFcw8g598QCJmJBQLZlepTP5ZtKlm23TJ0YTeEBAF7K7dp3fSJFlhpIBSpa74d5//JljleBkNem2W7q+2TO3uok4YKPr+d4PH+b5P7sG2OiajGmJBg8//F2H2xpLF100Q0vSXllZ4d133+W///f/DsD3fu/38ou/+Iusrq7S29u7p4NrxwfX6DP/6GOPc3Kir6Y0+AOnBilZLjen1yiUbAZ7I0SCOe4s5MoBT+GGsV2PgZ4QRdPhwGCsphXX//VfXtlAlPU/P/fhw/y3L13CsBx0RFDNLDfVHe5RGByKcuX2KomozmKqgOs55TZiCtmizaOHkxw71FtTjt1oDpo1y/2J545vsOJ8yzdXMEXmgCKTL1kcPdBT+Uy9G8A/hj+H/T0hAqrMoeFEhRDnVvKV7IPBZJinTkWZGE3wpb+8uaNGubom0uCU8v2oRjwa4PihnpoFphqt+km2wm75fX/wI8cBarJHfvi7Dld+30UX20VL0p6bm2NoaAil3PlEURQGBweZm5vbc9LeCdrdgv7vr17bkJbV7MVqRpTV6Vn+Of/XV6+zsFpAVRTed7JfZG+sLpOMBTl5uJfphVy5gMRBVSWG+6J8X3lhaYdw6kl1s6yDegvyyFhCdCtX5E0DdPVzWJ/D/Mm68V28eBHYefuvv/mBcb786m2Qa1M8++IBTh5Koqky0bje8LudFGz7wY8c75J0F7sOyfM2z2y+dOkS//pf/2tefPHFyu+effZZfvVXf5VHH31004MbhsGlS5d2Z6Qdgkt38nzjWp5s0SEWUvjA8QhnDkVafm8+ZfLa5RxBXapkvpRMj6dORRlONiagvcR8yuTyvSLpgkMirHDqQGjXxrEb1/on31jhzakihi3iAQNxhRPjoQc+b110cT9x5swZAoFAze9aWtojIyMsLCzgOA6KouA4DouLi4yMtC9x2ejEzXDx4kXOnz/f9rHvN86fh7+37W9fIO0kWUwVGOuA5q3P7fLxqu/d6dPrVvl2rrX6Eai28B/UvHX6c7lTPMzXtx+vbTODtyVp9/X1cerUKf74j/+Y7//+7+eP//iPOXXqVEe7RjoVw0md5863zmJ4GLCbbopOcnl00cWDRlvZI//+3/97fuZnfob/+l//K/F4nM985jN7Pa4uuuiiiy4aoC3SPnr0KH/4h3+412PpoosuuuiiBTqyIrKLLrrooovG6JJ2F1100cU+wp4KRvnZhKbZunlBNQyjeaXifsfDfG3wcF/fw3xt8HBf3367Np8zG2Vkt8zT3gmy2SzXrl3bq8N30UUXXTzUOH78OLFYrOZ3e0raruuSz+fRNK3Sb7CLLrrooovN4XkelmURiUSQ5Vov9p6Sdhdd/P/bu3+Q5No4jOMXDRINaUs0OERBES1BQUtLp8EG07MJh5ykBqmhCCpqCCvIhihIaKgxmgIJaiiQIINCqOVQREhFYA2SYP8gyPsZIunhHd7e11tuj1yfybN9heOPg+L9IyK5+EMkEZGFcGgTEVkIhzYRkYVwaBMRWQiHNhGRhXBoExFZCIc2EZGFlMzQvrm5gc/ng8vlgs/nw+3treokacLhMDRNQ3Nzc9n9QzSTyWBgYAAulwt9fX0YGhrC09OT6ixpgsEgPB4PdF2HYRi4vLxUnVQUq6urZXd/apqG3t5eeL1eeL1eHB0dqU6SQ5QIv98votGoEEKIaDQq/H6/4iJ5EomESKVSoru7W1xdXanOkSqTyYiTk5P89cLCgpicnFRYJFc2m82/Pjg4ELquK6wpDtM0RSAQKLv7s9zez7eSeNL+3vjudrsBfG18v7i4KJsnto6Ojv+0ns1KHA4HOjs789dtbW1IpVIKi+T6ee7Dy8tL2R3H8PHxgVAohJmZGdUp9EtFPeXvt6y68Z3+lsvlsLW1BU3TVKdINTU1hePjYwghsL6+rjpHqpWVFXg8HjidTtUpRTE2NgYhBNrb2zE6Oorq6mrVSQUriSdtKg+zs7OoqqpCf3+/6hSp5ufncXh4iJGRESwuLqrOkeb8/BymacIwDNUpRbG5uYmdnR1sb29DCIFQKKQ6SYqSGNo/N74D+F8b30mtcDiMu7s7LC8v/+NUsnKh6zpOT0+RyWRUp0iRSCSQTCbR09MDTdPw+PiIQCCAeDyuOk2K7/lhs9lgGAbOzs4UF8lREp+unxvfAXDju8UsLS3BNE1EIhHYbDbVOdK8vr7i4eEhfx2LxWC32+FwONRFSTQ4OIh4PI5YLIZYLIa6ujpsbGygq6tLdVrB3t7e8Pz8DODrmNO9vT20tLQorpKjZI5mTSaTmJiYQDabzW98b2hoUJ0lxdzcHPb395FOp1FTUwOHw4Hd3V3VWVJcX1/D7Xajvr4elZWVAACn04lIJKK4rHDpdBrBYBDv7++oqKiA3W7H+Pg4WltbVacVhaZpWFtbQ1NTk+qUgt3f32N4eBifn5/I5XJobGzE9PQ0amtrVacVrGSGNhER/buS+HqEiIh+h0ObiMhCOLSJiCyEQ5uIyEI4tImILIRDm4jIQji0iYgshEObiMhC/gDHFDItUVJ2JAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "samples = jit(mcmc.sample_chain(mcmc.hmc(\n", " ppl.log_prob(model)), 1000))(random.PRNGKey(0), jnp.ones(2))\n", "plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)\n", "plt.show()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "A Tour of Oryx", "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 0 }