Skip to content

Commit

Permalink
Merge branch 'inference_test'
Browse files Browse the repository at this point in the history
  • Loading branch information
kfastov committed Nov 18, 2023
2 parents fec2247 + 5c229d2 commit 654eaba
Showing 1 changed file with 166 additions and 10 deletions.
176 changes: 166 additions & 10 deletions contracts/src/inference.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ fn predict(mut x: Tensor<FP16x16>) -> FP16x16 {
// board_state_copy[i,j]=turn_monitor
// legal_moves_dict[(i,j)]=board_state_copy.flatten()
// return legal_moves_dict
fn legal_moves_generator(current_board_state: Array<u8>, turn_monitor: u8) -> Array<Array<u8>> {
fn legal_moves_generator(
current_board_state: @Array<u8>, turn_monitor: u8
) -> Array<(Array<u8>, u32)> {
let mut moves = ArrayTrait::new();
let mut index = 0;
loop {
Expand All @@ -71,11 +73,10 @@ fn legal_moves_generator(current_board_state: Array<u8>, turn_monitor: u8) -> Ar
// loop body
if *current_board_state.at(index) == MOVE_EMPTY {
let board_state_copy = modify_array_at_index(
@current_board_state, index, turn_monitor.into()
current_board_state, index, turn_monitor.into()
);
moves.append(board_state_copy);
moves.append((board_state_copy, index));
}
let copy = modify_array_at_index(@current_board_state, 1, 2);
// end of loop body
index += 1;
};
Expand Down Expand Up @@ -122,9 +123,71 @@ fn modify_array_at_index(array: @Array<u8>, index: u32, value: u8) -> Array<u8>
// score=tracker[selected_move]
// return selected_move,new_board_state,score

fn move_selector(current_board_state: Array<u8>, turn_monitor: u8) -> u32 { // index of the move
let mut current_max_location = 0;
let mut current_max = FixedTrait::<FP16x16>::new_unscaled(1000, true); // -1000
let legal_moves = legal_moves_generator(@current_board_state, turn_monitor);

let mut i = 0;
loop {
if (i >= legal_moves.len()) {
break;
}

let (state_after, location) = legal_moves.at(i);

// get tensor representation of a board state
let mut tensor_state_after = board_state_to_tensor(state_after);

let value = predict(tensor_state_after);

// compare prediction with a previous one
if value >= current_max {
// set current prediction and index to max prediction
current_max = value;
current_max_location = *location;
}
i += 1;
};
// return the move in the index
current_max_location
}

// TODO impl Into<Array<u8>, Tensor>
fn board_state_to_tensor(board_state: @Array<u8>) -> Tensor<FP16x16> {
// TODO globals?
let p0 = FixedTrait::<FP16x16>::new_unscaled(MOVE_PLAYER0.into(), false);
let p1 = FixedTrait::<FP16x16>::new_unscaled(MOVE_PLAYER1.into(), false);
let empty = FixedTrait::<FP16x16>::new_unscaled(MOVE_EMPTY.into(), false);

let mut tensor_data = ArrayTrait::new();

let mut i = 0;
loop {
if i >= board_state.len() {
break;
}
tensor_data
.append(
// TODO use enum with Into<u8> and match on it
if *board_state.at(i) == MOVE_PLAYER0 {
p0
} else if *board_state.at(i) == MOVE_PLAYER1 {
p1
} else {
empty
}
);
i += 1;
};

Tensor { shape: array![9].span(), data: tensor_data.span() }
}

#[cfg(test)]
mod tests {
use super::{MOVE_PLAYER0, MOVE_PLAYER1, MOVE_EMPTY};
use orion::numbers::{FP16x16, FixedTrait};
#[test]
#[available_gas(2000000000000)]
fn test_modify_array_at_index() {
Expand All @@ -135,7 +198,6 @@ mod tests {
assert(*new_arr.at(2) == 3, 'wrong value at index 2');
}

//fn legal_moves_generator(current_board_state: Array<u8>, turn_monitor: u8) -> Array<Array<u8>> {
#[test]
#[available_gas(2000000000000)]
fn test_legal_moves_generator() {
Expand All @@ -150,12 +212,15 @@ mod tests {
MOVE_EMPTY,
MOVE_PLAYER1,
];
let moves = super::legal_moves_generator(board, MOVE_PLAYER0);
let moves = super::legal_moves_generator(@board, MOVE_PLAYER0);

assert(moves.len() == 2, 'wrong moves len');

let move0 = moves.at(0);
let move1 = moves.at(1);
let (move0, loc0) = moves.at(0);
let (move1, loc1) = moves.at(1);

assert(*loc0 == 2, 'wrong location 0');
assert(*loc1 == 7, 'wrong location 1');

assert(*move0.at(0) == MOVE_PLAYER0, 'wrong value at move 0 index 0');
assert(*move0.at(1) == MOVE_PLAYER0, 'wrong value at move 0 index 1');
Expand All @@ -164,17 +229,108 @@ mod tests {
assert(*move0.at(4) == MOVE_PLAYER1, 'wrong value at move 0 index 4');
assert(*move0.at(5) == MOVE_PLAYER0, 'wrong value at move 0 index 5');
assert(*move0.at(6) == MOVE_PLAYER0, 'wrong value at move 0 index 6');
assert(*move0.at(7) == MOVE_EMPTY, 'wrong value at move 0 index 7');
assert(*move0.at(7) == MOVE_EMPTY, 'wrong value at move 0 index 7');
assert(*move0.at(8) == MOVE_PLAYER1, 'wrong value at move 0 index 8');

assert(*move1.at(0) == MOVE_PLAYER0, 'wrong value at move 1 index 0');
assert(*move1.at(1) == MOVE_PLAYER0, 'wrong value at move 1 index 1');
assert(*move1.at(2) == MOVE_EMPTY, 'wrong value at move 1 index 2');
assert(*move1.at(2) == MOVE_EMPTY, 'wrong value at move 1 index 2');
assert(*move1.at(3) == MOVE_PLAYER1, 'wrong value at move 1 index 3');
assert(*move1.at(4) == MOVE_PLAYER1, 'wrong value at move 1 index 4');
assert(*move1.at(5) == MOVE_PLAYER0, 'wrong value at move 1 index 5');
assert(*move1.at(6) == MOVE_PLAYER0, 'wrong value at move 1 index 6');
assert(*move1.at(7) == MOVE_PLAYER0, 'wrong value at move 1 index 7');
assert(*move1.at(8) == MOVE_PLAYER1, 'wrong value at move 1 index 8');
}

#[test]
#[available_gas(2000000000000)]
fn test_board_state_to_tensor() {
let board = array![
MOVE_PLAYER0,
MOVE_PLAYER0,
MOVE_EMPTY,
MOVE_PLAYER1,
MOVE_PLAYER1,
MOVE_PLAYER0,
MOVE_PLAYER0,
MOVE_EMPTY,
MOVE_PLAYER1,
];
let tensor = super::board_state_to_tensor(@board);

// TODO
// assert(tensor.shape(0) == 9, 'wrong tensor shape');

let p0 = FixedTrait::<FP16x16>::new_unscaled(MOVE_PLAYER0.into(), false);
let p1 = FixedTrait::<FP16x16>::new_unscaled(MOVE_PLAYER1.into(), false);
let empty = FixedTrait::<FP16x16>::new_unscaled(MOVE_EMPTY.into(), false);

assert(*tensor.data.at(0) == p0, 'wrong value at index 0');
assert(*tensor.data.at(1) == p0, 'wrong value at index 1');
assert(*tensor.data.at(2) == empty, 'wrong value at index 2');
assert(*tensor.data.at(3) == p1, 'wrong value at index 3');
assert(*tensor.data.at(4) == p1, 'wrong value at index 4');
assert(*tensor.data.at(5) == p0, 'wrong value at index 5');
assert(*tensor.data.at(6) == p0, 'wrong value at index 6');
assert(*tensor.data.at(7) == empty, 'wrong value at index 7');
assert(*tensor.data.at(8) == p1, 'wrong value at index 8');
}

#[test]
#[available_gas(2000000000000)]
fn test_move_selector_player0() {
// The state looks like this:
// o o _
// x x o
// x _ x
//
// An ideal AI should make a move on (0, 2) when playing with "o"

let state = array![
MOVE_PLAYER0,
MOVE_PLAYER0,
MOVE_EMPTY,
MOVE_PLAYER1,
MOVE_PLAYER1,
MOVE_PLAYER0,
MOVE_PLAYER1,
MOVE_EMPTY,
MOVE_PLAYER1,
];

let current_player = MOVE_PLAYER0;

let move = super::move_selector(state, current_player);

assert(move == 2, 'bad move');
}

#[test]
#[available_gas(2000000000000)]
fn test_move_selector_player1() {
// The state looks like this:
// o x o
// o x _
// x _ o
//

let state = array![
MOVE_PLAYER0,
MOVE_PLAYER1,
MOVE_PLAYER0,
MOVE_PLAYER0,
MOVE_PLAYER1,
MOVE_EMPTY,
MOVE_PLAYER1,
MOVE_EMPTY,
MOVE_PLAYER0,
];

let current_player = MOVE_PLAYER1;

let move = super::move_selector(state, current_player);

assert(move == 7, 'bad move');
}
}

0 comments on commit 654eaba

Please sign in to comment.