# ---------------------------------------------------------------------------------------------------
# ECON 2010C Fall 2024 with David Lagakos
# Borui Niklas Zhu
# Value Function Iteration, especially for Problem Set 1
# Code Template only. There are often different (and better) ways to do things.
# ---------------------------------------------------------------------------------------------------

# %%
# Part I: User settings and housekeeping
# ---------------------------------------------------------------------------------------------------

# Housekeeping
import numpy as np # standard maths library
import matplotlib.pyplot as plt # standard plotting library
import os # anything related to the operating system

# Initialisation
    # I prefer to use dictionaries so it's easier to pass things to functions
    # and not accidentally overwrite anything

# Initialize meta settings
init = {}
init['N'] = 1001
init['conv_crit'] = 1e-5
# ...

# Initialising parameters
params = {}
params['delta'] = 0.1
params['beta_orig'] = 0.96
# ...
# new "effective" beta due to detrending appropriately
params['beta'] = (params['beta_orig']*params['gamma']**(1-params['sigma']))

# Initialising grids (will reuse them)
grids = {}
grids['k'] = np.linspace(1.0e-5, 5, init['N'])
grids['c'] = np.linspace(1.0e-5, 5, init['N'])

# %%
# Part II: Value function iteration (VFI)
# ---------------------------------------------------------------------------------------------------

# Define function for backward iteration
def iteration_step(V_prime, init, params, grids):
    # Performs a single backward iteration, based on some given V(k')

    # This code uses vectorised operations to avoid a double loop.
        # The key object is the value function V(k), which is represented as a vector
        # of values V(k_grid) = V(k_1), V(k_2), ..., V(k_N), corresponding to the grid for k:
        # with k_grid = k_1, k_2, ..., k_N.
        # Note that even though we take V(k') as given and not V(k), k' is defined on the same grid as k.
    # Consider a particular k in k_grid, and consider the maximisation problem in the Bellman operator
    # V(k) = max_{c,k'} u(c) + beta V(k') s.t.
    # c + k' - (1-delta)k = k^theta
        # Since V(k') is given, it's more natural to reduce the choice set to be in terms of k' rather than c.
        # So use the resource constraint to get c(k') = k^theta + (1-delta)k - k' 
        # For our given k, we can consider all possible k' in the k_grid and pick the best.
        # Make the k'-options a (1xN)-vector, and compute c(k_grid), a (1xN)-vector.
        # Plug into the utility function element-wise,
        # and add beta V(k'), where V(k') is made (1xN) too to match dimensions.
        # U(k, k') = u(c(k')) + beta V(k') for all k' in k_grid.
        # This U(k, k') should also be a (1xN)-vector
        # Our fixed k is reflected inside c(k').
        # Now find the best k' among the values of U(k, k'), within that row. That's k'(k).
        # The best U(k, k') is your new V(k).
    # Instead of looping over k, we can just stack all the different k as additional rows.
        # Everything above was within a (1xN)-vector,
        # so now stack the different k to make the whole thing a (NxN)-matrix.
        # The new k are reflected bt c(k') being (NxN), and using a different k in each row.
        # Maximising for each row, within each row, will yield V(k_grid), k'(k_grid).
        # Think about how you can get c(k_grid) too.

    # ...

    # return updated value function and policy function
    return {'V': V, 'k_prime': k_prime, 'c': c}


# Define routine to run value function iteration (VFI) until convergence: max_k d|V_b(k), V_{b+1}(k)| < epsilon
def run_vfi(init, params, grids, b_verbose = False):
    iter = 0
    diff = 999999
    # initial guess or V'
    V_prime = np.log(grids['k'] + 0.01)
    b_conv_success = False
    # conditions to ensure convergence or, if failure
    while (iter < init['conv_itermax'] and diff > init['conv_crit']):
        res = iteration_step(V_prime, init, params, grids)
        # ...
        if b_verbose:
            print('iteration ' + str(iter) + '; diff = ' + str(np.round(diff, 5)))
    # ...

    # return final value function, policy function, and result whether V converged
    return {'V': V, 'k_prime' : res['k_prime'], 'c': res['c'], 'b_conv_success': b_conv_success}


# Run VFI and save value function + policy functions
# ...

# plot value + policy functions
fig, axs = plt.subplots(2, 2, figsize = (14,5))

plt.subplot(2, 2, 1)
plt.title("Value function V(k)")
plt.plot(grids['k'], V_optimal)

plt.subplot(2, 2, 2)
plt.title("k'(k)")
plt.plot(grids['k'], k_prime_policy)

plt.subplot(2, 2, 3)
plt.title("c(k)")
plt.plot(grids['k'], c_policy)

axs[1,1].axis('off')

# we can cut off the first few grid points, to make the graph be less driven by the mangitudes with tiny k.
# plot value + policy functions without the first few entries
fig, axs = plt.subplots(2, 2, figsize = (14,5))

plt.subplot(2, 2, 1)
plt.title("Value function V(k)")
plt.plot(grids['k'][5:], V_optimal[5:])

plt.subplot(2, 2, 2)
plt.title("k'(k)")
plt.plot(grids['k'][5:], k_prime_policy[5:])

plt.subplot(2, 2, 3)
plt.title("c(k)")
plt.plot(grids['k'][5:], c_policy[5:])

axs[1,1].axis('off')

plt.savefig("PS_1_Q3_Figure_1.pdf")

