core_rust/
lib.rs

1//! Python bindings for the core GridWorld + MCTS abstraction crate.
2//!
3//! This crate exposes a minimal, typed interface to Python via pyo3. It allows
4//! you to:
5//! - Compute optimal discounted returns and shortest paths for a world.
6//! - Visualize worlds and learned abstractions.
7//! - Run MCTS either in the ground MDP or in an abstracted MDP.
8//! - Build transition/reward matrices and enumerate state counts.
9//!
10//! The Rust types and modules remain available for native Rust use under
11//! `crate::core`. The Python module installs under the name `core_rust`.
12#![allow(clippy::too_many_arguments, clippy::type_complexity, deprecated)]
13
14use ordered_float::Pow;
15use pyo3::exceptions::{PyRuntimeError, PyValueError};
16use pyo3::prelude::*;
17use pyo3::types::PyDict;
18pub mod core;
19use crate::core::abstraction::homomorphism::get_abstraction;
20use crate::core::game::game_logic::Game;
21use crate::core::runner::Runner;
22use crate::core::utils::min_turns::min_turns_to_finish;
23use crate::core::utils::plotting::{draw_abstraction, draw_world};
24use core::abstraction::homomorphism::get_all_states;
25use core::game::utils::actions::Action;
26use core::utils::matrices::build_matrices;
27use core::utils::representation::generate_representations;
28use std::path::Path;
29
30/// Convert a Python ``list[list[str]]`` map into a Rust grid of chars.
31fn py_to_world(py_world: Vec<Vec<String>>) -> PyResult<Vec<Vec<char>>> {
32    py_world
33        .into_iter()
34        .map(|row| {
35            row.into_iter()
36                .map(|s| {
37                    // if you want to allow only single-character strings:
38                    let mut chars = s.chars();
39                    if let (Some(ch), None) = (chars.next(), chars.next()) {
40                        Ok(ch)
41                    } else {
42                        Err(PyValueError::new_err(
43                            "Each string must be exactly one character",
44                        ))
45                    }
46                })
47                .collect()
48        })
49        .collect()
50}
51
52/// Python wrapper for the Rust `Runner`, capable of executing MCTS episodes.
53#[pyclass]
54pub struct PyRunner {
55    inner: Runner,
56}
57
58#[pymethods]
59impl PyRunner {
60    #[new]
61    /// Create a new `PyRunner`.
62    ///
63    /// Parameters
64    /// ----------
65    /// py_world
66    ///     2D map of single-character strings.
67    /// abstracted
68    ///     If true, run in the abstract MDP; otherwise run in the ground MDP.
69    /// py_abstraction
70    ///     Optional custom abstraction (clusters of ground-state IDs).
71    pub fn new(
72        py_world: Vec<Vec<String>>,
73        abstracted: bool,
74        py_abstraction: Option<Vec<Vec<isize>>>,
75    ) -> PyResult<Self> {
76        let world = py_to_world(py_world)?;
77        let game = Game::new(world).map_err(|e| PyRuntimeError::new_err(format!("{:?}", e)))?;
78        let runner = Runner::new(&game, abstracted, py_abstraction);
79        Ok(PyRunner { inner: runner })
80    }
81
82    /// Run `runs` episodes of MCTS and return per-episode results.
83    pub fn run(
84        &mut self,
85        sim_limit: i32,
86        sim_depth: i32,
87        c: f32,
88        gamma: f32,
89        seed: Option<u64>,
90        max_turns: i32,
91        runs: usize,
92        debug: bool,
93        show_mcts: bool,
94    ) -> PyResult<Vec<(i32, f32, (usize, usize))>> {
95        Ok(self.inner.run(
96            sim_limit, sim_depth, c, gamma, seed, max_turns, runs, debug, show_mcts,
97        ))
98    }
99}
100
101/// Compute the maximum achievable discounted return from the initial state.
102#[pyfunction]
103fn max_returns(py_world: Vec<Vec<String>>, gamma: f32) -> PyResult<f32> {
104    let world = py_to_world(py_world)?;
105    let game = Game::new(world).map_err(|e| PyRuntimeError::new_err(format!("{:?}", e)))?;
106
107    let min_turns =
108        min_turns_to_finish(&game).map_err(|e| PyRuntimeError::new_err(format!("{:?}", e)))? as i32;
109
110    let max_returns = gamma.pow(min_turns);
111
112    Ok(max_returns)
113}
114
115/// Compute the minimum number of turns to reach the goal.
116#[pyfunction]
117fn min_turns(py_world: Vec<Vec<String>>) -> PyResult<usize> {
118    let world = py_to_world(py_world)?;
119    let game = Game::new(world).map_err(|e| PyRuntimeError::new_err(format!("{:?}", e)))?;
120    let min_turns =
121        min_turns_to_finish(&game).map_err(|e| PyRuntimeError::new_err(format!("{:?}", e)))?;
122
123    Ok(min_turns)
124}
125
126/// Save a rasterized visualization of the map to ``<output_dir>/map.png``.
127#[pyfunction]
128fn visualize_world_map(py_world: Vec<Vec<String>>, output_dir: &str) -> PyResult<()> {
129    let world = py_to_world(py_world)?;
130    let world_size = world.len() as u32;
131
132    // Render a 500×500 map, computing cell size from dimensions
133    let cell_size = 500 / world_size;
134
135    let dir = Path::new(output_dir);
136    let out_file = dir.join("map.png");
137    let out_path_str = out_file
138        .to_str()
139        .ok_or_else(|| PyRuntimeError::new_err("Invalid output path"))?;
140
141    draw_world(&world, out_path_str, cell_size)
142        .map_err(|_| PyRuntimeError::new_err("Plotting error"))?;
143    println!("Saved world visualization to: {:?}", out_path_str);
144
145    Ok(())
146}
147
148/// Save a rasterized visualization of the learned abstraction to
149/// ``<output_dir>/abstraction.png``.
150#[pyfunction]
151fn visualize_abstraction(py_world: Vec<Vec<String>>, output_dir: &str) -> PyResult<()> {
152    let world = py_to_world(py_world)?;
153    let world_size = world.len() as u32;
154
155    // Render a 500×500 abstraction, computing cell size from dimensions
156    let cell_size = 500 / world_size;
157
158    let dir = Path::new(output_dir);
159    let out_file = dir.join("abstraction.png");
160    let out_path_str = out_file
161        .to_str()
162        .ok_or_else(|| PyRuntimeError::new_err("Invalid output path"))?;
163
164    let game = Game::new(world.clone())
165        .map_err(|e| PyRuntimeError::new_err(format!("abstraction failed: {:?}", e)))?;
166    let (states, clusters) =
167        get_abstraction(&game).map_err(|_| PyRuntimeError::new_err("Failed to get abstraction"))?;
168
169    draw_abstraction(&world, &states, &clusters, out_path_str, cell_size)
170        .map_err(|_| PyRuntimeError::new_err("Plotting error"))?;
171    println!("Saved abstraction to: {:?}", out_path_str);
172
173    Ok(())
174}
175
176/// Generate multiple textual/JSON representations of the map for prompting.
177#[pyfunction]
178fn generate_representations_py(py: Python, py_world: Vec<Vec<String>>) -> PyResult<PyObject> {
179    // Reconstruct the Game
180    let world = py_world
181        .into_iter()
182        .map(|row| row.into_iter().map(|s| s.chars().next().unwrap()).collect())
183        .collect();
184    let mut game =
185        Game::new(world).map_err(|e| PyRuntimeError::new_err(format!("invalid world: {:?}", e)))?;
186    // Generate representations
187    let (js, txt, adj) = generate_representations(&mut game);
188    // Convert JSON values to strings for Python
189    let json_str = serde_json::to_string(&js)
190        .map_err(|e| PyRuntimeError::new_err(format!("json serialization error: {}", e)))?;
191    let adj_str = serde_json::to_string(&adj)
192        .map_err(|e| PyRuntimeError::new_err(format!("adj serialization error: {}", e)))?;
193    // Build a Python dict
194    let dict = PyDict::new(py);
195    dict.set_item("json", json_str)?;
196    dict.set_item("text", txt)?;
197    dict.set_item("adj", adj_str)?;
198    // Return it as a PyObject
199    Ok(dict.into_py(py))
200}
201
202/// Build transition and reward matrices along with the learned abstraction.
203#[pyfunction]
204fn generate_mdp(py: Python<'_>, py_world: Vec<Vec<String>>) -> PyResult<PyObject> {
205    let world = py_to_world(py_world)?;
206
207    let game =
208        Game::new(world).map_err(|e| PyRuntimeError::new_err(format!("invalid world: {:?}", e)))?;
209
210    let (states, clusters) = get_abstraction(&game)
211        .map_err(|e| PyRuntimeError::new_err(format!("abstraction failed: {:?}", e)))?;
212    let actions = [Action::Up, Action::Down, Action::Left, Action::Right];
213
214    let (t, r) = build_matrices(&game, &states, &actions);
215
216    let dict = PyDict::new(py);
217    dict.set_item("T", t.clone().into_py(py))?;
218    dict.set_item("R", r.clone().into_py(py))?;
219    dict.set_item("abstraction", clusters.clone().into_py(py))?;
220
221    Ok(dict.into_py(py))
222}
223
224/// Return the total number of reachable ground states in the map.
225#[pyfunction]
226fn get_number_of_states(py_world: Vec<Vec<String>>) -> PyResult<usize> {
227    let world = py_to_world(py_world)?;
228    let game =
229        Game::new(world).map_err(|e| PyRuntimeError::new_err(format!("invalid world: {:?}", e)))?;
230
231    let all_states = get_all_states(&game)
232        .map_err(|e| PyRuntimeError::new_err(format!("abstraction failed: {:?}", e)))?;
233    let num_states = all_states.len();
234
235    Ok(num_states)
236}
237
238/// Python module initializer for `core_rust`.
239#[pymodule]
240fn core_rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
241    m.add_class::<PyRunner>()?;
242
243    m.add_function(wrap_pyfunction!(max_returns, m)?)?;
244
245    m.add_function(wrap_pyfunction!(min_turns, m)?)?;
246
247    m.add_function(wrap_pyfunction!(visualize_world_map, m)?)?;
248
249    m.add_function(wrap_pyfunction!(visualize_abstraction, m)?)?;
250
251    m.add_function(wrap_pyfunction!(generate_representations_py, m)?)?;
252
253    m.add_function(wrap_pyfunction!(generate_mdp, m)?)?;
254
255    m.add_function(wrap_pyfunction!(get_number_of_states, m)?)?;
256
257    Ok(())
258}