rllib
rl-mountain-car.hpp
1 /* This file is part of rl-lib
2  *
3  * Copyright (C) 2010, Supelec
4  *
5  * Author : Herve Frezza-Buet and Matthieu Geist
6  *
7  * Contributor :
8  *
9  * This library is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU General Public
11  * License (GPL) as published by the Free Software Foundation; either
12  * version 3 of the License, or any later version.
13  *
14  * This library is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17  * General Public License for more details.
18  *
19  * You should have received a copy of the GNU General Public
20  * License along with this library; if not, write to the Free Software
21  * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
22  *
23  * Contact : Herve.Frezza-Buet@supelec.fr Matthieu.Geist@supelec.fr
24  *
25  */
26 
27 #pragma once
28 
29 #include <cstdlib>
30 #include <cmath>
31 #include <iostream>
32 #include <iomanip>
33 #include <sstream>
34 #include <fstream>
35 #include <vector>
36 #include <iterator>
37 #include <utility>
38 #include <rlAlgo.hpp>
39 #include <rlEpisode.hpp>
40 #include <rlException.hpp>
41 
42 namespace rl {
43  namespace problem {
44  namespace mountain_car {
45 
46  // The action space
47  enum class Action: int {
48  actionNone = 0,
49  actionBackward = 1,
50  actionForward = 2};
51 
52  constexpr int actionSize = 3;
53 
54  // some exceptions for state and action consistancy
55  class BadAction : public rl::exception::Any {
56  public:
57  BadAction(std::string comment)
58  : Any(std::string("Bad action performed : ")+comment) {}
59  };
60 
61  class BadState : public rl::exception::Any {
62  public:
63  BadState(std::string comment)
64  : Any(std::string("Bad state found : ")+comment) {}
65  };
66 
67 
68  // Mountain car parameters
69  class DefaultParam {
70  public:
71  inline static double minPosition(void) {return -1.200;}
72  inline static double maxPosition(void) {return 0.500;}
73  inline static double minSpeed(void) {return -0.070;}
74  inline static double maxSpeed(void) {return 0.070;}
75  inline static double goalPosition(void) {return maxPosition();}
76  inline static double goalSpeed(void) {return 0.000;}
77  inline static double goalSpeedMargin(void) {return maxSpeed();}
78  inline static double rewardGoal(void) {return 1.0;};
79  inline static double rewardStep(void) {return 0.0;};
80  };
81 
82  // This is the phase space
83  template<typename PARAM>
84  class Phase {
85  public:
86  using param_type = PARAM;
87 
88  double position,speed;
89 
90  Phase(void) {}
91  Phase(const Phase& copy) : position(copy.position), speed(copy.speed) {}
92  Phase(double p, double s) : position(p), speed(s) {}
93  ~Phase(void) {}
94  Phase& operator=(const Phase& copy) {
95  if(this != &copy) {
96  position = copy.position;
97  speed = copy.speed;
98  }
99  return *this;
100  }
101 
102  void check(void) const {
103  if( (position > param_type::maxPosition()) || (position < param_type::minPosition())
104  || (speed > param_type::maxSpeed()) || (speed < param_type::minSpeed()) ) {
105  std::ostringstream ostr;
106  ostr << "mountain_car::Phase::check : At position = " << position << " and speed = " << speed << ".";
107  throw BadState(ostr.str());
108  }
109  }
110 
111  template<typename RANDOM_DEVICE>
112  static Phase<PARAM> random(RANDOM_DEVICE& gen) {
113  return Phase<PARAM>(std::uniform_real_distribution<>(param_type::minPosition(), param_type::maxPosition())(gen),
114  std::uniform_real_distribution<>(param_type::minSpeed(), param_type::maxSpeed())(gen));
115  }
116 
117  void saturateSpeed(void) {
118  if(speed < param_type::minSpeed())
119  speed = param_type::minSpeed();
120  else if(speed > param_type::maxSpeed())
121  speed = param_type::maxSpeed();
122  }
123  };
124 
129  template<typename MOUNTAIN_CAR_PARAM>
130  class Simulator {
131 
132  public:
133 
134  using param_type = MOUNTAIN_CAR_PARAM;
135 
138  using action_type = Action;
139  using reward_type = double;
140 
141  private:
142 
143  phase_type current_state;
144  double r;
145 
146  public:
147 
148  // This can be usefull for drawing graphics.
149  void location(double& position,
150  double& speed,
151  double& height) {
152  position = current_state.position;
153  speed = current_state.speed;
154  height = heightOf(position);
155  }
156 
157  static double heightOf(double position) {
158  return sin(3*position);
159  }
160 
161  // The bottom position
162  static double bottom(void) {
163  return - M_PI/6;
164  }
165 
166  void setPhase(const phase_type& s) {
167  current_state = s;
168  current_state.check();
169  }
170 
171  const observation_type& sense(void) const {
172  current_state.check();
173  return current_state;
174  }
175 
176  void timeStep(const action_type& a) {
177  double aa;
178 
179  switch(a) {
180  case Action::actionForward:
181  aa = 1;
182  break;
183  case Action::actionBackward:
184  aa = -1;
185  break;
186  case Action::actionNone:
187  aa = 0;
188  break;
189  default:
190  std::ostringstream ostr;
191  ostr << "mountain_car::Simulator::timeStep(" << static_cast<int>(a) << ")";
192  throw BadAction(ostr.str());
193  }
194 
195  current_state.speed += (0.001*aa - 0.0025*cos(3*current_state.position));
196  current_state.saturateSpeed();
197  current_state.position += current_state.speed;
198 
199  r=param_type::rewardStep();
200  if(current_state.position < param_type::minPosition()) {
201  current_state.position = param_type::minPosition();
202  current_state.speed = 0;
203  }
204  else if(current_state.position > param_type::maxPosition()) {
205 
206  if((current_state.speed >= param_type::goalSpeed())
207  &&
208  (current_state.speed <= param_type::goalSpeed() + param_type::goalSpeedMargin())) {
209  r = param_type::rewardGoal();
210  throw rl::exception::Terminal("Goal reached");
211  }
212 
213  throw rl::exception::Terminal("Upper position bound violated");
214  }
215 
216  }
217 
218  reward_type reward(void) const {
219  return r;
220  }
221 
222  Simulator(void) : current_state(), r(0) {}
223  Simulator(const Simulator& copy)
224  : current_state(copy.current_state),
225  r(copy.r) {}
226  ~Simulator(void) {}
227 
228  Simulator& operator=(const Simulator& copy) {
229  if(this != &copy)
230  current_state = copy.current_state;
231  return *this;
232  }
233  };
234 
239  template<typename SIMULATOR>
240  class Gnuplot {
241  private:
242 
243 
244  template<typename Q, typename POLICY>
245  static void Qdata(std::ostream& file,
246  const Q& q,
247  const POLICY& policy,
248  int points_per_side,
249  bool draw_q) {
250 
251  double coef_p,coef_s;
252  double position,speed;
253  int p,s;
254  Action a;
255  typename SIMULATOR::phase_type current;
256  coef_p = (SIMULATOR::param_type::maxPosition()-SIMULATOR::param_type::minPosition())/((double)(points_per_side-1));
257  coef_s = (SIMULATOR::param_type::maxSpeed()-SIMULATOR::param_type::minSpeed())/((double)(points_per_side-1));
258  for(s=0;s<points_per_side;++s) {
259  speed = SIMULATOR::param_type::minSpeed() + coef_s*s;
260  for(p=0;p<points_per_side;++p) {
261  position = SIMULATOR::param_type::minPosition() + coef_p*p;
262  current = typename SIMULATOR::phase_type(position,speed);
263  a = policy(current);
264  if(draw_q)
265  file << position << ' ' << speed << ' ' << q(current,a) << ' ' << static_cast<int>(a) << std::endl;
266  else
267  file << position << ' ' << speed << ' ' << static_cast<int>(a) << std::endl;
268  }
269  file << std::endl;
270  }
271  }
272 
273  public:
274 
275  template<typename Q, typename POLICY>
276  static void drawQ(std::string title,
277  std::string file_prefix, int rank,
278  const Q& q,
279  const POLICY& policy,
280  int points_per_side=50) {
281  std::ostringstream ostr;
282  std::ofstream file;
283  std::string numbered_prefix;
284  std::string filename;
285 
286  ostr << file_prefix;
287  if(rank >=0)
288  ostr << '-' << std::setfill('0') << std::setw(6) << rank;
289  numbered_prefix = ostr.str();
290  filename = numbered_prefix + ".plot";
291 
292  file.open(filename.c_str());
293  if(!file) {
294  std::cerr << "Cannot open \"" << filename << "\". Plotting skipped." << std::endl;
295  return;
296  }
297  file << "unset hidden3d;" << std::endl
298  << "set xrange [" << SIMULATOR::param_type::minPosition()
299  << ":" << SIMULATOR::param_type::maxPosition() << "];" << std::endl
300  << "set yrange [" << SIMULATOR::param_type::minSpeed()
301  << ":" << SIMULATOR::param_type::maxSpeed() << "];" << std::endl
302  << "set zrange [-1:1.5];" << std::endl
303  << "set cbrange [0:2];" << std::endl
304  << "set view 48,336;" << std::endl
305  << "set palette defined ( 0 \"yellow\", 1 \"red\",2 \"blue\");" << std::endl
306  << "set ticslevel 0;" << std::endl
307  << "set title \"" << title << "\";" << std::endl
308  << "set xlabel \"position\";" << std::endl
309  << "set ylabel \"speed\";" << std::endl
310  << "set zlabel \"Q(max_a)\";" << std::endl
311  << "set cblabel \"none=" << static_cast<int>(Action::actionNone)
312  << ", forward=" << static_cast<int>(Action::actionForward)
313  << ", backward=" << static_cast<int>(Action::actionBackward)
314  << "\";" << std::endl
315  << "set style line 100 linecolor rgb \"black\";" << std::endl
316  << "set pm3d at s hidden3d 100;" << std::endl
317  << "set output \"" << numbered_prefix << ".png\";" << std::endl
318  << "set term png enhanced size 600,400;"<< std::endl
319  << "splot '-' using 1:2:3:4 with pm3d notitle;" << std::endl;
320 
321 
322 
323  Qdata(file,q,policy,points_per_side,true);
324  file.close();
325  std::cout << "\"" << filename << "\" generated." << std::endl;
326  }
327 
331  template<typename Q,typename POLICY>
332  static void drawEpisode(std::string title,
333  std::string file_prefix, int rank,
334  SIMULATOR& simulator,
335  const Q& q,
336  const POLICY& policy,
337  unsigned int max_episode_length,
338  int points_per_side=50) {
339  std::ostringstream ostr;
340  std::ostringstream titleostr;
341  std::ofstream file;
342  std::string numbered_prefix;
343  std::string filename,policyfilename;
344  double cumrew;
345 
346 
347 
348  ostr << file_prefix;
349  if(rank >=0)
350  ostr << '-' << std::setfill('0') << std::setw(6) << rank;
351  numbered_prefix = ostr.str();
352  filename = numbered_prefix + ".plot";
353  policyfilename = numbered_prefix + "-policy.data";
354 
355  file.open(filename.c_str());
356  if(!file) {
357  std::cerr << "Cannot open \"" << filename << "\". Plotting skipped." << std::endl;
358  return;
359  }
360 
361  std::vector<std::pair<typename SIMULATOR::phase_type,typename SIMULATOR::reward_type>> transitions;
362  rl::episode::run(simulator,policy,
363  std::back_inserter(transitions),
364  [](const typename SIMULATOR::phase_type& s,
365  const typename SIMULATOR::action_type& a,
366  const typename SIMULATOR::reward_type r,
367  const typename SIMULATOR::phase_type& s_)
368  -> std::pair<typename SIMULATOR::phase_type,typename SIMULATOR::reward_type> {return std::make_pair(s,r);},
369  [](const typename SIMULATOR::phase_type& s,
370  const typename SIMULATOR::action_type& a,
371  const typename SIMULATOR::reward_type r)
372  -> std::pair<typename SIMULATOR::phase_type,typename SIMULATOR::reward_type> {return std::make_pair(s,r);},
373  max_episode_length);
374 
375  cumrew=0;
376  for(auto& t : transitions)
377  cumrew += t.second;
378 
379  titleostr << title << "\\n cumulated reward = " << cumrew;
380 
381  file << "set xrange [" << SIMULATOR::param_type::minPosition()
382  << ":" << SIMULATOR::param_type::maxPosition() << "];" << std::endl
383  << "set yrange [" << SIMULATOR::param_type::minSpeed()
384  << ":" << SIMULATOR::param_type::maxSpeed() << "];" << std::endl
385  << "set zrange [0:3];" << std::endl
386  << "set cbrange [0:3];" << std::endl
387  << "set title \"" << titleostr.str() << "\";" << std::endl
388  << "set palette defined ( 0 \"yellow\", 1 \"red\",2 \"blue\", 3 \"black\");" << std::endl
389  << "set xlabel \"position\";" << std::endl
390  << "set ylabel \"speed\";" << std::endl
391  << "set cblabel \"none=" << static_cast<int>(Action::actionNone)
392  << ", forward=" << static_cast<int>(Action::actionForward)
393  << ", backward=" << static_cast<int>(Action::actionBackward)
394  << "\";" << std::endl
395  << "set view map;" << std::endl
396  << "set pm3d at s;" << std::endl
397  << "splot '" << policyfilename << "' with pm3d notitle, \\" << std::endl
398  << " '-' with linespoints notitle pt 7 ps 0.5 lc rgb \"black\"" << std::endl;
399 
400  for(auto& t : transitions)
401  file << t.first.position << ' '
402  << t.first.speed << ' '
403  << 3 << std::endl;
404  file.close();
405  std::cout << "\"" << filename << "\" generated." << std::endl;
406 
407 
408  file.open(policyfilename.c_str());
409  if(!file) {
410  std::cerr << "Cannot open \"" << filename << "\". Plotting skipped." << std::endl;
411  return;
412  }
413  Qdata(file,q,policy,points_per_side,false);
414  file.close();
415  std::cout << "\"" << policyfilename << "\" generated." << std::endl;
416  }
417  };
418  }
419  }
420 }
rl::exception::Any
Definition: rlException.hpp:35
rl::problem::mountain_car::BadAction
Definition: rl-mountain-car.hpp:55
rl::problem::mountain_car::Phase
Definition: rl-mountain-car.hpp:84
rl::problem::mountain_car::Gnuplot
This plots nice graphics for representing the Q function.
Definition: rl-mountain-car.hpp:240
rl::problem::mountain_car::DefaultParam
Definition: rl-mountain-car.hpp:69
rl::exception::Terminal
Terminal state.
Definition: rlException.hpp:56
rl::problem::mountain_car::BadState
Definition: rl-mountain-car.hpp:61
rl::problem::mountain_car::Simulator
Definition: rl-mountain-car.hpp:130
rl::problem::mountain_car::Gnuplot::drawEpisode
static void drawEpisode(std::string title, std::string file_prefix, int rank, SIMULATOR &simulator, const Q &q, const POLICY &policy, unsigned int max_episode_length, int points_per_side=50)
Definition: rl-mountain-car.hpp:332