core_rust/core/
runner.rs

1use crate::core::agent::mcts::MCTSAgent;
2use crate::core::*;
3use game::game_logic::Game;
4use ordered_float::Pow;
5use simulation::mapper::Mapper;
6use simulation::simulator::{AbstractSim, GroundSim, Simulator};
7
8/// Execute MCTS episodes in either the ground or abstract MDP.
9pub struct Runner {
10    game: Game,
11    mapper: Option<Mapper>,
12}
13
14impl Runner {
15    /// Construct a new `Runner`.
16    ///
17    /// If `abstracted` is true, the runner builds a `Mapper` using either the
18    /// provided `abstraction` or a learned abstraction; otherwise the runner
19    /// interacts with the ground MDP directly.
20    pub fn new(game: &Game, abstracted: bool, abstraction: Option<Vec<Vec<isize>>>) -> Self {
21        let mapper = if abstracted {
22            Some(Mapper::new(game, abstraction).unwrap())
23        } else {
24            None
25        };
26
27        Runner {
28            game: game.clone(),
29            mapper,
30        }
31    }
32
33    /// Compute discounted return for a terminal trajectory of the given length.
34    fn compute_discounted_returns(gamma: f32, turns_taken: i32) -> f32 {
35        gamma.pow(turns_taken)
36    }
37
38    /// Run `runs` independent episodes and return per-run (turns, score, final_pos).
39    pub fn run(
40        &mut self,
41        sim_limit: i32,
42        sim_depth: i32,
43        c: f32,
44        gamma: f32,
45        seed: Option<u64>,
46        max_turns: i32,
47        runs: usize,
48        debug: bool,
49        show_mcts: bool,
50    ) -> Vec<(i32, f32, (usize, usize))> {
51        let mut results = Vec::with_capacity(runs);
52
53        for run_idx in 0..runs {
54            // Pick the appropriate simulator once for this run
55            let simulator: Box<dyn Simulator> = if let Some(mapper) = &self.mapper {
56                Box::new(AbstractSim::new(&self.game, mapper.clone()))
57            } else {
58                Box::new(GroundSim::new(&self.game))
59            };
60
61            let mut agent = MCTSAgent::new(sim_limit, sim_depth, c, gamma, seed, simulator);
62
63            // Track one local state through the turns
64            let mut current_state = self.game.get_state();
65            let mut game_done = false;
66
67            for turn in 1..=max_turns {
68                if debug {
69                    println!(
70                        "[Run {} Turn {}] State unit position = {:?}",
71                        run_idx, turn, current_state.unit_position
72                    );
73                }
74
75                // MCTS returns an Action (abstract or ground)
76                let action = if let Some(mapper) = &self.mapper {
77                    let abs_action = agent.run(current_state.clone(), debug, show_mcts);
78                    let abs_state = mapper.ground_state_to_abstract(&current_state);
79                    let (_, ground_action) =
80                        mapper.abstract_state_action_to_ground(&abs_state, abs_action);
81                    ground_action
82                } else {
83                    agent.run(current_state.clone(), debug, show_mcts)
84                };
85
86                // Simulate that action via the same Simulator
87                let (next_state, game_vars) = self.game.step(&action).unwrap_or_else(|err| {
88                        panic!("Runner simulation error at run {}, turn: {}:\nState: {:?}\nAction: {:?}\nError: {:?}", run_idx, turn, current_state, action, err)
89                    });
90
91                if game_vars.done {
92                    let score = Self::compute_discounted_returns(gamma, turn);
93                    if debug {
94                        println!(
95                            "[Run {}] finished in {} turns, score = {}",
96                            run_idx, turn, score
97                        );
98                    }
99                    results.push((turn, score, next_state.unit_position));
100                    game_done = true;
101                    break;
102                }
103
104                current_state = next_state;
105            }
106
107            if !game_done {
108                // ran all the way to max_turns without finishing
109                if debug {
110                    println!(
111                        "[Run {}] hit turn limit = {}, score = 0",
112                        run_idx, max_turns
113                    );
114                }
115                results.push((max_turns, 0.0, current_state.unit_position));
116            }
117
118            // Reset the underlying world for the next run
119            self.game.reset();
120        }
121
122        results
123    }
124}
125
126#[cfg(test)]
127mod runner_tests {
128    use super::*;
129    use crate::core::game::game_logic::Game;
130
131    /// Helper to build a simple 3×3 world with goal in bottom‐right.
132    fn make_game() -> Game {
133        let world = vec![
134            vec!['.', '.', '.'],
135            vec!['.', '.', '.'],
136            vec!['.', '.', 'G'],
137        ];
138        Game::new(world).unwrap()
139    }
140
141    #[test]
142    fn test_runner_max_turns_zero() {
143        let game = make_game();
144        // ground simulation, no abstraction
145        let mut runner = Runner::new(&game, false, None);
146
147        // max_turns = 0: we should get exactly `runs` entries of (0,0.0)
148        let runs = 5;
149        let out = runner.run(
150            /*sim_limit=*/ 1, /*sim_depth=*/ 1, /*c=*/ 1.0, /*gamma=*/ 1.0,
151            /*seed=*/ None, /*max_turns=*/ 0, /*runs=*/ runs, /*debug=*/ false,
152            /*show_mcts=*/ false,
153        );
154        assert_eq!(out.len(), runs);
155        assert!(out.iter().all(|&(t, s, _)| t == 0 && s == 0.0));
156    }
157
158    #[test]
159    fn test_runner_one_turn_timeout() {
160        let game = make_game();
161        let mut runner = Runner::new(&game, false, None);
162
163        // max_turns = 1: in one step you can never reach the goal,
164        // so every run should return (1, 0.0).
165        let runs = 3;
166        let out = runner.run(1, 1, 1.0, 1.0, None, 1, runs, false, false);
167        assert_eq!(out.len(), runs);
168        assert!(out.iter().all(|&(t, s, _)| t == 1 && s == 0.0));
169    }
170
171    #[test]
172    fn test_runner_identity_abstraction_equivalent_to_ground() {
173        // Both with simulation high simulation limit should be able to find th end and be able to find the goal perfectly in a 3 by 3
174        let game = make_game();
175
176        // ground runner
177        let mut ground_runner = Runner::new(&game, false, None);
178        let ground_out = ground_runner.run(128, 16, 1.4, 0.85, None, 10, 1, false, false);
179
180        // abstract runner with identity abstraction (each state in its own cluster)
181        let identity_clusters: Vec<Vec<isize>> = (0..9).map(|i| vec![i]).collect();
182        let mut abs_runner = Runner::new(&game, true, Some(identity_clusters));
183        let abs_out = abs_runner.run(128, 16, 1.4, 0.85, None, 10, 1, false, false);
184
185        // should both have the same score
186        assert_eq!(ground_out, abs_out);
187    }
188}