1#![allow(clippy::too_many_arguments, clippy::type_complexity, deprecated)]
13
14use ordered_float::Pow;
15use pyo3::exceptions::{PyRuntimeError, PyValueError};
16use pyo3::prelude::*;
17use pyo3::types::PyDict;
18pub mod core;
19use crate::core::abstraction::homomorphism::get_abstraction;
20use crate::core::game::game_logic::Game;
21use crate::core::runner::Runner;
22use crate::core::utils::min_turns::min_turns_to_finish;
23use crate::core::utils::plotting::{draw_abstraction, draw_world};
24use core::abstraction::homomorphism::get_all_states;
25use core::game::utils::actions::Action;
26use core::utils::matrices::build_matrices;
27use core::utils::representation::generate_representations;
28use std::path::Path;
29
30fn py_to_world(py_world: Vec<Vec<String>>) -> PyResult<Vec<Vec<char>>> {
32 py_world
33 .into_iter()
34 .map(|row| {
35 row.into_iter()
36 .map(|s| {
37 let mut chars = s.chars();
39 if let (Some(ch), None) = (chars.next(), chars.next()) {
40 Ok(ch)
41 } else {
42 Err(PyValueError::new_err(
43 "Each string must be exactly one character",
44 ))
45 }
46 })
47 .collect()
48 })
49 .collect()
50}
51
52#[pyclass]
54pub struct PyRunner {
55 inner: Runner,
56}
57
58#[pymethods]
59impl PyRunner {
60 #[new]
61 pub fn new(
72 py_world: Vec<Vec<String>>,
73 abstracted: bool,
74 py_abstraction: Option<Vec<Vec<isize>>>,
75 ) -> PyResult<Self> {
76 let world = py_to_world(py_world)?;
77 let game = Game::new(world).map_err(|e| PyRuntimeError::new_err(format!("{:?}", e)))?;
78 let runner = Runner::new(&game, abstracted, py_abstraction);
79 Ok(PyRunner { inner: runner })
80 }
81
82 pub fn run(
84 &mut self,
85 sim_limit: i32,
86 sim_depth: i32,
87 c: f32,
88 gamma: f32,
89 seed: Option<u64>,
90 max_turns: i32,
91 runs: usize,
92 debug: bool,
93 show_mcts: bool,
94 ) -> PyResult<Vec<(i32, f32, (usize, usize))>> {
95 Ok(self.inner.run(
96 sim_limit, sim_depth, c, gamma, seed, max_turns, runs, debug, show_mcts,
97 ))
98 }
99}
100
101#[pyfunction]
103fn max_returns(py_world: Vec<Vec<String>>, gamma: f32) -> PyResult<f32> {
104 let world = py_to_world(py_world)?;
105 let game = Game::new(world).map_err(|e| PyRuntimeError::new_err(format!("{:?}", e)))?;
106
107 let min_turns =
108 min_turns_to_finish(&game).map_err(|e| PyRuntimeError::new_err(format!("{:?}", e)))? as i32;
109
110 let max_returns = gamma.pow(min_turns);
111
112 Ok(max_returns)
113}
114
115#[pyfunction]
117fn min_turns(py_world: Vec<Vec<String>>) -> PyResult<usize> {
118 let world = py_to_world(py_world)?;
119 let game = Game::new(world).map_err(|e| PyRuntimeError::new_err(format!("{:?}", e)))?;
120 let min_turns =
121 min_turns_to_finish(&game).map_err(|e| PyRuntimeError::new_err(format!("{:?}", e)))?;
122
123 Ok(min_turns)
124}
125
126#[pyfunction]
128fn visualize_world_map(py_world: Vec<Vec<String>>, output_dir: &str) -> PyResult<()> {
129 let world = py_to_world(py_world)?;
130 let world_size = world.len() as u32;
131
132 let cell_size = 500 / world_size;
134
135 let dir = Path::new(output_dir);
136 let out_file = dir.join("map.png");
137 let out_path_str = out_file
138 .to_str()
139 .ok_or_else(|| PyRuntimeError::new_err("Invalid output path"))?;
140
141 draw_world(&world, out_path_str, cell_size)
142 .map_err(|_| PyRuntimeError::new_err("Plotting error"))?;
143 println!("Saved world visualization to: {:?}", out_path_str);
144
145 Ok(())
146}
147
148#[pyfunction]
151fn visualize_abstraction(py_world: Vec<Vec<String>>, output_dir: &str) -> PyResult<()> {
152 let world = py_to_world(py_world)?;
153 let world_size = world.len() as u32;
154
155 let cell_size = 500 / world_size;
157
158 let dir = Path::new(output_dir);
159 let out_file = dir.join("abstraction.png");
160 let out_path_str = out_file
161 .to_str()
162 .ok_or_else(|| PyRuntimeError::new_err("Invalid output path"))?;
163
164 let game = Game::new(world.clone())
165 .map_err(|e| PyRuntimeError::new_err(format!("abstraction failed: {:?}", e)))?;
166 let (states, clusters) =
167 get_abstraction(&game).map_err(|_| PyRuntimeError::new_err("Failed to get abstraction"))?;
168
169 draw_abstraction(&world, &states, &clusters, out_path_str, cell_size)
170 .map_err(|_| PyRuntimeError::new_err("Plotting error"))?;
171 println!("Saved abstraction to: {:?}", out_path_str);
172
173 Ok(())
174}
175
176#[pyfunction]
178fn generate_representations_py(py: Python, py_world: Vec<Vec<String>>) -> PyResult<PyObject> {
179 let world = py_world
181 .into_iter()
182 .map(|row| row.into_iter().map(|s| s.chars().next().unwrap()).collect())
183 .collect();
184 let mut game =
185 Game::new(world).map_err(|e| PyRuntimeError::new_err(format!("invalid world: {:?}", e)))?;
186 let (js, txt, adj) = generate_representations(&mut game);
188 let json_str = serde_json::to_string(&js)
190 .map_err(|e| PyRuntimeError::new_err(format!("json serialization error: {}", e)))?;
191 let adj_str = serde_json::to_string(&adj)
192 .map_err(|e| PyRuntimeError::new_err(format!("adj serialization error: {}", e)))?;
193 let dict = PyDict::new(py);
195 dict.set_item("json", json_str)?;
196 dict.set_item("text", txt)?;
197 dict.set_item("adj", adj_str)?;
198 Ok(dict.into_py(py))
200}
201
202#[pyfunction]
204fn generate_mdp(py: Python<'_>, py_world: Vec<Vec<String>>) -> PyResult<PyObject> {
205 let world = py_to_world(py_world)?;
206
207 let game =
208 Game::new(world).map_err(|e| PyRuntimeError::new_err(format!("invalid world: {:?}", e)))?;
209
210 let (states, clusters) = get_abstraction(&game)
211 .map_err(|e| PyRuntimeError::new_err(format!("abstraction failed: {:?}", e)))?;
212 let actions = [Action::Up, Action::Down, Action::Left, Action::Right];
213
214 let (t, r) = build_matrices(&game, &states, &actions);
215
216 let dict = PyDict::new(py);
217 dict.set_item("T", t.clone().into_py(py))?;
218 dict.set_item("R", r.clone().into_py(py))?;
219 dict.set_item("abstraction", clusters.clone().into_py(py))?;
220
221 Ok(dict.into_py(py))
222}
223
224#[pyfunction]
226fn get_number_of_states(py_world: Vec<Vec<String>>) -> PyResult<usize> {
227 let world = py_to_world(py_world)?;
228 let game =
229 Game::new(world).map_err(|e| PyRuntimeError::new_err(format!("invalid world: {:?}", e)))?;
230
231 let all_states = get_all_states(&game)
232 .map_err(|e| PyRuntimeError::new_err(format!("abstraction failed: {:?}", e)))?;
233 let num_states = all_states.len();
234
235 Ok(num_states)
236}
237
238#[pymodule]
240fn core_rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
241 m.add_class::<PyRunner>()?;
242
243 m.add_function(wrap_pyfunction!(max_returns, m)?)?;
244
245 m.add_function(wrap_pyfunction!(min_turns, m)?)?;
246
247 m.add_function(wrap_pyfunction!(visualize_world_map, m)?)?;
248
249 m.add_function(wrap_pyfunction!(visualize_abstraction, m)?)?;
250
251 m.add_function(wrap_pyfunction!(generate_representations_py, m)?)?;
252
253 m.add_function(wrap_pyfunction!(generate_mdp, m)?)?;
254
255 m.add_function(wrap_pyfunction!(get_number_of_states, m)?)?;
256
257 Ok(())
258}