core_rust/core/abstraction/
homomorphism.rs1use crate::core::abstraction::*;
2use crate::core::game::{utils::errors::GameError, *};
3use errors::AbstractionError;
4use game_logic::Game;
5use ordered_float::OrderedFloat;
6use rayon::prelude::*;
7use state::State;
8use std::collections::{HashMap, HashSet, VecDeque};
9use std::time::Instant;
10use storing::{load_cache, save_cache, AbstractionEntry};
11
12pub fn get_all_states(game: &Game) -> Result<Vec<State>, GameError> {
16 let mut visited_states: HashSet<State> = HashSet::new();
18 let mut possible_states = Vec::new();
19
20 let mut state_queue = VecDeque::from([game.get_state()]);
22
23 while let Some(current_state) = state_queue.pop_front() {
24 if visited_states.contains(¤t_state) {
26 continue;
27 }
28
29 visited_states.insert(current_state.clone());
31 possible_states.push(current_state.clone());
32
33 for action in current_state.valid_moves() {
35 let new_state = match game.simulate(¤t_state, action) {
36 Ok((state, _)) => state,
37 Err(e) => return Err(e),
38 };
39
40 state_queue.push_back(new_state);
41 }
42 }
43
44 for (index, state) in possible_states.iter_mut().enumerate() {
46 state.index = Some(index as isize);
47 }
48
49 Ok(possible_states)
50}
51
52pub fn compute_signature_parallel(
57 state: &State,
58 partition: &HashMap<isize, usize>,
59 game: &Game,
60 position_lookup: &HashMap<(usize, usize), isize>,
61) -> Vec<(OrderedFloat<f32>, usize)> {
62 let mut outcomes = vec![];
63
64 for action in state.valid_moves.iter() {
66 let (_, vars) = game.simulate(state, action).unwrap();
70 let pos = game.simulate(state, action).unwrap().0.unit_position;
71
72 let next_index = match position_lookup.get(&pos) {
74 Some(idx) => *idx,
75 None => {
76 eprintln!("Simulated state not found in state set.");
77 eprintln!("Missing position: {:?}", pos);
78 eprintln!(
79 "State details: {:?}",
80 game.simulate(state, action).unwrap().0
81 );
82 panic!("Abstraction failed: new state was not found in state set.");
83 }
84 };
85
86 let partition_id = *partition
88 .get(&next_index)
89 .expect("Partition must contain all state indices");
90
91 outcomes.push((OrderedFloat(vars.score), partition_id));
93 }
94
95 outcomes.sort_by(|a, b| a.partial_cmp(b).unwrap());
97 outcomes
98}
99
100pub fn compute_mdp_homomorphism(states: &[State], game: &Game) -> Vec<Vec<isize>> {
108 let mut partition: HashMap<isize, usize> = HashMap::new();
109
110 for state in states.iter() {
112 let done = game.goal() == state.unit_position;
113 partition.insert(
114 state.index.expect("State must be indexed"),
115 if done { 0 } else { 1 },
116 );
117 }
118
119 let mut position_lookup: HashMap<(usize, usize), isize> = HashMap::new();
123 for state in states.iter() {
124 position_lookup.insert(state.unit_position, state.index.unwrap());
125 }
126
127 let total_states = states.len();
132 let mut changed = true;
133 let mut iteration = 0;
134 let min_iters = 10000;
135 let max_stagnant_iters = 100;
136 let mut stagnant_count = 0;
137 let mut prev_partition_count = 0;
138
139 while changed {
143 let sig_state_pairs: Vec<(Vec<(OrderedFloat<f32>, usize)>, isize)> = states
145 .par_iter()
146 .map(|state| {
147 let sig = compute_signature_parallel(state, &partition, game, &position_lookup);
148 (sig, state.index.unwrap())
149 })
150 .collect();
151
152 let mut groups_by_signature: HashMap<Vec<(OrderedFloat<f32>, usize)>, Vec<isize>> =
154 HashMap::new();
155 for (sig, idx) in sig_state_pairs {
156 groups_by_signature.entry(sig).or_default().push(idx);
157 }
158
159 let mut new_partition: HashMap<isize, usize> = HashMap::new();
161
162 for (pid, group) in groups_by_signature.values().enumerate() {
163 for idx in group.iter() {
164 new_partition.insert(*idx, pid);
165 }
166 }
167
168 let new_partition_count = new_partition.len();
170 if iteration >= min_iters {
171 if new_partition_count == prev_partition_count {
172 stagnant_count += 1;
173 } else {
174 stagnant_count = 0;
175 }
176
177 if stagnant_count >= max_stagnant_iters {
180 println!(
181 "Early stop after {} stagnant iterations ({} total states, {} groups)",
182 stagnant_count, total_states, new_partition_count
183 );
184 break;
185 }
186 }
187
188 prev_partition_count = new_partition_count;
189 iteration += 1;
190
191 if new_partition == partition {
192 changed = false;
193 } else {
194 partition = new_partition;
195 }
196 }
197
198 let mut groups: HashMap<usize, Vec<isize>> = HashMap::new();
200 for (idx, group_id) in partition {
201 groups.entry(group_id).or_default().push(idx);
202 }
203
204 let mut clusters: Vec<Vec<isize>> = groups
207 .into_values()
208 .map(|mut v| {
209 v.sort_unstable();
210 v
211 })
212 .collect();
213 clusters.sort_unstable_by_key(|cluster| cluster[0]);
214
215 println!("Required {} iterations to converge", iteration);
216
217 clusters
218}
219
220pub fn get_abstraction(game: &Game) -> Result<(Vec<State>, Vec<Vec<isize>>), AbstractionError> {
223 let now = Instant::now();
224 let config = &game.world_configuration();
225 let cache_file = "abstraction_cache.json";
226
227 let mut cache = load_cache(cache_file).map_err(AbstractionError::Io)?;
229
230 if let Some(entry) = cache.iter().find(|e| &e.config == config) {
232 return Ok((entry.states.clone(), entry.clusters.clone()));
233 }
234
235 let game_clone = game.clone();
237 let all_states = get_all_states(&game_clone).map_err(|e| AbstractionError::Computation {
238 error: e.to_string(),
239 })?;
240
241 let clusters = compute_mdp_homomorphism(all_states.as_slice(), game);
242
243 cache.push(AbstractionEntry {
245 config: config.clone(),
246 states: all_states.clone(),
247 clusters: clusters.clone(),
248 });
249 println!("Saving config...");
250 save_cache(cache_file, &cache).map_err(AbstractionError::Io)?;
251
252 let elapsed_time = now.elapsed();
253 println!(
254 "Took {} seconds to calculate exact homomorphism",
255 elapsed_time.as_secs()
256 );
257
258 Ok((all_states, clusters))
259}