From Certain to Uncertain | Stochastic Bellman Equation Made Easy

In the video below we will go over how to calculate value for a state when the actions are probabilistic.

If you wondered how do I get the values for all states, here is the code snippet for it.

import numpy as np
import matplotlib.pyplot as plt
from typing import List, Tuple

class StochasticGridWorld:
    def __init__(self, size: int = 3, gamma: float = 0.9):
        self.size = size
        self.gamma = gamma
        # Initialize states
        self.values = np.zeros((size, size))
        self.values[0, 2] = -1  # Cat
        self.values[2, 2] = 1   # Cheese
        
        # Track value history for convergence visualization
        self.value_history = {(i, j): [] for i in range(size) for j in range(size)}
        
        # Movement probabilities
        self.p_intended = 0.5  # Probability of moving in intended direction
        self.p_random = 0.5 / 4  # Split remaining probability among all directions
        
    def get_next_state(self, current_state: Tuple[int, int], 
                       action: Tuple[int, int]) -> Tuple[int, int]:
        """Calculate next state given current state and action"""
        next_i = current_state[0] + action[0]
        next_j = current_state[1] + action[1]
        
        # Check if next state is within grid
        if 0 <= next_i < self.size and 0 <= next_j < self.size:
            return (next_i, next_j)
        return current_state
    
    def get_possible_actions(self) -> List[Tuple[int, int]]:
        """Return all possible actions as (dx, dy)"""
        return [(0, 1), (0, -1), (1, 0), (-1, 0)]  # Right, Left, Down, Up
    
    def calculate_state_value(self, state: Tuple[int, int]) -> float:
        """Calculate value for a given state considering all actions"""
        if state == (0, 2) or state == (2, 2):  # Terminal states
            return self.values[state]
        
        max_value = float('-inf')
        actions = self.get_possible_actions()
        
        for action in actions:
            value = 0 # We know this as the immediate reward is 0
            # Intended movement
            next_state = self.get_next_state(state, action)
            value += self.p_intended * self.values[next_state]
            
            # Random movements
            for random_action in actions:
                random_next_state = self.get_next_state(state, random_action)
                value += self.p_random * self.values[random_next_state]
            
            value = self.gamma * value  # Apply discount factor
            max_value = max(max_value, value)
            
        return max_value
    
    def value_iteration(self, num_iterations: int = 100, 
                       threshold: float = 1e-4) -> np.ndarray:
        """Perform value iteration and store history"""
        for iteration in range(num_iterations):
            delta = 0
            new_values = np.copy(self.values)
            
            for i in range(self.size):
                for j in range(self.size):
                    if (i, j) not in [(0, 2), (2, 2)]:  # Skip terminal states
                        old_value = self.values[i, j]
                        new_values[i, j] = self.calculate_state_value((i, j))
                        delta = max(delta, abs(old_value - new_values[i, j]))
                        self.value_history[(i, j)].append(new_values[i, j])
            
            self.values = new_values
            
            # Check convergence
            if delta < threshold:
                print(f"Converged after {iteration + 1} iterations")
                break
        
        return self.values
    
    def plot_convergence(self):
        """Plot value convergence for each non-terminal state"""
        plt.figure(figsize=(12, 8))
        for state, history in self.value_history.items():
            if state not in [(0, 2), (2, 2)]:  # Skip terminal states
                plt.plot(history, label=f'State {state}')
        
        plt.title('Value Convergence Over Iterations')
        plt.xlabel('Iteration')
        plt.ylabel('State Value')
        plt.legend()
        plt.grid(True)
        plt.show()

# Run the simulation
grid_world = StochasticGridWorld()
final_values = grid_world.value_iteration(num_iterations=100)

print("\nFinal Values:")
print(np.round(final_values, 3))

Comments

Leave a comment