1use crate::core::agent::mcts::MCTSAgent;
2use crate::core::*;
3use game::game_logic::Game;
4use ordered_float::Pow;
5use simulation::mapper::Mapper;
6use simulation::simulator::{AbstractSim, GroundSim, Simulator};
7
8pub struct Runner {
10 game: Game,
11 mapper: Option<Mapper>,
12}
13
14impl Runner {
15 pub fn new(game: &Game, abstracted: bool, abstraction: Option<Vec<Vec<isize>>>) -> Self {
21 let mapper = if abstracted {
22 Some(Mapper::new(game, abstraction).unwrap())
23 } else {
24 None
25 };
26
27 Runner {
28 game: game.clone(),
29 mapper,
30 }
31 }
32
33 fn compute_discounted_returns(gamma: f32, turns_taken: i32) -> f32 {
35 gamma.pow(turns_taken)
36 }
37
38 pub fn run(
40 &mut self,
41 sim_limit: i32,
42 sim_depth: i32,
43 c: f32,
44 gamma: f32,
45 seed: Option<u64>,
46 max_turns: i32,
47 runs: usize,
48 debug: bool,
49 show_mcts: bool,
50 ) -> Vec<(i32, f32, (usize, usize))> {
51 let mut results = Vec::with_capacity(runs);
52
53 for run_idx in 0..runs {
54 let simulator: Box<dyn Simulator> = if let Some(mapper) = &self.mapper {
56 Box::new(AbstractSim::new(&self.game, mapper.clone()))
57 } else {
58 Box::new(GroundSim::new(&self.game))
59 };
60
61 let mut agent = MCTSAgent::new(sim_limit, sim_depth, c, gamma, seed, simulator);
62
63 let mut current_state = self.game.get_state();
65 let mut game_done = false;
66
67 for turn in 1..=max_turns {
68 if debug {
69 println!(
70 "[Run {} Turn {}] State unit position = {:?}",
71 run_idx, turn, current_state.unit_position
72 );
73 }
74
75 let action = if let Some(mapper) = &self.mapper {
77 let abs_action = agent.run(current_state.clone(), debug, show_mcts);
78 let abs_state = mapper.ground_state_to_abstract(¤t_state);
79 let (_, ground_action) =
80 mapper.abstract_state_action_to_ground(&abs_state, abs_action);
81 ground_action
82 } else {
83 agent.run(current_state.clone(), debug, show_mcts)
84 };
85
86 let (next_state, game_vars) = self.game.step(&action).unwrap_or_else(|err| {
88 panic!("Runner simulation error at run {}, turn: {}:\nState: {:?}\nAction: {:?}\nError: {:?}", run_idx, turn, current_state, action, err)
89 });
90
91 if game_vars.done {
92 let score = Self::compute_discounted_returns(gamma, turn);
93 if debug {
94 println!(
95 "[Run {}] finished in {} turns, score = {}",
96 run_idx, turn, score
97 );
98 }
99 results.push((turn, score, next_state.unit_position));
100 game_done = true;
101 break;
102 }
103
104 current_state = next_state;
105 }
106
107 if !game_done {
108 if debug {
110 println!(
111 "[Run {}] hit turn limit = {}, score = 0",
112 run_idx, max_turns
113 );
114 }
115 results.push((max_turns, 0.0, current_state.unit_position));
116 }
117
118 self.game.reset();
120 }
121
122 results
123 }
124}
125
126#[cfg(test)]
127mod runner_tests {
128 use super::*;
129 use crate::core::game::game_logic::Game;
130
131 fn make_game() -> Game {
133 let world = vec![
134 vec!['.', '.', '.'],
135 vec!['.', '.', '.'],
136 vec!['.', '.', 'G'],
137 ];
138 Game::new(world).unwrap()
139 }
140
141 #[test]
142 fn test_runner_max_turns_zero() {
143 let game = make_game();
144 let mut runner = Runner::new(&game, false, None);
146
147 let runs = 5;
149 let out = runner.run(
150 1, 1, 1.0, 1.0,
151 None, 0, runs, false,
152 false,
153 );
154 assert_eq!(out.len(), runs);
155 assert!(out.iter().all(|&(t, s, _)| t == 0 && s == 0.0));
156 }
157
158 #[test]
159 fn test_runner_one_turn_timeout() {
160 let game = make_game();
161 let mut runner = Runner::new(&game, false, None);
162
163 let runs = 3;
166 let out = runner.run(1, 1, 1.0, 1.0, None, 1, runs, false, false);
167 assert_eq!(out.len(), runs);
168 assert!(out.iter().all(|&(t, s, _)| t == 1 && s == 0.0));
169 }
170
171 #[test]
172 fn test_runner_identity_abstraction_equivalent_to_ground() {
173 let game = make_game();
175
176 let mut ground_runner = Runner::new(&game, false, None);
178 let ground_out = ground_runner.run(128, 16, 1.4, 0.85, None, 10, 1, false, false);
179
180 let identity_clusters: Vec<Vec<isize>> = (0..9).map(|i| vec![i]).collect();
182 let mut abs_runner = Runner::new(&game, true, Some(identity_clusters));
183 let abs_out = abs_runner.run(128, 16, 1.4, 0.85, None, 10, 1, false, false);
184
185 assert_eq!(ground_out, abs_out);
187 }
188}