How Does a Mouse Find Cheese? | Bellman Equation Made Simple

In the video we will explain how the Bellman Equation works in a deterministic world.

Here is the code snippet you can use and run to verify the values of the state in the 3×3 grid world.

def value_iteration(rewards, gamma=0.9, tolerance=1e-4, max_iterations=1000):
    # Initialize value matrix
    V = np.zeros_like(rewards, dtype=float)
    # Set terminal state values
    V[0, 2] = -1  # Cat state
    V[2, 2] = 1   # Cheese state
    
    for iteration in range(max_iterations):
        delta = 0  # Track maximum change
        V_prev = V.copy()  # Store previous values
        
        for i in range(3):
            for j in range(3):
                # Skip terminal states
                if (i == 0 and j == 2) or (i == 2 and j == 2):
                    continue
                    
                # Get values of possible next states
                possible_values = []
                
                # Check all possible moves (up, down, left, right)
                # Up
                if i > 0:
                    possible_values.append(V_prev[i-1, j])
                # Down
                if i < 2:
                    possible_values.append(V_prev[i+1, j])
                # Left
                if j > 0:
                    possible_values.append(V_prev[i, j-1])
                # Right
                if j < 2:
                    possible_values.append(V_prev[i, j+1])
                
                # Update value using Bellman equation
                best_next_value = max(possible_values)
                V[i, j] = rewards[i, j] + gamma * best_next_value
                
                # Update delta
                delta = max(delta, abs(V[i, j] - V_prev[i, j]))
        
        # Check for convergence
        if delta < tolerance:
            print(f"Converged after {iteration + 1} iterations")
            break
    
    return V

# Initialize rewards matrix
rewards = np.zeros((3, 3))
rewards[0, 2] = -1  # Cat state
rewards[2, 2] = 1   # Cheese state

# Run value iteration
V = value_iteration(rewards, gamma=0.9)

# Round the values for better readability
np.set_printoptions(precision=3, suppress=True)
print("\nFinal Value Function:")
print(V)

Comments

Leave a comment