1use crate::core::abstraction::*;
2use crate::core::game::*;
3use std::collections::HashMap;
4use std::collections::HashSet;
5
6use errors::AbstractionError;
7use game_logic::Game;
8use homomorphism::{get_abstraction, get_all_states};
9
10use state::State;
11use utils::actions::Action;
12
13#[derive(Debug, Clone)]
15pub struct Mapper {
16 all_ground_states: Vec<State>,
17 abstraction: Vec<Vec<isize>>,
18 _abstract_transition_map: HashMap<(isize, Action), usize>,
19 abstract_action_map: HashMap<(isize, Action), Action>,
20 abstract_to_ground_map: HashMap<(usize, Action), (isize, Action)>,
21}
22
23impl Mapper {
24 pub fn new(
26 game: &Game,
27 abstraction: Option<Vec<Vec<isize>>>,
28 ) -> Result<Self, AbstractionError> {
29 let (ground_states, abstraction) = match abstraction {
30 Some(abstraction) => {
31 let states = get_all_states(game).map_err(|e| AbstractionError::Computation {
32 error: e.to_string(),
33 })?;
34 (states, abstraction)
35 }
36 None => {
37 let (states, abstraction) =
38 get_abstraction(game).map_err(|e| AbstractionError::Computation {
39 error: e.to_string(),
40 })?;
41 (states, abstraction)
42 }
43 };
44
45 let transition_map: HashMap<(isize, Action), usize> =
46 Mapper::build_abstract_transition_map(game, &ground_states, &abstraction)?;
47
48 let (abstract_action_map, abstract_to_ground_map) =
49 Self::build_action_maps(&transition_map, &abstraction)?;
50 Ok(Mapper {
53 all_ground_states: ground_states,
54 abstraction,
55 _abstract_transition_map: transition_map,
56 abstract_action_map,
57 abstract_to_ground_map,
58 })
59 }
60
61 fn get_ground_id(state: &State, all_states: &[State]) -> isize {
62 let state_id = all_states
63 .iter()
64 .find(|&ground_state| ground_state.unit_position == state.unit_position)
65 .expect("No matching state")
66 .index
67 .expect("Gound state list did not generate with ids");
68 state_id
69 }
70
71 fn build_abstract_transition_map(
72 game: &Game,
73 ground_states: &[State],
74 abstraction: &[Vec<isize>],
75 ) -> Result<HashMap<(isize, Action), usize>, AbstractionError> {
76 let mut abstract_transition_map: HashMap<(isize, Action), usize> = HashMap::new();
77 let ground_actions = [Action::Up, Action::Down, Action::Left, Action::Right];
78
79 for ground_state in ground_states.iter().cloned() {
80 let ground_state_index = match ground_state.index {
81 Some(index) => index,
82 None => Mapper::get_ground_id(&ground_state, ground_states),
83 };
84 for ground_action in ground_actions {
85 let (next_ground_state, _) =
86 game.simulate(&ground_state, &ground_action).map_err(|e| {
87 AbstractionError::Computation {
88 error: e.to_string(),
89 }
90 })?;
91 let next_ground_state_id = Mapper::get_ground_id(&next_ground_state, ground_states);
92 let next_abtract_state_id = abstraction
93 .iter()
94 .position(|cluster| cluster.contains(&next_ground_state_id))
95 .expect("ground state id not found in abstraction");
96 abstract_transition_map
97 .insert((ground_state_index, ground_action), next_abtract_state_id);
98 }
99 }
100
101 Ok(abstract_transition_map)
102 }
103
104 fn build_action_maps(
106 transition_map: &HashMap<(isize, Action), usize>,
107 abstraction: &[Vec<isize>],
108 ) -> Result<
109 (
110 HashMap<(isize, Action), Action>,
111 HashMap<(usize, Action), (isize, Action)>,
112 ),
113 AbstractionError,
114 > {
115 let mut ga2aa = HashMap::new();
116 let mut aa2ga = HashMap::new();
117 let ground_actions = [Action::Up, Action::Down, Action::Left, Action::Right];
118
119 for (src_abs, cluster) in abstraction.iter().enumerate() {
120 let mut buckets: Vec<(usize, Vec<(isize, Action)>)> = Vec::new();
122 for &gs in cluster {
123 for &ga in &ground_actions {
124 let next_abs = transition_map[&(gs, ga)];
125 if let Some((_, ref mut vec)) = buckets.iter_mut().find(|(k, _)| *k == next_abs)
126 {
127 vec.push((gs, ga));
128 } else {
129 buckets.push((next_abs, vec![(gs, ga)]));
130 }
131 }
132 }
133
134 let mut next_aa = 5; for (_next_abs, pairs) in buckets {
137 let aa = Action::from_id(next_aa).map_err(|e| AbstractionError::Computation {
138 error: e.to_string(),
139 })?;
140 next_aa += 1;
141
142 let &(rep_gs, rep_ga) = pairs
144 .iter()
145 .min_by_key(|(gs, _)| *gs)
146 .expect("bucket never empty");
147
148 aa2ga.entry((src_abs, aa)).or_insert((rep_gs, rep_ga));
150
151 for (gs, ga) in pairs {
153 ga2aa.insert((gs, ga), aa);
154 }
155 }
156 }
157
158 Ok((ga2aa, aa2ga))
159 }
160
161 pub fn ground_state_to_abstract(&self, state: &State) -> State {
162 let ground_state_id = Mapper::get_ground_id(state, &self.all_ground_states);
163
164 let abstract_id = self
165 .abstraction
166 .iter()
167 .position(|cluster| cluster.contains(&ground_state_id))
168 .expect("ground state not in any cluster") as isize;
169
170 let mut set = HashSet::new();
171 for &ground_action in state.valid_moves().iter() {
172 if let Some(&abstract_action) = self
173 .abstract_action_map
174 .get(&(ground_state_id, ground_action))
175 {
176 set.insert(abstract_action);
177 }
178 }
179 let valid_abstract_moves: Vec<Action> = set.into_iter().collect();
180
181 let mut abstract_state = State::new(state.unit_position, valid_abstract_moves);
182 abstract_state.index = Some(abstract_id);
183 abstract_state
184 }
185
186 pub fn abstract_state_action_to_ground(
187 &self,
188 state: &State,
189 action: Action,
190 ) -> (State, Action) {
191 let abs_id = state.index.expect("abstract state needs an index") as usize;
192 let (gs, ga) = if let Some(&(gs, ga)) = self.abstract_to_ground_map.get(&(abs_id, action)) {
193 (gs, ga)
194 } else {
195 let cluster = &self.abstraction[abs_id];
197 let &rep_gs = cluster
198 .iter()
199 .min()
200 .expect("cluster should have at least one state");
201
202 let ground_moves = [Action::Up, Action::Down, Action::Left, Action::Right];
204 let ga = ground_moves
205 .iter()
206 .copied()
207 .find(|&ga| {
208 self._abstract_transition_map
209 .get(&(rep_gs, ga))
210 .copied()
211 .expect("every (gs,ga) must be in transition_map")
212 == abs_id
213 })
214 .expect("no looping move found for abstract state");
215
216 (rep_gs, ga)
217 };
218
219 let mut ground_state = self.all_ground_states[gs as usize].clone();
221 ground_state.index = Some(gs);
222 (ground_state, ga)
223 }
224}
225
226mod tests {
227 use super::*;
228
229 #[allow(dead_code)]
230 fn make_game() -> Game {
231 let world = vec![
232 vec!['.', '.', '.'],
233 vec!['.', '.', '.'],
234 vec!['.', '.', 'G'],
235 ];
236
237 Game::new(world).expect("failed to build test game")
238 }
239
240 #[test]
241 fn test_mapper_init_no_abstraction() {
242 let game = make_game();
243
244 let mapper = Mapper::new(&game, None).unwrap();
245
246 let expected = vec![
247 vec![0],
248 vec![1, 2],
249 vec![3, 5],
250 vec![4],
251 vec![6, 7],
252 vec![8],
253 ];
254 assert_eq!(mapper.abstraction, expected);
255 }
256
257 #[test]
258 fn test_mapper_init_with_abstraction() {
259 let game = make_game();
260
261 let supplied = vec![
262 vec![0],
263 vec![1, 2],
264 vec![3],
265 vec![4],
266 vec![5],
267 vec![6, 7],
268 vec![8],
269 ];
270 let mapper = Mapper::new(&game, Some(supplied.clone())).unwrap();
271 assert_eq!(mapper.abstraction, supplied);
272 }
273
274 #[test]
275 fn test_all_states() {
276 let game = make_game();
277
278 let all_states = get_all_states(&game).unwrap();
279 assert_eq!(all_states.len(), 9)
280 }
281
282 #[test]
283 fn test_state_mapping() {
284 let game = make_game();
285
286 let all_states = get_all_states(&game).unwrap();
287
288 let mapper = Mapper::new(&game, None).unwrap();
289
290 for (i, state) in all_states.iter().enumerate() {
291 let ground_id = i;
292 assert_eq!(ground_id, i);
293
294 let abstract_idx = mapper
295 .ground_state_to_abstract(&state.clone())
296 .index
297 .unwrap() as usize;
298
299 let want_idx = mapper
300 .abstraction
301 .iter()
302 .position(|cluster| cluster.contains(&(ground_id as isize)))
303 .unwrap();
304 assert_eq!(abstract_idx, want_idx);
305 }
306 }
307
308 #[test]
309 fn test_abstract_action_transitions() {
310 let world = vec![
312 vec!['.', '.', '.'],
313 vec!['.', '.', '.'],
314 vec!['.', '.', 'G'],
315 ];
316 let game = Game::new(world).unwrap();
317 let mapper = Mapper::new(&game, None).unwrap();
318
319 let cases = vec![
321 (0, Action::AbstractAction1, 0),
322 (0, Action::AbstractAction2, 1),
323 (1, Action::AbstractAction1, 0),
324 (1, Action::AbstractAction2, 2),
325 (1, Action::AbstractAction3, 1),
326 (1, Action::AbstractAction4, 3),
327 (2, Action::AbstractAction1, 1),
328 (2, Action::AbstractAction2, 2),
329 (2, Action::AbstractAction3, 4),
330 (3, Action::AbstractAction1, 1),
331 (3, Action::AbstractAction2, 4),
332 (4, Action::AbstractAction1, 3),
333 (4, Action::AbstractAction2, 4),
334 (4, Action::AbstractAction3, 2),
335 (4, Action::AbstractAction4, 5),
336 (5, Action::AbstractAction2, 5),
337 ];
338
339 for (init_abs, aa, want_abs) in cases {
340 let gs_id = mapper.abstraction[init_abs][0] as usize;
342 let ground_state = &mapper.all_ground_states[gs_id];
343
344 let abstract_state = mapper.ground_state_to_abstract(ground_state);
346
347 let (_sel_gs, ground_action) =
349 mapper.abstract_state_action_to_ground(&abstract_state, aa);
350
351 let (new_ground_state, _) = game.simulate(ground_state, &ground_action).unwrap();
353
354 let new_abs_state = mapper.ground_state_to_abstract(&new_ground_state);
356 let got_abs = new_abs_state.index.unwrap() as usize;
357
358 assert_eq!(
359 got_abs, want_abs,
360 "From abstract state {} via {:?}, expected {} but got {}",
361 init_abs, aa, want_abs, got_abs
362 );
363 }
364 }
365
366 #[test]
367 fn test_abstract_action_transitions_4x4() {
368 let world = vec![
370 vec!['.', '.', '.', '.'],
371 vec!['.', '.', '.', '.'],
372 vec!['.', '.', '.', '.'],
373 vec!['.', '.', '.', 'G'],
374 ];
375 let game = Game::new(world).unwrap();
376 let mapper = Mapper::new(&game, None).unwrap();
377
378 let cases = vec![
380 (0, Action::AbstractAction1, 0),
381 (0, Action::AbstractAction2, 1),
382 (1, Action::AbstractAction1, 0),
383 (1, Action::AbstractAction2, 2),
384 (1, Action::AbstractAction3, 1),
385 (1, Action::AbstractAction4, 3),
386 (2, Action::AbstractAction1, 1),
387 (2, Action::AbstractAction2, 4),
388 (2, Action::AbstractAction3, 2),
389 (2, Action::AbstractAction4, 5),
390 (3, Action::AbstractAction1, 1),
391 (3, Action::AbstractAction2, 5),
392 (4, Action::AbstractAction1, 2),
393 (4, Action::AbstractAction2, 4),
394 (4, Action::AbstractAction3, 6),
395 (5, Action::AbstractAction1, 3),
396 (5, Action::AbstractAction2, 6),
397 (5, Action::AbstractAction3, 2),
398 (5, Action::AbstractAction4, 7),
399 (6, Action::AbstractAction1, 5),
400 (6, Action::AbstractAction2, 6),
401 (6, Action::AbstractAction3, 4),
402 (6, Action::AbstractAction4, 8),
403 (7, Action::AbstractAction1, 5),
404 (7, Action::AbstractAction2, 8),
405 (8, Action::AbstractAction1, 7),
406 (8, Action::AbstractAction2, 8),
407 (8, Action::AbstractAction3, 6),
408 (8, Action::AbstractAction4, 9),
409 (9, Action::AbstractAction2, 9),
410 ];
411
412 for (init_abs, aa, want_abs) in cases {
413 let gs_id = mapper.abstraction[init_abs][0] as usize;
415 let ground_state = &mapper.all_ground_states[gs_id];
416
417 let abstract_state = mapper.ground_state_to_abstract(ground_state);
419
420 let (_sel_gs, ground_action) =
422 mapper.abstract_state_action_to_ground(&abstract_state, aa);
423
424 let (new_ground_state, _) = game.simulate(ground_state, &ground_action).unwrap();
426
427 let new_abs_state = mapper.ground_state_to_abstract(&new_ground_state);
429 let got_abs = new_abs_state.index.unwrap() as usize;
430
431 assert_eq!(
432 got_abs, want_abs,
433 "From abstract state {} via {:?}, expected {} but got {}",
434 init_abs, aa, want_abs, got_abs
435 );
436 }
437 }
438
439 #[test]
440 fn test_abstract_to_ground_uses_minimum_representative_3() {
441 let game = make_game();
442 let mapper = Mapper::new(&game, None).expect("failed to build Mapper");
443
444 let (all_states, clusters) = get_abstraction(&game).expect("homomorphism failed");
446
447 for (abs_id, cluster) in clusters.iter().enumerate() {
449 let rep_ground_id = cluster[0];
451
452 let rep_state = &all_states[rep_ground_id as usize];
454 let abs_state = mapper.ground_state_to_abstract(rep_state);
455 assert_eq!(
456 abs_state.index.unwrap() as usize,
457 abs_id,
458 "representative state had wrong abstract index"
459 );
460
461 for &abs_action in abs_state.valid_moves().iter() {
463 let (gs, _ga) = mapper.abstract_state_action_to_ground(&abs_state, abs_action);
465 let mapped_id = Mapper::get_ground_id(&gs, &all_states);
466
467 assert_eq!(
469 mapped_id, rep_ground_id,
470 "abstract state {abs_id}, action {abs_action:?} mapped back to ground \
471 {mapped_id} but expected the minimal representative {rep_ground_id}"
472 );
473 }
474 }
475 }
476
477 #[test]
478 fn test_abstract_to_ground_uses_minimum_representative_4() {
479 let world = vec![
480 vec!['.', '.', '.', '.'],
481 vec!['.', '.', '.', '.'],
482 vec!['.', '.', '.', '.'],
483 vec!['.', '.', '.', 'G'],
484 ];
485
486 let game = Game::new(world).unwrap();
487 let mapper = Mapper::new(&game, None).expect("failed to build Mapper");
488
489 let (all_states, clusters) = get_abstraction(&game).expect("homomorphism failed");
491
492 for (abs_id, cluster) in clusters.iter().enumerate() {
494 let rep_ground_id = cluster[0];
496
497 let rep_state = &all_states[rep_ground_id as usize];
499 let abs_state = mapper.ground_state_to_abstract(rep_state);
500 assert_eq!(
501 abs_state.index.unwrap() as usize,
502 abs_id,
503 "representative state had wrong abstract index"
504 );
505
506 for &abs_action in abs_state.valid_moves().iter() {
508 let (gs, _ga) = mapper.abstract_state_action_to_ground(&abs_state, abs_action);
510 let mapped_id = Mapper::get_ground_id(&gs, &all_states);
511
512 assert_eq!(
514 mapped_id, rep_ground_id,
515 "abstract state {abs_id}, action {abs_action:?} mapped back to ground \
516 {mapped_id} but expected the minimal representative {rep_ground_id}"
517 );
518 }
519 }
520 }
521}