core_rust/core/abstraction/
homomorphism.rs

1use crate::core::abstraction::*;
2use crate::core::game::{utils::errors::GameError, *};
3use errors::AbstractionError;
4use game_logic::Game;
5use ordered_float::OrderedFloat;
6use rayon::prelude::*;
7use state::State;
8use std::collections::{HashMap, HashSet, VecDeque};
9use std::time::Instant;
10use storing::{load_cache, save_cache, AbstractionEntry};
11
12/// BFS to enumerate all reachable states from the initial game state.
13/// Each unique state is assigned a unique index from [0, N].
14/// N is the number of reachable states.
15pub fn get_all_states(game: &Game) -> Result<Vec<State>, GameError> {
16    // Hashset to track visited states
17    let mut visited_states: HashSet<State> = HashSet::new();
18    let mut possible_states = Vec::new();
19
20    // Standard BFS queue starting with the root game state
21    let mut state_queue = VecDeque::from([game.get_state()]);
22
23    while let Some(current_state) = state_queue.pop_front() {
24        // Skip states that have already been visited
25        if visited_states.contains(&current_state) {
26            continue;
27        }
28
29        // Mark state as visited
30        visited_states.insert(current_state.clone());
31        possible_states.push(current_state.clone());
32
33        // Find new reachable states based on actions --> append to queue
34        for action in current_state.valid_moves() {
35            let new_state = match game.simulate(&current_state, action) {
36                Ok((state, _)) => state,
37                Err(e) => return Err(e),
38            };
39
40            state_queue.push_back(new_state);
41        }
42    }
43
44    // Assign each state an index
45    for (index, state) in possible_states.iter_mut().enumerate() {
46        state.index = Some(index as isize);
47    }
48
49    Ok(possible_states)
50}
51
52/// Compute the “signature” of one state under the current partitioning.
53/// For each action, record (reward, next_partition_id).  
54/// Sorting these pairs gives us a fingerprint used to decide which states are equivalent.
55/// This version is specifically so it can be used in parallel with rayon.
56pub fn compute_signature_parallel(
57    state: &State,
58    partition: &HashMap<isize, usize>,
59    game: &Game,
60    position_lookup: &HashMap<(usize, usize), isize>,
61) -> Vec<(OrderedFloat<f32>, usize)> {
62    let mut outcomes = vec![];
63
64    // For each action, simulate and record (reward, partition_of_successor).
65    for action in state.valid_moves.iter() {
66        // Simulate next game state
67        // WARNING: simulated game states do not have an index therefore mapping to indexes is done with unit position
68        // Couldn't think of a better way to do this
69        let (_, vars) = game.simulate(state, action).unwrap();
70        let pos = game.simulate(state, action).unwrap().0.unit_position;
71
72        // Get next state index from mapping to unit position
73        let next_index = match position_lookup.get(&pos) {
74            Some(idx) => *idx,
75            None => {
76                eprintln!("Simulated state not found in state set.");
77                eprintln!("Missing position: {:?}", pos);
78                eprintln!(
79                    "State details: {:?}",
80                    game.simulate(state, action).unwrap().0
81                );
82                panic!("Abstraction failed: new state was not found in state set.");
83            }
84        };
85
86        // Find which partition that successor state currently belongs to.
87        let partition_id = *partition
88            .get(&next_index)
89            .expect("Partition must contain all state indices");
90
91        // Push the reward + partition pair.
92        outcomes.push((OrderedFloat(vars.score), partition_id));
93    }
94
95    // Sort so that signature is order-invariant across action enumeration.
96    outcomes.sort_by(|a, b| a.partial_cmp(b).unwrap());
97    outcomes
98}
99
100/// Main loop of MDP‐homomorphism refinement:
101/// 1. Initialize coarse partition: terminal vs. nonterminal.
102/// 2. Repeat until (no change) or early-stop:
103///    - For *each* state, compute its signature.
104///    - Group states by identical signature.
105///    - Reassign each group a new unique partition id.
106/// 3. Return the final clusters of state‐indices.
107pub fn compute_mdp_homomorphism(states: &[State], game: &Game) -> Vec<Vec<isize>> {
108    let mut partition: HashMap<isize, usize> = HashMap::new();
109
110    // Start with two partitions: goal‐states (pid=0) vs. everything else (pid=1)
111    for state in states.iter() {
112        let done = game.goal() == state.unit_position;
113        partition.insert(
114            state.index.expect("State must be indexed"),
115            if done { 0 } else { 1 },
116        );
117    }
118
119    // Build unit_position -> index lookup table
120    // States that come from simulation don't have an index so we need to compare unit positions since they are unique
121    // Used in the signature function, this spares us time having to recalculate it
122    let mut position_lookup: HashMap<(usize, usize), isize> = HashMap::new();
123    for state in states.iter() {
124        position_lookup.insert(state.unit_position, state.index.unwrap());
125    }
126
127    // Heuristic early stop variables - tuned by just playing around
128    // IMPORTANT HERE:
129    // `min_iters`: Minimum amount of iterations before we early stop
130    // `max_stagnant_iters`: How many iterations we see barely any change before we stop
131    let total_states = states.len();
132    let mut changed = true;
133    let mut iteration = 0;
134    let min_iters = 10000;
135    let max_stagnant_iters = 100;
136    let mut stagnant_count = 0;
137    let mut prev_partition_count = 0;
138
139    // Refinement
140    // O(S * A * number of iterations)
141    // Computationally very expensive, therefore using early stopping
142    while changed {
143        // Compute signatures in parallel with rayon to distribute over cores
144        let sig_state_pairs: Vec<(Vec<(OrderedFloat<f32>, usize)>, isize)> = states
145            .par_iter()
146            .map(|state| {
147                let sig = compute_signature_parallel(state, &partition, game, &position_lookup);
148                (sig, state.index.unwrap())
149            })
150            .collect();
151
152        // Regroup states by identical signature
153        let mut groups_by_signature: HashMap<Vec<(OrderedFloat<f32>, usize)>, Vec<isize>> =
154            HashMap::new();
155        for (sig, idx) in sig_state_pairs {
156            groups_by_signature.entry(sig).or_default().push(idx);
157        }
158
159        // Build a new partition map by assigning each group a new pid
160        let mut new_partition: HashMap<isize, usize> = HashMap::new();
161
162        for (pid, group) in groups_by_signature.values().enumerate() {
163            for idx in group.iter() {
164                new_partition.insert(*idx, pid);
165            }
166        }
167
168        // Get the new partitions and see how they have changed compared to the last cycle
169        let new_partition_count = new_partition.len();
170        if iteration >= min_iters {
171            if new_partition_count == prev_partition_count {
172                stagnant_count += 1;
173            } else {
174                stagnant_count = 0;
175            }
176
177            // If we get a `somewhat` stable grouping after `n` iterations we can assume its gtg
178            // If partially or fully abstractable we would be able to converge to a different grouping over time
179            if stagnant_count >= max_stagnant_iters {
180                println!(
181                    "Early stop after {} stagnant iterations ({} total states, {} groups)",
182                    stagnant_count, total_states, new_partition_count
183                );
184                break;
185            }
186        }
187
188        prev_partition_count = new_partition_count;
189        iteration += 1;
190
191        if new_partition == partition {
192            changed = false;
193        } else {
194            partition = new_partition;
195        }
196    }
197
198    // Convert final partition map into Vec<Vec<isize>>
199    let mut groups: HashMap<usize, Vec<isize>> = HashMap::new();
200    for (idx, group_id) in partition {
201        groups.entry(group_id).or_default().push(idx);
202    }
203
204    // Group the abstraction by ground state index
205    // Similar to how it was done in Python prototype to allow comparing the results
206    let mut clusters: Vec<Vec<isize>> = groups
207        .into_values()
208        .map(|mut v| {
209            v.sort_unstable();
210            v
211        })
212        .collect();
213    clusters.sort_unstable_by_key(|cluster| cluster[0]);
214
215    println!("Required {} iterations to converge", iteration);
216
217    clusters
218}
219
220/// Top-level API: either load from cache to save time or compute the exact homomorphism.
221/// Returns (all_states, clusters), and saves to disk for next time.
222pub fn get_abstraction(game: &Game) -> Result<(Vec<State>, Vec<Vec<isize>>), AbstractionError> {
223    let now = Instant::now();
224    let config = &game.world_configuration();
225    let cache_file = "abstraction_cache.json";
226
227    // Load cache if available
228    let mut cache = load_cache(cache_file).map_err(AbstractionError::Io)?;
229
230    // Look for config
231    if let Some(entry) = cache.iter().find(|e| &e.config == config) {
232        return Ok((entry.states.clone(), entry.clusters.clone()));
233    }
234
235    // No config so we get all states and run compute function
236    let game_clone = game.clone();
237    let all_states = get_all_states(&game_clone).map_err(|e| AbstractionError::Computation {
238        error: e.to_string(),
239    })?;
240
241    let clusters = compute_mdp_homomorphism(all_states.as_slice(), game);
242
243    // Save to file
244    cache.push(AbstractionEntry {
245        config: config.clone(),
246        states: all_states.clone(),
247        clusters: clusters.clone(),
248    });
249    println!("Saving config...");
250    save_cache(cache_file, &cache).map_err(AbstractionError::Io)?;
251
252    let elapsed_time = now.elapsed();
253    println!(
254        "Took {} seconds to calculate exact homomorphism",
255        elapsed_time.as_secs()
256    );
257
258    Ok((all_states, clusters))
259}