core_rust/core/agent/
mcts.rs

1use std::cell::RefCell;
2use std::rc::Rc;
3
4use crate::core::agent::node::{MCTSError, MCTSNode, NodeRef};
5use crate::core::game::*;
6use crate::core::simulation::simulator::{Simulator, SimulatorType};
7use ordered_float::Pow;
8use rand::prelude::*;
9use rand::{Rng, SeedableRng};
10use state::State;
11use utils::actions::Action;
12
13/// Simple Monte Carlo Tree Search agent parameterized by a `Simulator`.
14#[derive(Debug, Clone)]
15pub struct MCTSAgent<S: Simulator> {
16    simulation_limit: i32,
17    simulation_depth: i32,
18    c: f32,
19    gamma: f32,
20    index_count: i32,
21    rng: StdRng,
22    simulator: S,
23}
24
25impl<S: Simulator> MCTSAgent<S> {
26    /// Create a new agent with the given search budget and parameters.
27    pub fn new(
28        simulation_limit: i32,
29        simulation_depth: i32,
30        c: f32,
31        gamma: f32,
32        seed: Option<u64>,
33        simulator: S,
34    ) -> Self {
35        let rng = match seed {
36            Some(value) => StdRng::seed_from_u64(value),
37            None => StdRng::from_os_rng(),
38        };
39
40        MCTSAgent {
41            simulation_limit,
42            simulation_depth,
43            c,
44            gamma,
45            index_count: 0,
46            rng,
47            simulator,
48        }
49    }
50
51    fn choose_random_action(&mut self, state: &State) -> Action {
52        let valid_actions = state.valid_moves();
53        let idx = self.rng.random_range(0..valid_actions.len());
54        valid_actions[idx]
55    }
56
57    /// Perform a default-policy rollout from `start_node` for at most `remaining_depth`.
58    fn rollout(&mut self, start_node: &NodeRef, remaining_depth: i32, debug: bool) -> f32 {
59        if debug {
60            println!(
61                "Rollout node {:?} with budget {:?}",
62                start_node.borrow(),
63                remaining_depth
64            )
65        };
66
67        if start_node.borrow().is_terminal() || remaining_depth <= 0 {
68            if debug {
69                println!("ROLLOUT ABORTED: Terminal node or no budget")
70            }
71            return 0.0;
72        }
73
74        let node_state = start_node.borrow().get_state();
75        let mut total_reward: f32 = 0.0;
76
77        for depth in 1..=remaining_depth {
78            let action = self.choose_random_action(&node_state);
79            let (state, game_vars) = self.simulator.simulate(&node_state, action);
80            let reward = game_vars.score;
81
82            total_reward += (self.gamma.pow(depth - 1)) * reward;
83
84            if debug {
85                println!("Rollout step {:?}: {:?} -> {:?}", depth, action, state);
86                println!(
87                    "Rollout step {:?}: R = {:?}, tot_reward += {:?}^{:?} * {:?}",
88                    depth,
89                    game_vars.score,
90                    self.gamma,
91                    depth - 1,
92                    reward
93                );
94                println!(
95                    "Rollout step {:?}: Total reward = {:?}",
96                    depth, total_reward
97                );
98            }
99
100            if game_vars.done {
101                if debug {
102                    println!(
103                        "Rollout terminated at step {:?} with total_reward = {:?}",
104                        depth, total_reward
105                    )
106                }
107                return total_reward;
108            }
109        }
110
111        total_reward
112    }
113
114    /// Expand one untried action from `node` and return the new child.
115    fn expansion(&mut self, node: &NodeRef, debug: bool) -> Result<NodeRef, MCTSError> {
116        let action = {
117            let mut n = node.borrow_mut();
118            n.expand(&mut self.rng)
119        };
120        let (state_before, depth) = {
121            let n = node.borrow();
122            (n.get_state(), n.get_depth() + 1)
123        };
124
125        if debug {
126            println!(
127                "Expanding node {:?} with action {:?}",
128                node.borrow(),
129                action
130            )
131        };
132
133        let (new_state, game_vars) = self.simulator.simulate(&state_before, action);
134
135        let idx = self.index_count;
136        let child = MCTSNode::new(
137            new_state,
138            Some(Rc::downgrade(node)),
139            action,
140            depth,
141            game_vars.score,
142            game_vars.done,
143            idx,
144            self.c,
145            self.gamma,
146        );
147        let child_rc = Rc::new(RefCell::new(child));
148
149        node.borrow_mut().add_child(Rc::clone(&child_rc));
150        self.index_count += 1;
151
152        Ok(child_rc)
153    }
154
155    /// Run one MCTS iteration and return the best action from the root.
156    pub fn run(&mut self, state: State, debug: bool, show_mcts: bool) -> Action {
157        let initial_state = match self.simulator.simulator_type() {
158            SimulatorType::Ground => {
159                if debug {
160                    println!("Running MCTS in ground")
161                }
162                state
163            }
164            SimulatorType::Abstract => {
165                if debug {
166                    println!("Running MCTS in abstract")
167                }
168                self.simulator.get_initial_state(state)
169            }
170        };
171
172        let root = Rc::new(RefCell::new(MCTSNode::new(
173            initial_state,
174            None,
175            Action::Root,
176            0,
177            0.0,
178            false,
179            self.index_count,
180            self.c,
181            self.gamma,
182        )));
183
184        let limit = self.simulation_limit as usize;
185        for _ in 0..limit {
186            let leaf = MCTSNode::find_leaf_node(&root);
187            if debug {
188                println!("selected node: {:?}", leaf.borrow())
189            }
190
191            {
192                let depth = leaf.borrow().get_depth();
193                let remaining = self.simulation_depth - depth;
194                if remaining > 0 && !leaf.borrow().is_terminal() {
195                    self.index_count += 1;
196
197                    let child = self.expansion(&leaf, debug).expect("expand should succeed");
198
199                    let value = self.rollout(&child, remaining - 1, debug);
200
201                    MCTSNode::backpropagate(&child, value);
202                }
203            }
204        }
205
206        if debug || show_mcts {
207            MCTSNode::print_tree(&root);
208            println!("{}", "-".repeat(40));
209        }
210
211        let best_action = root.borrow().best_action();
212        best_action
213    }
214}