core_rust/core/agent/
mcts.rs1use std::cell::RefCell;
2use std::rc::Rc;
3
4use crate::core::agent::node::{MCTSError, MCTSNode, NodeRef};
5use crate::core::game::*;
6use crate::core::simulation::simulator::{Simulator, SimulatorType};
7use ordered_float::Pow;
8use rand::prelude::*;
9use rand::{Rng, SeedableRng};
10use state::State;
11use utils::actions::Action;
12
13#[derive(Debug, Clone)]
15pub struct MCTSAgent<S: Simulator> {
16 simulation_limit: i32,
17 simulation_depth: i32,
18 c: f32,
19 gamma: f32,
20 index_count: i32,
21 rng: StdRng,
22 simulator: S,
23}
24
25impl<S: Simulator> MCTSAgent<S> {
26 pub fn new(
28 simulation_limit: i32,
29 simulation_depth: i32,
30 c: f32,
31 gamma: f32,
32 seed: Option<u64>,
33 simulator: S,
34 ) -> Self {
35 let rng = match seed {
36 Some(value) => StdRng::seed_from_u64(value),
37 None => StdRng::from_os_rng(),
38 };
39
40 MCTSAgent {
41 simulation_limit,
42 simulation_depth,
43 c,
44 gamma,
45 index_count: 0,
46 rng,
47 simulator,
48 }
49 }
50
51 fn choose_random_action(&mut self, state: &State) -> Action {
52 let valid_actions = state.valid_moves();
53 let idx = self.rng.random_range(0..valid_actions.len());
54 valid_actions[idx]
55 }
56
57 fn rollout(&mut self, start_node: &NodeRef, remaining_depth: i32, debug: bool) -> f32 {
59 if debug {
60 println!(
61 "Rollout node {:?} with budget {:?}",
62 start_node.borrow(),
63 remaining_depth
64 )
65 };
66
67 if start_node.borrow().is_terminal() || remaining_depth <= 0 {
68 if debug {
69 println!("ROLLOUT ABORTED: Terminal node or no budget")
70 }
71 return 0.0;
72 }
73
74 let node_state = start_node.borrow().get_state();
75 let mut total_reward: f32 = 0.0;
76
77 for depth in 1..=remaining_depth {
78 let action = self.choose_random_action(&node_state);
79 let (state, game_vars) = self.simulator.simulate(&node_state, action);
80 let reward = game_vars.score;
81
82 total_reward += (self.gamma.pow(depth - 1)) * reward;
83
84 if debug {
85 println!("Rollout step {:?}: {:?} -> {:?}", depth, action, state);
86 println!(
87 "Rollout step {:?}: R = {:?}, tot_reward += {:?}^{:?} * {:?}",
88 depth,
89 game_vars.score,
90 self.gamma,
91 depth - 1,
92 reward
93 );
94 println!(
95 "Rollout step {:?}: Total reward = {:?}",
96 depth, total_reward
97 );
98 }
99
100 if game_vars.done {
101 if debug {
102 println!(
103 "Rollout terminated at step {:?} with total_reward = {:?}",
104 depth, total_reward
105 )
106 }
107 return total_reward;
108 }
109 }
110
111 total_reward
112 }
113
114 fn expansion(&mut self, node: &NodeRef, debug: bool) -> Result<NodeRef, MCTSError> {
116 let action = {
117 let mut n = node.borrow_mut();
118 n.expand(&mut self.rng)
119 };
120 let (state_before, depth) = {
121 let n = node.borrow();
122 (n.get_state(), n.get_depth() + 1)
123 };
124
125 if debug {
126 println!(
127 "Expanding node {:?} with action {:?}",
128 node.borrow(),
129 action
130 )
131 };
132
133 let (new_state, game_vars) = self.simulator.simulate(&state_before, action);
134
135 let idx = self.index_count;
136 let child = MCTSNode::new(
137 new_state,
138 Some(Rc::downgrade(node)),
139 action,
140 depth,
141 game_vars.score,
142 game_vars.done,
143 idx,
144 self.c,
145 self.gamma,
146 );
147 let child_rc = Rc::new(RefCell::new(child));
148
149 node.borrow_mut().add_child(Rc::clone(&child_rc));
150 self.index_count += 1;
151
152 Ok(child_rc)
153 }
154
155 pub fn run(&mut self, state: State, debug: bool, show_mcts: bool) -> Action {
157 let initial_state = match self.simulator.simulator_type() {
158 SimulatorType::Ground => {
159 if debug {
160 println!("Running MCTS in ground")
161 }
162 state
163 }
164 SimulatorType::Abstract => {
165 if debug {
166 println!("Running MCTS in abstract")
167 }
168 self.simulator.get_initial_state(state)
169 }
170 };
171
172 let root = Rc::new(RefCell::new(MCTSNode::new(
173 initial_state,
174 None,
175 Action::Root,
176 0,
177 0.0,
178 false,
179 self.index_count,
180 self.c,
181 self.gamma,
182 )));
183
184 let limit = self.simulation_limit as usize;
185 for _ in 0..limit {
186 let leaf = MCTSNode::find_leaf_node(&root);
187 if debug {
188 println!("selected node: {:?}", leaf.borrow())
189 }
190
191 {
192 let depth = leaf.borrow().get_depth();
193 let remaining = self.simulation_depth - depth;
194 if remaining > 0 && !leaf.borrow().is_terminal() {
195 self.index_count += 1;
196
197 let child = self.expansion(&leaf, debug).expect("expand should succeed");
198
199 let value = self.rollout(&child, remaining - 1, debug);
200
201 MCTSNode::backpropagate(&child, value);
202 }
203 }
204 }
205
206 if debug || show_mcts {
207 MCTSNode::print_tree(&root);
208 println!("{}", "-".repeat(40));
209 }
210
211 let best_action = root.borrow().best_action();
212 best_action
213 }
214}