aboutsummaryrefslogtreecommitdiffstats
path: root/src/gridworld.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/gridworld.py')
-rw-r--r--src/gridworld.py585
1 files changed, 585 insertions, 0 deletions
diff --git a/src/gridworld.py b/src/gridworld.py
new file mode 100644
index 0000000..6e1e16b
--- /dev/null
+++ b/src/gridworld.py
@@ -0,0 +1,585 @@
+# gridworld.py
+# ------------
+# Licensing Information: You are free to use or extend these projects for
+# educational purposes provided that (1) you do not distribute or publish
+# solutions, (2) you retain this notice, and (3) you provide clear
+# attribution to UC Berkeley, including a link to http://ai.berkeley.edu.
+#
+# Attribution Information: The Pacman AI projects were developed at UC Berkeley.
+# The core projects and autograders were primarily created by John DeNero
+# (denero@cs.berkeley.edu) and Dan Klein (klein@cs.berkeley.edu).
+# Student side autograding was added by Brad Miller, Nick Hay, and
+# Pieter Abbeel (pabbeel@cs.berkeley.edu).
+
+
+import random
+import sys
+import mdp
+import environment
+import util
+import optparse
+
+class Gridworld(mdp.MarkovDecisionProcess):
+ """
+ Gridworld
+ """
+ def __init__(self, grid):
+ # layout
+ if type(grid) == type([]): grid = makeGrid(grid)
+ self.grid = grid
+
+ # parameters
+ self.livingReward = 0.0
+ self.noise = 0.2
+
+ def setLivingReward(self, reward):
+ """
+ The (negative) reward for exiting "normal" states.
+
+ Note that in the R+N text, this reward is on entering
+ a state and therefore is not clearly part of the state's
+ future rewards.
+ """
+ self.livingReward = reward
+
+ def setNoise(self, noise):
+ """
+ The probability of moving in an unintended direction.
+ """
+ self.noise = noise
+
+
+ def getPossibleActions(self, state):
+ """
+ Returns list of valid actions for 'state'.
+
+ Note that you can request moves into walls and
+ that "exit" states transition to the terminal
+ state under the special action "done".
+ """
+ if state == self.grid.terminalState:
+ return ()
+ x,y = state
+ if type(self.grid[x][y]) == int:
+ return ('exit',)
+ return ('north','west','south','east')
+
+ def getStates(self):
+ """
+ Return list of all states.
+ """
+ # The true terminal state.
+ states = [self.grid.terminalState]
+ for x in range(self.grid.width):
+ for y in range(self.grid.height):
+ if self.grid[x][y] != '#':
+ state = (x,y)
+ states.append(state)
+ return states
+
+ def getReward(self, state, action, nextState):
+ """
+ Get reward for state, action, nextState transition.
+
+ Note that the reward depends only on the state being
+ departed (as in the R+N book examples, which more or
+ less use this convention).
+ """
+ if state == self.grid.terminalState:
+ return 0.0
+ x, y = state
+ cell = self.grid[x][y]
+ if type(cell) == int or type(cell) == float:
+ return cell
+ return self.livingReward
+
+ def getStartState(self):
+ for x in range(self.grid.width):
+ for y in range(self.grid.height):
+ if self.grid[x][y] == 'S':
+ return (x, y)
+ raise 'Grid has no start state'
+
+ def isTerminal(self, state):
+ """
+ Only the TERMINAL_STATE state is *actually* a terminal state.
+ The other "exit" states are technically non-terminals with
+ a single action "exit" which leads to the true terminal state.
+ This convention is to make the grids line up with the examples
+ in the R+N textbook.
+ """
+ return state == self.grid.terminalState
+
+
+ def getTransitionStatesAndProbs(self, state, action):
+ """
+ Returns list of (nextState, prob) pairs
+ representing the states reachable
+ from 'state' by taking 'action' along
+ with their transition probabilities.
+ """
+
+ if action not in self.getPossibleActions(state):
+ raise "Illegal action!"
+
+ if self.isTerminal(state):
+ return []
+
+ x, y = state
+
+ if type(self.grid[x][y]) == int or type(self.grid[x][y]) == float:
+ termState = self.grid.terminalState
+ return [(termState, 1.0)]
+
+ successors = []
+
+ northState = (self.__isAllowed(y+1,x) and (x,y+1)) or state
+ westState = (self.__isAllowed(y,x-1) and (x-1,y)) or state
+ southState = (self.__isAllowed(y-1,x) and (x,y-1)) or state
+ eastState = (self.__isAllowed(y,x+1) and (x+1,y)) or state
+
+ if action == 'north' or action == 'south':
+ if action == 'north':
+ successors.append((northState,1-self.noise))
+ else:
+ successors.append((southState,1-self.noise))
+
+ massLeft = self.noise
+ successors.append((westState,massLeft/2.0))
+ successors.append((eastState,massLeft/2.0))
+
+ if action == 'west' or action == 'east':
+ if action == 'west':
+ successors.append((westState,1-self.noise))
+ else:
+ successors.append((eastState,1-self.noise))
+
+ massLeft = self.noise
+ successors.append((northState,massLeft/2.0))
+ successors.append((southState,massLeft/2.0))
+
+ successors = self.__aggregate(successors)
+
+ return successors
+
+ def __aggregate(self, statesAndProbs):
+ counter = util.Counter()
+ for state, prob in statesAndProbs:
+ counter[state] += prob
+ newStatesAndProbs = []
+ for state, prob in counter.items():
+ newStatesAndProbs.append((state, prob))
+ return newStatesAndProbs
+
+ def __isAllowed(self, y, x):
+ if y < 0 or y >= self.grid.height: return False
+ if x < 0 or x >= self.grid.width: return False
+ return self.grid[x][y] != '#'
+
+class GridworldEnvironment(environment.Environment):
+
+ def __init__(self, gridWorld):
+ self.gridWorld = gridWorld
+ self.reset()
+
+ def getCurrentState(self):
+ return self.state
+
+ def getPossibleActions(self, state):
+ return self.gridWorld.getPossibleActions(state)
+
+ def doAction(self, action):
+ state = self.getCurrentState()
+ (nextState, reward) = self.getRandomNextState(state, action)
+ self.state = nextState
+ return (nextState, reward)
+
+ def getRandomNextState(self, state, action, randObj=None):
+ rand = -1.0
+ if randObj is None:
+ rand = random.random()
+ else:
+ rand = randObj.random()
+ sum = 0.0
+ successors = self.gridWorld.getTransitionStatesAndProbs(state, action)
+ for nextState, prob in successors:
+ sum += prob
+ if sum > 1.0:
+ raise 'Total transition probability more than one; sample failure.'
+ if rand < sum:
+ reward = self.gridWorld.getReward(state, action, nextState)
+ return (nextState, reward)
+ raise 'Total transition probability less than one; sample failure.'
+
+ def reset(self):
+ self.state = self.gridWorld.getStartState()
+
+class Grid:
+ """
+ A 2-dimensional array of immutables backed by a list of lists. Data is accessed
+ via grid[x][y] where (x,y) are cartesian coordinates with x horizontal,
+ y vertical and the origin (0,0) in the bottom left corner.
+
+ The __str__ method constructs an output that is oriented appropriately.
+ """
+ def __init__(self, width, height, initialValue=' '):
+ self.width = width
+ self.height = height
+ self.data = [[initialValue for y in range(height)] for x in range(width)]
+ self.terminalState = 'TERMINAL_STATE'
+
+ def __getitem__(self, i):
+ return self.data[i]
+
+ def __setitem__(self, key, item):
+ self.data[key] = item
+
+ def __eq__(self, other):
+ if other == None: return False
+ return self.data == other.data
+
+ def __hash__(self):
+ return hash(self.data)
+
+ def copy(self):
+ g = Grid(self.width, self.height)
+ g.data = [x[:] for x in self.data]
+ return g
+
+ def deepCopy(self):
+ return self.copy()
+
+ def shallowCopy(self):
+ g = Grid(self.width, self.height)
+ g.data = self.data
+ return g
+
+ def _getLegacyText(self):
+ t = [[self.data[x][y] for x in range(self.width)] for y in range(self.height)]
+ t.reverse()
+ return t
+
+ def __str__(self):
+ return str(self._getLegacyText())
+
+def makeGrid(gridString):
+ width, height = len(gridString[0]), len(gridString)
+ grid = Grid(width, height)
+ for ybar, line in enumerate(gridString):
+ y = height - ybar - 1
+ for x, el in enumerate(line):
+ grid[x][y] = el
+ return grid
+
+def getCliffGrid():
+ grid = [[' ',' ',' ',' ',' '],
+ ['S',' ',' ',' ',10],
+ [-100,-100, -100, -100, -100]]
+ return Gridworld(makeGrid(grid))
+
+def getCliffGrid2():
+ grid = [[' ',' ',' ',' ',' '],
+ [8,'S',' ',' ',10],
+ [-100,-100, -100, -100, -100]]
+ return Gridworld(grid)
+
+def getDiscountGrid():
+ grid = [[' ',' ',' ',' ',' '],
+ [' ','#',' ',' ',' '],
+ [' ','#', 1,'#', 10],
+ ['S',' ',' ',' ',' '],
+ [-10,-10, -10, -10, -10]]
+ return Gridworld(grid)
+
+def getBridgeGrid():
+ grid = [[ '#',-100, -100, -100, -100, -100, '#'],
+ [ 1, 'S', ' ', ' ', ' ', ' ', 10],
+ [ '#',-100, -100, -100, -100, -100, '#']]
+ return Gridworld(grid)
+
+def getBookGrid():
+ grid = [[' ',' ',' ',+1],
+ [' ','#',' ',-1],
+ ['S',' ',' ',' ']]
+ return Gridworld(grid)
+
+def getMazeGrid():
+ grid = [[' ',' ',' ',+1],
+ ['#','#',' ','#'],
+ [' ','#',' ',' '],
+ [' ','#','#',' '],
+ ['S',' ',' ',' ']]
+ return Gridworld(grid)
+
+
+
+def getUserAction(state, actionFunction):
+ """
+ Get an action from the user (rather than the agent).
+
+ Used for debugging and lecture demos.
+ """
+ import graphicsUtils
+ action = None
+ while True:
+ keys = graphicsUtils.wait_for_keys()
+ if 'Up' in keys: action = 'north'
+ if 'Down' in keys: action = 'south'
+ if 'Left' in keys: action = 'west'
+ if 'Right' in keys: action = 'east'
+ if 'q' in keys: sys.exit(0)
+ if action == None: continue
+ break
+ actions = actionFunction(state)
+ if action not in actions:
+ action = actions[0]
+ return action
+
+def printString(x): print x
+
+def runEpisode(agent, environment, discount, decision, display, message, pause, episode):
+ returns = 0
+ totalDiscount = 1.0
+ environment.reset()
+ if 'startEpisode' in dir(agent): agent.startEpisode()
+ message("BEGINNING EPISODE: "+str(episode)+"\n")
+ while True:
+
+ # DISPLAY CURRENT STATE
+ state = environment.getCurrentState()
+ display(state)
+ pause()
+
+ # END IF IN A TERMINAL STATE
+ actions = environment.getPossibleActions(state)
+ if len(actions) == 0:
+ message("EPISODE "+str(episode)+" COMPLETE: RETURN WAS "+str(returns)+"\n")
+ return returns
+
+ # GET ACTION (USUALLY FROM AGENT)
+ action = decision(state)
+ if action == None:
+ raise 'Error: Agent returned None action'
+
+ # EXECUTE ACTION
+ nextState, reward = environment.doAction(action)
+ message("Started in state: "+str(state)+
+ "\nTook action: "+str(action)+
+ "\nEnded in state: "+str(nextState)+
+ "\nGot reward: "+str(reward)+"\n")
+ # UPDATE LEARNER
+ if 'observeTransition' in dir(agent):
+ agent.observeTransition(state, action, nextState, reward)
+
+ returns += reward * totalDiscount
+ totalDiscount *= discount
+
+ if 'stopEpisode' in dir(agent):
+ agent.stopEpisode()
+
+def parseOptions():
+ optParser = optparse.OptionParser()
+ optParser.add_option('-d', '--discount',action='store',
+ type='float',dest='discount',default=0.9,
+ help='Discount on future (default %default)')
+ optParser.add_option('-r', '--livingReward',action='store',
+ type='float',dest='livingReward',default=0.0,
+ metavar="R", help='Reward for living for a time step (default %default)')
+ optParser.add_option('-n', '--noise',action='store',
+ type='float',dest='noise',default=0.2,
+ metavar="P", help='How often action results in ' +
+ 'unintended direction (default %default)' )
+ optParser.add_option('-e', '--epsilon',action='store',
+ type='float',dest='epsilon',default=0.3,
+ metavar="E", help='Chance of taking a random action in q-learning (default %default)')
+ optParser.add_option('-l', '--learningRate',action='store',
+ type='float',dest='learningRate',default=0.5,
+ metavar="P", help='TD learning rate (default %default)' )
+ optParser.add_option('-i', '--iterations',action='store',
+ type='int',dest='iters',default=10,
+ metavar="K", help='Number of rounds of value iteration (default %default)')
+ optParser.add_option('-k', '--episodes',action='store',
+ type='int',dest='episodes',default=1,
+ metavar="K", help='Number of epsiodes of the MDP to run (default %default)')
+ optParser.add_option('-g', '--grid',action='store',
+ metavar="G", type='string',dest='grid',default="BookGrid",
+ help='Grid to use (case sensitive; options are BookGrid, BridgeGrid, CliffGrid, MazeGrid, default %default)' )
+ optParser.add_option('-w', '--windowSize', metavar="X", type='int',dest='gridSize',default=150,
+ help='Request a window width of X pixels *per grid cell* (default %default)')
+ optParser.add_option('-a', '--agent',action='store', metavar="A",
+ type='string',dest='agent',default="random",
+ help='Agent type (options are \'random\', \'value\' and \'q\', default %default)')
+ optParser.add_option('-t', '--text',action='store_true',
+ dest='textDisplay',default=False,
+ help='Use text-only ASCII display')
+ optParser.add_option('-p', '--pause',action='store_true',
+ dest='pause',default=False,
+ help='Pause GUI after each time step when running the MDP')
+ optParser.add_option('-q', '--quiet',action='store_true',
+ dest='quiet',default=False,
+ help='Skip display of any learning episodes')
+ optParser.add_option('-s', '--speed',action='store', metavar="S", type=float,
+ dest='speed',default=1.0,
+ help='Speed of animation, S > 1.0 is faster, 0.0 < S < 1.0 is slower (default %default)')
+ optParser.add_option('-m', '--manual',action='store_true',
+ dest='manual',default=False,
+ help='Manually control agent')
+ optParser.add_option('-v', '--valueSteps',action='store_true' ,default=False,
+ help='Display each step of value iteration')
+
+ opts, args = optParser.parse_args()
+
+ if opts.manual and opts.agent != 'q':
+ print '## Disabling Agents in Manual Mode (-m) ##'
+ opts.agent = None
+
+ # MANAGE CONFLICTS
+ if opts.textDisplay or opts.quiet:
+ # if opts.quiet:
+ opts.pause = False
+ # opts.manual = False
+
+ if opts.manual:
+ opts.pause = True
+
+ return opts
+
+
+if __name__ == '__main__':
+
+ opts = parseOptions()
+
+ ###########################
+ # GET THE GRIDWORLD
+ ###########################
+
+ import gridworld
+ mdpFunction = getattr(gridworld, "get"+opts.grid)
+ mdp = mdpFunction()
+ mdp.setLivingReward(opts.livingReward)
+ mdp.setNoise(opts.noise)
+ env = gridworld.GridworldEnvironment(mdp)
+
+
+ ###########################
+ # GET THE DISPLAY ADAPTER
+ ###########################
+
+ import textGridworldDisplay
+ display = textGridworldDisplay.TextGridworldDisplay(mdp)
+ if not opts.textDisplay:
+ import graphicsGridworldDisplay
+ display = graphicsGridworldDisplay.GraphicsGridworldDisplay(mdp, opts.gridSize, opts.speed)
+ try:
+ display.start()
+ except KeyboardInterrupt:
+ sys.exit(0)
+
+ ###########################
+ # GET THE AGENT
+ ###########################
+
+ import valueIterationAgents, qlearningAgents
+ a = None
+ if opts.agent == 'value':
+ a = valueIterationAgents.ValueIterationAgent(mdp, opts.discount, opts.iters)
+ elif opts.agent == 'q':
+ #env.getPossibleActions, opts.discount, opts.learningRate, opts.epsilon
+ #simulationFn = lambda agent, state: simulation.GridworldSimulation(agent,state,mdp)
+ gridWorldEnv = GridworldEnvironment(mdp)
+ actionFn = lambda state: mdp.getPossibleActions(state)
+ qLearnOpts = {'gamma': opts.discount,
+ 'alpha': opts.learningRate,
+ 'epsilon': opts.epsilon,
+ 'actionFn': actionFn}
+ a = qlearningAgents.QLearningAgent(**qLearnOpts)
+ elif opts.agent == 'random':
+ # # No reason to use the random agent without episodes
+ if opts.episodes == 0:
+ opts.episodes = 10
+ class RandomAgent:
+ def getAction(self, state):
+ return random.choice(mdp.getPossibleActions(state))
+ def getValue(self, state):
+ return 0.0
+ def getQValue(self, state, action):
+ return 0.0
+ def getPolicy(self, state):
+ "NOTE: 'random' is a special policy value; don't use it in your code."
+ return 'random'
+ def update(self, state, action, nextState, reward):
+ pass
+ a = RandomAgent()
+ else:
+ if not opts.manual: raise 'Unknown agent type: '+opts.agent
+
+
+ ###########################
+ # RUN EPISODES
+ ###########################
+ # DISPLAY Q/V VALUES BEFORE SIMULATION OF EPISODES
+ try:
+ if not opts.manual and opts.agent == 'value':
+ if opts.valueSteps:
+ for i in range(opts.iters):
+ tempAgent = valueIterationAgents.ValueIterationAgent(mdp, opts.discount, i)
+ display.displayValues(tempAgent, message = "VALUES AFTER "+str(i)+" ITERATIONS")
+ display.pause()
+
+ display.displayValues(a, message = "VALUES AFTER "+str(opts.iters)+" ITERATIONS")
+ display.pause()
+ display.displayQValues(a, message = "Q-VALUES AFTER "+str(opts.iters)+" ITERATIONS")
+ display.pause()
+ except KeyboardInterrupt:
+ sys.exit(0)
+
+
+
+ # FIGURE OUT WHAT TO DISPLAY EACH TIME STEP (IF ANYTHING)
+ displayCallback = lambda x: None
+ if not opts.quiet:
+ if opts.manual and opts.agent == None:
+ displayCallback = lambda state: display.displayNullValues(state)
+ else:
+ if opts.agent == 'random': displayCallback = lambda state: display.displayValues(a, state, "CURRENT VALUES")
+ if opts.agent == 'value': displayCallback = lambda state: display.displayValues(a, state, "CURRENT VALUES")
+ if opts.agent == 'q': displayCallback = lambda state: display.displayQValues(a, state, "CURRENT Q-VALUES")
+
+ messageCallback = lambda x: printString(x)
+ if opts.quiet:
+ messageCallback = lambda x: None
+
+ # FIGURE OUT WHETHER TO WAIT FOR A KEY PRESS AFTER EACH TIME STEP
+ pauseCallback = lambda : None
+ if opts.pause:
+ pauseCallback = lambda : display.pause()
+
+ # FIGURE OUT WHETHER THE USER WANTS MANUAL CONTROL (FOR DEBUGGING AND DEMOS)
+ if opts.manual:
+ decisionCallback = lambda state : getUserAction(state, mdp.getPossibleActions)
+ else:
+ decisionCallback = a.getAction
+
+ # RUN EPISODES
+ if opts.episodes > 0:
+ print
+ print "RUNNING", opts.episodes, "EPISODES"
+ print
+ returns = 0
+ for episode in range(1, opts.episodes+1):
+ returns += runEpisode(a, env, opts.discount, decisionCallback, displayCallback, messageCallback, pauseCallback, episode)
+ if opts.episodes > 0:
+ print
+ print "AVERAGE RETURNS FROM START STATE: "+str((returns+0.0) / opts.episodes)
+ print
+ print
+
+ # DISPLAY POST-LEARNING VALUES / Q-VALUES
+ if opts.agent == 'q' and not opts.manual:
+ try:
+ display.displayQValues(a, message = "Q-VALUES AFTER "+str(opts.episodes)+" EPISODES")
+ display.pause()
+ display.displayValues(a, message = "VALUES AFTER "+str(opts.episodes)+" EPISODES")
+ display.pause()
+ except KeyboardInterrupt:
+ sys.exit(0)