# %%
# Part III: Simulating equilibrium paths
# ---------------------------------------------------------------------------------------------------

# A helper function

# k'(k) is by construction of this algorithm one of the choices on the k-grid.
# However, when we simulate a path, our starting value k_0 may not be.
# Then our policy functions will be useless for that first step.
# Two potential ways to remedy:
# (1) Interpolation
# (2) Find the closest k on the grid to use the policy function there.
# Let's do (2) here.

def find_closest_index(x, x_grid):
    diff = np.abs(x - x_grid)
    index = np.argmin(diff)
    return index

# We know the policy functions c(k), k'(k), so starting from the initial value
# of the state variable, k_0, we can just iterate forward
# Reminder: To apply Blackwell's Theorem on VFI, we do backward iteration until we converge
# For simulation, we use the resulting policy functions to do forward iteration

# define routine to simulate by iterating forward, save results
def simulate_path(init, params, grids, vfi_res):
    T = init['simul_duration']
    k_prime_policy = vfi_res['k_prime']
    c_policy = vfi_res['c']
    # allocate storage to track five variables over time:
    # (0) the state variable k and the four variables we want to plot:
    # (1) log gdp, (2) rental rate = MPK, (3) wage rate = MPL, (4) i/y-ratio
    # we will have to undo the detrend later, but I want to save the
    # stationary trajectory so we can see the difference between the two
    data_simul = np.zeros((T, 5))
    # initialise
    k_prev = params['k_0']
    k_prev_index = find_closest_index(k_prev, grids['k'])
    # forward iteration
    for t in np.arange(T):
        # ...

    return data_simul

# Run simulation
data_simul = simulate_path(init, params, grids, vfi_res)
T = np.arange(init['simul_duration'])

# Plot trajectory

fig, axs = plt.subplots(3, 2, figsize = (14,5))

plt.subplot(3, 2, 1)
plt.title("Log GDP")
plt.plot(T, data_simul[:,1])

plt.subplot(3, 2, 2)
plt.title("Rental rate")
plt.plot(T, data_simul[:,2])

plt.subplot(3, 2, 3)
plt.title("Wage rate")
plt.plot(T, data_simul[:,3])

plt.subplot(3, 2, 4)
plt.title("Investment-output ratio")
plt.plot(T, data_simul[:,4])

plt.subplot(3, 2, 5)
plt.title("Capital stock")
plt.plot(T, data_simul[:,0])

axs[2,1].axis('off')

plt.savefig("PS_1_Q3_Figure_2.pdf")

# Have so far used detrended, stationary variables. Convert back to BGP,
# see in PSet solution for formulae
data_simul_growth = data_simul.copy()
# ...

# Now plot transition paths of original, nonstationary system

fig, axs = plt.subplots(3, 2, figsize = (14,5))

plt.subplot(3, 2, 1)
plt.title("Log GDP")
plt.plot(T, data_simul_growth[:,1])

plt.subplot(3, 2, 2)
plt.title("Rental rate")
plt.plot(T, data_simul_growth[:,2])

plt.subplot(3, 2, 3)
plt.title("Wage rate")
plt.plot(T, data_simul_growth[:,3])

plt.subplot(3, 2, 4)
plt.title("Investment-output ratio")
plt.plot(T, data_simul_growth[:,4])

plt.subplot(3, 2, 5)
plt.title("Capital stock")
plt.plot(T, data_simul_growth[:,0])

axs[2,1].axis('off')

plt.savefig("PS_1_Q3_Figure_3.pdf")

# Now redo that with k_0 = 0.1

# change parameter
params_new = params.copy()
params_new['k_0'] = 0.1
# run simulation path (policy functions stay the same!!!)
data_simul = simulate_path(init, params_new, grids, vfi_res)
T = np.arange(init['simul_duration'])
# undo stationarisation (copy-paste from above)
# ...
# plot, but let's also plot the curves for the higher initial starting value
fig, axs = plt.subplots(3, 2, figsize = (14,5))
plt.subplot(3, 2, 1)
plt.title("Log GDP")
plt.plot(T, data_simul_growth[:,1], label = "$k_0 = 0.5$")
plt.plot(T, data_simul_growth_new[:,1], label = "$k_0 = 0.1$")
plt.subplot(3, 2, 2)
plt.title("Rental rate")
plt.plot(T, data_simul_growth[:,2], label = "$k_0 = 0.5$")
plt.plot(T, data_simul_growth_new[:,2], label = "$k_0 = 0.1$")
plt.subplot(3, 2, 3)
plt.title("Wage rate")
plt.plot(T, data_simul_growth[:,3], label = "$k_0 = 0.5$")
plt.plot(T, data_simul_growth_new[:,3], label = "$k_0 = 0.1$")
plt.subplot(3, 2, 4)
plt.title("Investment-output ratio")
plt.plot(T, data_simul_growth[:,4], label = "$k_0 = 0.5$")
plt.plot(T, data_simul_growth_new[:,4], label = "$k_0 = 0.1$")
plt.subplot(3, 2, 5)
plt.title("Capital stock")
plt.plot(T, data_simul_growth[:,0], label = "$k_0 = 0.5$")
plt.plot(T, data_simul_growth_new[:,0], label = "$k_0 = 0.1$")
plt.legend()
axs[2,1].axis('off')
# save figure
plt.savefig("PS_1_Q3_Figure_4.pdf")