core_rust/core/simulation/
mapper.rs

1use crate::core::abstraction::*;
2use crate::core::game::*;
3use std::collections::HashMap;
4use std::collections::HashSet;
5
6use errors::AbstractionError;
7use game_logic::Game;
8use homomorphism::{get_abstraction, get_all_states};
9
10use state::State;
11use utils::actions::Action;
12
13/// Bidirectional mappings between ground and abstract MDPs.
14#[derive(Debug, Clone)]
15pub struct Mapper {
16    all_ground_states: Vec<State>,
17    abstraction: Vec<Vec<isize>>,
18    _abstract_transition_map: HashMap<(isize, Action), usize>,
19    abstract_action_map: HashMap<(isize, Action), Action>,
20    abstract_to_ground_map: HashMap<(usize, Action), (isize, Action)>,
21}
22
23impl Mapper {
24    /// Build a mapper either from a supplied abstraction or a learned one.
25    pub fn new(
26        game: &Game,
27        abstraction: Option<Vec<Vec<isize>>>,
28    ) -> Result<Self, AbstractionError> {
29        let (ground_states, abstraction) = match abstraction {
30            Some(abstraction) => {
31                let states = get_all_states(game).map_err(|e| AbstractionError::Computation {
32                    error: e.to_string(),
33                })?;
34                (states, abstraction)
35            }
36            None => {
37                let (states, abstraction) =
38                    get_abstraction(game).map_err(|e| AbstractionError::Computation {
39                        error: e.to_string(),
40                    })?;
41                (states, abstraction)
42            }
43        };
44
45        let transition_map: HashMap<(isize, Action), usize> =
46            Mapper::build_abstract_transition_map(game, &ground_states, &abstraction)?;
47
48        let (abstract_action_map, abstract_to_ground_map) =
49            Self::build_action_maps(&transition_map, &abstraction)?;
50        // println!("actions: {:?}\n", abstract_action_map);
51
52        Ok(Mapper {
53            all_ground_states: ground_states,
54            abstraction,
55            _abstract_transition_map: transition_map,
56            abstract_action_map,
57            abstract_to_ground_map,
58        })
59    }
60
61    fn get_ground_id(state: &State, all_states: &[State]) -> isize {
62        let state_id = all_states
63            .iter()
64            .find(|&ground_state| ground_state.unit_position == state.unit_position)
65            .expect("No matching state")
66            .index
67            .expect("Gound state list did not generate with ids");
68        state_id
69    }
70
71    fn build_abstract_transition_map(
72        game: &Game,
73        ground_states: &[State],
74        abstraction: &[Vec<isize>],
75    ) -> Result<HashMap<(isize, Action), usize>, AbstractionError> {
76        let mut abstract_transition_map: HashMap<(isize, Action), usize> = HashMap::new();
77        let ground_actions = [Action::Up, Action::Down, Action::Left, Action::Right];
78
79        for ground_state in ground_states.iter().cloned() {
80            let ground_state_index = match ground_state.index {
81                Some(index) => index,
82                None => Mapper::get_ground_id(&ground_state, ground_states),
83            };
84            for ground_action in ground_actions {
85                let (next_ground_state, _) =
86                    game.simulate(&ground_state, &ground_action).map_err(|e| {
87                        AbstractionError::Computation {
88                            error: e.to_string(),
89                        }
90                    })?;
91                let next_ground_state_id = Mapper::get_ground_id(&next_ground_state, ground_states);
92                let next_abtract_state_id = abstraction
93                    .iter()
94                    .position(|cluster| cluster.contains(&next_ground_state_id))
95                    .expect("ground state id not found in abstraction");
96                abstract_transition_map
97                    .insert((ground_state_index, ground_action), next_abtract_state_id);
98            }
99        }
100
101        Ok(abstract_transition_map)
102    }
103
104    /// Construct forward (ground→abstract) and reverse (abstract→ground) action maps.
105    fn build_action_maps(
106        transition_map: &HashMap<(isize, Action), usize>,
107        abstraction: &[Vec<isize>],
108    ) -> Result<
109        (
110            HashMap<(isize, Action), Action>,
111            HashMap<(usize, Action), (isize, Action)>,
112        ),
113        AbstractionError,
114    > {
115        let mut ga2aa = HashMap::new();
116        let mut aa2ga = HashMap::new();
117        let ground_actions = [Action::Up, Action::Down, Action::Left, Action::Right];
118
119        for (src_abs, cluster) in abstraction.iter().enumerate() {
120            // Group all (ground_state, ground_action) pairs by their destination abstract state
121            let mut buckets: Vec<(usize, Vec<(isize, Action)>)> = Vec::new();
122            for &gs in cluster {
123                for &ga in &ground_actions {
124                    let next_abs = transition_map[&(gs, ga)];
125                    if let Some((_, ref mut vec)) = buckets.iter_mut().find(|(k, _)| *k == next_abs)
126                    {
127                        vec.push((gs, ga));
128                    } else {
129                        buckets.push((next_abs, vec![(gs, ga)]));
130                    }
131                }
132            }
133
134            // Assign abstract-action IDs in the order of discovered destinations
135            let mut next_aa = 5; // AbstractAction1 == 5
136            for (_next_abs, pairs) in buckets {
137                let aa = Action::from_id(next_aa).map_err(|e| AbstractionError::Computation {
138                    error: e.to_string(),
139                })?;
140                next_aa += 1;
141
142                // Pick a representative pair where the ground-state id is minimal
143                let &(rep_gs, rep_ga) = pairs
144                    .iter()
145                    .min_by_key(|(gs, _)| *gs)
146                    .expect("bucket never empty");
147
148                // record the reverse mapping once
149                aa2ga.entry((src_abs, aa)).or_insert((rep_gs, rep_ga));
150
151                // lift that same aa onto *every* pair in this bucket
152                for (gs, ga) in pairs {
153                    ga2aa.insert((gs, ga), aa);
154                }
155            }
156        }
157
158        Ok((ga2aa, aa2ga))
159    }
160
161    pub fn ground_state_to_abstract(&self, state: &State) -> State {
162        let ground_state_id = Mapper::get_ground_id(state, &self.all_ground_states);
163
164        let abstract_id = self
165            .abstraction
166            .iter()
167            .position(|cluster| cluster.contains(&ground_state_id))
168            .expect("ground state not in any cluster") as isize;
169
170        let mut set = HashSet::new();
171        for &ground_action in state.valid_moves().iter() {
172            if let Some(&abstract_action) = self
173                .abstract_action_map
174                .get(&(ground_state_id, ground_action))
175            {
176                set.insert(abstract_action);
177            }
178        }
179        let valid_abstract_moves: Vec<Action> = set.into_iter().collect();
180
181        let mut abstract_state = State::new(state.unit_position, valid_abstract_moves);
182        abstract_state.index = Some(abstract_id);
183        abstract_state
184    }
185
186    pub fn abstract_state_action_to_ground(
187        &self,
188        state: &State,
189        action: Action,
190    ) -> (State, Action) {
191        let abs_id = state.index.expect("abstract state needs an index") as usize;
192        let (gs, ga) = if let Some(&(gs, ga)) = self.abstract_to_ground_map.get(&(abs_id, action)) {
193            (gs, ga)
194        } else {
195            // fallback: pick the cluster representative and loop in place
196            let cluster = &self.abstraction[abs_id];
197            let &rep_gs = cluster
198                .iter()
199                .min()
200                .expect("cluster should have at least one state");
201
202            // scan in fixed order for a move that stays in this abstract state
203            let ground_moves = [Action::Up, Action::Down, Action::Left, Action::Right];
204            let ga = ground_moves
205                .iter()
206                .copied()
207                .find(|&ga| {
208                    self._abstract_transition_map
209                        .get(&(rep_gs, ga))
210                        .copied()
211                        .expect("every (gs,ga) must be in transition_map")
212                        == abs_id
213                })
214                .expect("no looping move found for abstract state");
215
216            (rep_gs, ga)
217        };
218
219        // Build the new ground‐state
220        let mut ground_state = self.all_ground_states[gs as usize].clone();
221        ground_state.index = Some(gs);
222        (ground_state, ga)
223    }
224}
225
226mod tests {
227    use super::*;
228
229    #[allow(dead_code)]
230    fn make_game() -> Game {
231        let world = vec![
232            vec!['.', '.', '.'],
233            vec!['.', '.', '.'],
234            vec!['.', '.', 'G'],
235        ];
236
237        Game::new(world).expect("failed to build test game")
238    }
239
240    #[test]
241    fn test_mapper_init_no_abstraction() {
242        let game = make_game();
243
244        let mapper = Mapper::new(&game, None).unwrap();
245
246        let expected = vec![
247            vec![0],
248            vec![1, 2],
249            vec![3, 5],
250            vec![4],
251            vec![6, 7],
252            vec![8],
253        ];
254        assert_eq!(mapper.abstraction, expected);
255    }
256
257    #[test]
258    fn test_mapper_init_with_abstraction() {
259        let game = make_game();
260
261        let supplied = vec![
262            vec![0],
263            vec![1, 2],
264            vec![3],
265            vec![4],
266            vec![5],
267            vec![6, 7],
268            vec![8],
269        ];
270        let mapper = Mapper::new(&game, Some(supplied.clone())).unwrap();
271        assert_eq!(mapper.abstraction, supplied);
272    }
273
274    #[test]
275    fn test_all_states() {
276        let game = make_game();
277
278        let all_states = get_all_states(&game).unwrap();
279        assert_eq!(all_states.len(), 9)
280    }
281
282    #[test]
283    fn test_state_mapping() {
284        let game = make_game();
285
286        let all_states = get_all_states(&game).unwrap();
287
288        let mapper = Mapper::new(&game, None).unwrap();
289
290        for (i, state) in all_states.iter().enumerate() {
291            let ground_id = i;
292            assert_eq!(ground_id, i);
293
294            let abstract_idx = mapper
295                .ground_state_to_abstract(&state.clone())
296                .index
297                .unwrap() as usize;
298
299            let want_idx = mapper
300                .abstraction
301                .iter()
302                .position(|cluster| cluster.contains(&(ground_id as isize)))
303                .unwrap();
304            assert_eq!(abstract_idx, want_idx);
305        }
306    }
307
308    #[test]
309    fn test_abstract_action_transitions() {
310        // build the 3×3 world with goal in bottom-right
311        let world = vec![
312            vec!['.', '.', '.'],
313            vec!['.', '.', '.'],
314            vec!['.', '.', 'G'],
315        ];
316        let game = Game::new(world).unwrap();
317        let mapper = Mapper::new(&game, None).unwrap();
318
319        // (initial_abs_state, abstract_action, expected_abs_state)
320        let cases = vec![
321            (0, Action::AbstractAction1, 0),
322            (0, Action::AbstractAction2, 1),
323            (1, Action::AbstractAction1, 0),
324            (1, Action::AbstractAction2, 2),
325            (1, Action::AbstractAction3, 1),
326            (1, Action::AbstractAction4, 3),
327            (2, Action::AbstractAction1, 1),
328            (2, Action::AbstractAction2, 2),
329            (2, Action::AbstractAction3, 4),
330            (3, Action::AbstractAction1, 1),
331            (3, Action::AbstractAction2, 4),
332            (4, Action::AbstractAction1, 3),
333            (4, Action::AbstractAction2, 4),
334            (4, Action::AbstractAction3, 2),
335            (4, Action::AbstractAction4, 5),
336            (5, Action::AbstractAction2, 5),
337        ];
338
339        for (init_abs, aa, want_abs) in cases {
340            // pick the “representative” ground state from that cluster
341            let gs_id = mapper.abstraction[init_abs][0] as usize;
342            let ground_state = &mapper.all_ground_states[gs_id];
343
344            // embed ground → abstract
345            let abstract_state = mapper.ground_state_to_abstract(ground_state);
346
347            // abstract action → ground action
348            let (_sel_gs, ground_action) =
349                mapper.abstract_state_action_to_ground(&abstract_state, aa);
350
351            // simulate one step at the ground level
352            let (new_ground_state, _) = game.simulate(ground_state, &ground_action).unwrap();
353
354            // map back up
355            let new_abs_state = mapper.ground_state_to_abstract(&new_ground_state);
356            let got_abs = new_abs_state.index.unwrap() as usize;
357
358            assert_eq!(
359                got_abs, want_abs,
360                "From abstract state {} via {:?}, expected {} but got {}",
361                init_abs, aa, want_abs, got_abs
362            );
363        }
364    }
365
366    #[test]
367    fn test_abstract_action_transitions_4x4() {
368        // build the 4×4 world with goal at (3,3)
369        let world = vec![
370            vec!['.', '.', '.', '.'],
371            vec!['.', '.', '.', '.'],
372            vec!['.', '.', '.', '.'],
373            vec!['.', '.', '.', 'G'],
374        ];
375        let game = Game::new(world).unwrap();
376        let mapper = Mapper::new(&game, None).unwrap();
377
378        // (initial_abs_state, abstract_action, expected_abs_state)
379        let cases = vec![
380            (0, Action::AbstractAction1, 0),
381            (0, Action::AbstractAction2, 1),
382            (1, Action::AbstractAction1, 0),
383            (1, Action::AbstractAction2, 2),
384            (1, Action::AbstractAction3, 1),
385            (1, Action::AbstractAction4, 3),
386            (2, Action::AbstractAction1, 1),
387            (2, Action::AbstractAction2, 4),
388            (2, Action::AbstractAction3, 2),
389            (2, Action::AbstractAction4, 5),
390            (3, Action::AbstractAction1, 1),
391            (3, Action::AbstractAction2, 5),
392            (4, Action::AbstractAction1, 2),
393            (4, Action::AbstractAction2, 4),
394            (4, Action::AbstractAction3, 6),
395            (5, Action::AbstractAction1, 3),
396            (5, Action::AbstractAction2, 6),
397            (5, Action::AbstractAction3, 2),
398            (5, Action::AbstractAction4, 7),
399            (6, Action::AbstractAction1, 5),
400            (6, Action::AbstractAction2, 6),
401            (6, Action::AbstractAction3, 4),
402            (6, Action::AbstractAction4, 8),
403            (7, Action::AbstractAction1, 5),
404            (7, Action::AbstractAction2, 8),
405            (8, Action::AbstractAction1, 7),
406            (8, Action::AbstractAction2, 8),
407            (8, Action::AbstractAction3, 6),
408            (8, Action::AbstractAction4, 9),
409            (9, Action::AbstractAction2, 9),
410        ];
411
412        for (init_abs, aa, want_abs) in cases {
413            // pick the first ground-state in that cluster
414            let gs_id = mapper.abstraction[init_abs][0] as usize;
415            let ground_state = &mapper.all_ground_states[gs_id];
416
417            // embed to abstract
418            let abstract_state = mapper.ground_state_to_abstract(ground_state);
419
420            // choose ground action from abstract action
421            let (_sel_gs, ground_action) =
422                mapper.abstract_state_action_to_ground(&abstract_state, aa);
423
424            // simulate one step at the ground level
425            let (new_ground_state, _) = game.simulate(ground_state, &ground_action).unwrap();
426
427            // map back up
428            let new_abs_state = mapper.ground_state_to_abstract(&new_ground_state);
429            let got_abs = new_abs_state.index.unwrap() as usize;
430
431            assert_eq!(
432                got_abs, want_abs,
433                "From abstract state {} via {:?}, expected {} but got {}",
434                init_abs, aa, want_abs, got_abs
435            );
436        }
437    }
438
439    #[test]
440    fn test_abstract_to_ground_uses_minimum_representative_3() {
441        let game = make_game();
442        let mapper = Mapper::new(&game, None).expect("failed to build Mapper");
443
444        // pull out both the full ground‐state list and the clusters
445        let (all_states, clusters) = get_abstraction(&game).expect("homomorphism failed");
446
447        // for each abstract‐state cluster ...
448        for (abs_id, cluster) in clusters.iter().enumerate() {
449            // cluster is sorted, so the first element is the minimum ground‐id
450            let rep_ground_id = cluster[0];
451
452            // get the abstract State object by starting from that representative
453            let rep_state = &all_states[rep_ground_id as usize];
454            let abs_state = mapper.ground_state_to_abstract(rep_state);
455            assert_eq!(
456                abs_state.index.unwrap() as usize,
457                abs_id,
458                "representative state had wrong abstract index"
459            );
460
461            // for every abstract‐action available in this abstract‐state ...
462            for &abs_action in abs_state.valid_moves().iter() {
463                // map back to a ground‐state & ground‐action
464                let (gs, _ga) = mapper.abstract_state_action_to_ground(&abs_state, abs_action);
465                let mapped_id = Mapper::get_ground_id(&gs, &all_states);
466
467                // **this** must equal the minimal ground_id for the cluster
468                assert_eq!(
469                    mapped_id, rep_ground_id,
470                    "abstract state {abs_id}, action {abs_action:?} mapped back to ground \
471                     {mapped_id} but expected the minimal representative {rep_ground_id}"
472                );
473            }
474        }
475    }
476
477    #[test]
478    fn test_abstract_to_ground_uses_minimum_representative_4() {
479        let world = vec![
480            vec!['.', '.', '.', '.'],
481            vec!['.', '.', '.', '.'],
482            vec!['.', '.', '.', '.'],
483            vec!['.', '.', '.', 'G'],
484        ];
485
486        let game = Game::new(world).unwrap();
487        let mapper = Mapper::new(&game, None).expect("failed to build Mapper");
488
489        // pull out both the full ground‐state list and the clusters
490        let (all_states, clusters) = get_abstraction(&game).expect("homomorphism failed");
491
492        // for each abstract‐state cluster ...
493        for (abs_id, cluster) in clusters.iter().enumerate() {
494            // cluster is sorted, so the first element is the minimum ground‐id
495            let rep_ground_id = cluster[0];
496
497            // get the abstract State object by starting from that representative
498            let rep_state = &all_states[rep_ground_id as usize];
499            let abs_state = mapper.ground_state_to_abstract(rep_state);
500            assert_eq!(
501                abs_state.index.unwrap() as usize,
502                abs_id,
503                "representative state had wrong abstract index"
504            );
505
506            // for every abstract‐action available in this abstract‐state ...
507            for &abs_action in abs_state.valid_moves().iter() {
508                // map back to a ground‐state & ground‐action
509                let (gs, _ga) = mapper.abstract_state_action_to_ground(&abs_state, abs_action);
510                let mapped_id = Mapper::get_ground_id(&gs, &all_states);
511
512                // **this** must equal the minimal ground_id for the cluster
513                assert_eq!(
514                    mapped_id, rep_ground_id,
515                    "abstract state {abs_id}, action {abs_action:?} mapped back to ground \
516                     {mapped_id} but expected the minimal representative {rep_ground_id}"
517                );
518            }
519        }
520    }
521}