Skip to content

Commit

Permalink
Merge pull request Cryolite#54 from Cryolite/feature/simulation_log
Browse files Browse the repository at this point in the history
Add support for new simulation log
  • Loading branch information
Cryolite authored Sep 7, 2023
2 parents 94cfc49 + 74dc10c commit 0cdfbd0
Show file tree
Hide file tree
Showing 42 changed files with 879 additions and 475 deletions.
76 changes: 41 additions & 35 deletions kanachan/simulation/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def _main():
raise RuntimeError(f'{config.baseline_model}: Not a file.')
baseline_model = load_model(config.baseline_model, map_location='cpu')
baseline_model.to(device=device, dtype=dtype)
baseline_model.requires_grad_(False)
baseline_model.eval()

if config.baseline_grade < 0 or 15 < config.baseline_grade:
Expand All @@ -97,6 +98,7 @@ def _main():
raise RuntimeError(f'{config.proposed_model}: Not a file.')
proposed_model = load_model(config.proposed_model, map_location='cpu')
proposed_model.to(device=device, dtype=dtype)
proposed_model.requires_grad_(False)
proposed_model.eval()

if config.proposed_grade < 0 or 15 < config.proposed_grade:
Expand Down Expand Up @@ -128,10 +130,14 @@ def _main():

with torch.no_grad():
start_time = datetime.datetime.now()
results = simulate(
game_logs = simulate(
device, dtype, config.baseline_grade, baseline_model,
config.proposed_grade, proposed_model, mode, config.n,
config.batch_size, config.concurrency)
game_results = []
for game_log in game_logs:
game_result = game_log.get_result()
game_results.append(game_result)

elapsed_time = datetime.datetime.now() - start_time
if config.non_duplicated:
Expand All @@ -142,7 +148,7 @@ def _main():
assert config.mode == '1vs3'
print(f'Elapsed time: {elapsed_time} ({elapsed_time / (config.n * 4.0)}/game)')

num_games = len(results)
num_games = len(game_results)

def get_grading_point(ranking: int, score: int) -> int:
return [125, 60, -5, -255][ranking] + (score - 25000) // 1000
Expand All @@ -153,31 +159,30 @@ def get_soul_point(ranking: int, _: int) -> float:
Statistic = Tuple[float, float]

def get_statistic(
results: List[object], proposed: int,
game_results: List[object], proposed: int,
callback: Callable[[int, int], float]) -> Statistic:
assert proposed in (0, 1)
average = 0.0
num_proposed = 0
for game in results:
assert len(game['proposed']) == 4
assert len(game['ranking']) == 4
for game_result in game_results:
assert len(game_result) == 4
for i in range(4):
assert game['proposed'][i] in (0, 1)
ranking = game['ranking'][i]
score = game['scores'][i]
if game['proposed'][i] == proposed:
ranking = game_result[i]['ranking']
score = game_result[i]['score']
if game_result[i]['proposed'] == proposed:
average += callback(ranking, score)
num_proposed += 1
assert num_proposed >= 1
average /= num_proposed

variance = 0.0
num_proposed = 0
for game in results:
for game_result in game_results:
assert len(game_result) == 4
for i in range(4):
ranking = game['ranking'][i]
score = game['scores'][i]
if game['proposed'][i] == proposed:
ranking = game_result[i]['ranking']
score = game_result[i]['score']
if game_result[i]['proposed'] == proposed:
variance += (callback(ranking, score) - average) ** 2.0
num_proposed += 1
assert num_proposed >= 1
Expand All @@ -188,31 +193,32 @@ def get_statistic(

Statistics = Tuple[Statistic, Statistic, Statistic, Statistic, Statistic]

def get_statistics(results: List[object], proposed: int) -> Statistics:
ranking_statistic = get_statistic(results, proposed, lambda r, s: r)
grading_point_statistic = get_statistic(
results, proposed, get_grading_point)
soul_point_statistic = get_statistic(results, proposed, get_soul_point)
def get_statistics(game_results: List[object], proposed: int) -> Statistics:
ranking_statistic = get_statistic(game_results, proposed, lambda r, s: r)
grading_point_statistic = get_statistic(game_results, proposed, get_grading_point)
soul_point_statistic = get_statistic(game_results, proposed, get_soul_point)

top_rate = 0.0
num_proposed = 0
for game in results:
for game_result in game_results:
assert len(game_result) == 4
for i in range(4):
if game['proposed'][i] == proposed:
if game_result[i]['proposed'] == proposed:
num_proposed += 1
if game['ranking'][i] == 0:
if game_result[i]['ranking'] == 0:
top_rate += 1.0
top_rate /= num_proposed
# Unbiased sample variance.
top_rate_variance = top_rate * (1.0 - top_rate) / (num_proposed - 1)

quinella_rate = 0.0
num_proposed = 0
for game in results:
for game_result in game_results:
assert len(game_result) == 4
for i in range(4):
if game['proposed'][i] == proposed:
if game_result[i]['proposed'] == proposed:
num_proposed += 1
if game['ranking'][i] <= 1:
if game_result[i]['ranking'] <= 1:
quinella_rate += 1.0
quinella_rate /= num_proposed
# Unbiased sample variance.
Expand All @@ -226,22 +232,22 @@ def get_statistics(results: List[object], proposed: int) -> Statistics:
(top_rate, top_rate_variance),
(quinella_rate, quinella_rate_variance),)

baseline_statistics = get_statistics(results, 0)
proposed_statistics = get_statistics(results, 1)
baseline_statistics = get_statistics(game_results, 0)
proposed_statistics = get_statistics(game_results, 1)

ranking_diff_average = 0.0
for game in results:
for game_result in game_results:
assert len(game_result) == 4
num_baseline = 0
baseline_ranking = 0.0
num_proposed = 0
proposed_ranking = 0.0
for i in range(4):
ranking = game['ranking'][i]
if game['proposed'][i] == 0:
ranking = game_result[i]['ranking']
if game_result[i]['proposed']:
num_baseline += 1
baseline_ranking += ranking
else:
assert game['proposed'][i] == 1
num_proposed += 1
proposed_ranking += ranking
assert num_baseline >= 1
Expand All @@ -253,18 +259,18 @@ def get_statistics(results: List[object], proposed: int) -> Statistics:
ranking_diff_average /= num_games

ranking_diff_variance = 0.0
for game in results:
for game_result in game_results:
assert len(game_result) == 4
num_baseline = 0
baseline_ranking = 0.0
num_proposed = 0
proposed_ranking = 0.0
for i in range(4):
ranking = game['ranking'][i]
if game['proposed'][i] == 0:
ranking = game_result[i]['ranking']
if game_result[i]['proposed']:
num_baseline += 1
baseline_ranking += ranking
else:
assert game['proposed'][i] == 1
num_proposed += 1
proposed_ranking += ranking
assert num_baseline >= 1
Expand Down
2 changes: 2 additions & 0 deletions src/simulation/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ add_library(simulation SHARED
jiagang.cpp
angang.cpp
zimo.cpp
game_log.cpp
round_state.cpp
shoupai.cpp
xiangting_calculator.cpp
paishan.cpp
game_state.cpp
round_result.cpp
decision_maker.cpp
gil.cpp
utility.cpp)
Expand Down
18 changes: 7 additions & 11 deletions src/simulation/angang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

#include "simulation/hule.hpp"
#include "simulation/zimo.hpp"
#include "simulation/game_log.hpp"
#include "simulation/round_state.hpp"
#include "common/throw.hpp"
#include <boost/python/dict.hpp>
#include <functional>
#include <any>
#include <utility>
Expand All @@ -16,39 +16,35 @@
namespace {

using std::placeholders::_1;
namespace python = boost::python;

} // namespace `anonymous`

namespace Kanachan{

std::any angang(
Kanachan::RoundState &round_state, std::uint_fast8_t const zimo_tile,
std::uint_fast8_t const encode, python::dict result)
std::uint_fast8_t const encode, Kanachan::GameLog &game_log)
{
if (zimo_tile >= 37u) {
KANACHAN_THROW<std::invalid_argument>(_1) << static_cast<unsigned>(zimo_tile);
}
if (encode >= 34u) {
KANACHAN_THROW<std::invalid_argument>(_1) << static_cast<unsigned>(encode);
}
if (result.is_none()) {
KANACHAN_THROW<std::invalid_argument>("`result` must not be a `None`.");
}

std::uint_fast16_t const action = round_state.onAngang(zimo_tile, encode);
std::uint_fast16_t const action = round_state.onAngang(zimo_tile, encode, game_log);

if (action == std::numeric_limits<std::uint_fast16_t>::max()) {
if (action == UINT_FAST16_MAX) {
// Si Gang San Le (四槓散了) の成立は打牌直後.
auto zimo = std::bind(&Kanachan::zimo, std::ref(round_state), result);
auto zimo = std::bind(&Kanachan::zimo, std::ref(round_state), std::ref(game_log));
std::function<std::any()> next_step(std::move(zimo));
return next_step;
}

if (action == 543u) {
// Qiang Gang (槍槓)
std::uint_fast8_t const zimo_tile_ = std::numeric_limits<std::uint_fast8_t>::max();
auto hule = std::bind(&Kanachan::hule, std::ref(round_state), zimo_tile_, result);
std::uint_fast8_t const zimo_tile_ = UINT_FAST8_MAX;
auto hule = std::bind(&Kanachan::hule, std::ref(round_state), zimo_tile_, std::ref(game_log));
std::function<std::any()> next_step(std::move(hule));
return next_step;
}
Expand Down
4 changes: 2 additions & 2 deletions src/simulation/angang.hpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#if !defined(KANACHAN_SIMULATION_ANGANG_HPP_INCLUDE_GUARD)
#define KANACHAN_SIMULATION_ANGANG_HPP_INCLUDE_GUARD

#include "simulation/game_log.hpp"
#include "simulation/round_state.hpp"
#include <boost/python/dict.hpp>
#include <any>
#include <cstdint>

Expand All @@ -11,7 +11,7 @@ namespace Kanachan{

std::any angang(
Kanachan::RoundState &round_state, std::uint_fast8_t zimo_tile, std::uint_fast8_t encode,
boost::python::dict result);
Kanachan::GameLog &game_log);

} // namespace Kanachan

Expand Down
13 changes: 5 additions & 8 deletions src/simulation/chi.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include "simulation/chi.hpp"

#include "simulation/dapai.hpp"
#include "simulation/game_log.hpp"
#include "simulation/round_state.hpp"
#include "common/assert.hpp"
#include "common/throw.hpp"
#include <boost/python/dict.hpp>
#include <functional>
#include <any>
#include <utility>
Expand All @@ -15,30 +15,27 @@
namespace {

using std::placeholders::_1;
namespace python = boost::python;

} // namespace `anonymous`

namespace Kanachan{

std::any chi(Kanachan::RoundState &round_state, std::uint_fast8_t const encode, python::dict result)
std::any chi(
Kanachan::RoundState &round_state, std::uint_fast8_t const encode, Kanachan::GameLog &game_log)
{
if (encode >= 90u) {
KANACHAN_THROW<std::invalid_argument>(_1) << static_cast<unsigned>(encode);
}
if (result.is_none()) {
KANACHAN_THROW<std::invalid_argument>("`result` must not be a `None`.");
}

std::uint_fast16_t const action = round_state.onChi(encode);
std::uint_fast16_t const action = round_state.onChi(encode, game_log);

if (action <= 147u) {
std::uint_fast8_t const tile = action / 4u;
bool const moqi = ((action - tile * 4u) / 2u >= 2u);
KANACHAN_ASSERT((!moqi));
bool const lizhi = ((action - tile * 4u - moqi * 2u) == 1u);
KANACHAN_ASSERT((!lizhi));
auto dapai = std::bind(&Kanachan::dapai, std::ref(round_state), tile, moqi, lizhi, result);
auto dapai = std::bind(&Kanachan::dapai, std::ref(round_state), tile, moqi, lizhi, std::ref(game_log));
std::function<std::any()> next_step(std::move(dapai));
return next_step;
}
Expand Down
4 changes: 2 additions & 2 deletions src/simulation/chi.hpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
#if !defined(KANACHAN_SIMULATION_CHI_HPP_INCLUDE_GUARD)
#define KANACHAN_SIMULATION_CHI_HPP_INCLUDE_GUARD

#include "simulation/game_log.hpp"
#include "simulation/round_state.hpp"
#include <boost/python/dict.hpp>
#include <any>
#include <cstdint>


namespace Kanachan{

std::any chi(
Kanachan::RoundState &round_state, std::uint_fast8_t encode, boost::python::dict result);
Kanachan::RoundState &round_state, std::uint_fast8_t encode, Kanachan::GameLog &game_log);

} // namespace Kanachan

Expand Down
20 changes: 4 additions & 16 deletions src/simulation/daminggang.cpp
Original file line number Diff line number Diff line change
@@ -1,34 +1,22 @@
#include "simulation/daminggang.hpp"

#include "simulation/zimo.hpp"
#include "simulation/game_log.hpp"
#include "simulation/round_state.hpp"
#include "common/throw.hpp"
#include <boost/python/dict.hpp>
#include <any>
#include <utility>
#include <stdexcept>


namespace {

namespace python = boost::python;

} // namespace `anonymous`

namespace Kanachan{

std::any daminggang(Kanachan::RoundState &round_state, python::dict result)
std::any daminggang(Kanachan::RoundState &round_state, Kanachan::GameLog &game_log)
{
if (result.is_none()) {
KANACHAN_THROW<std::invalid_argument>("`result` must not be a `None`.");
}

round_state.onDaminggang();
round_state.onDaminggang(game_log);

// Si Gang San Le (四槓散了) の成立は打牌直後.

// Zimo (自摸)
auto zimo = std::bind(&Kanachan::zimo, std::ref(round_state), result);
auto zimo = std::bind(&Kanachan::zimo, std::ref(round_state), std::ref(game_log));
std::function<std::any()> next_step(std::move(zimo));
return next_step;
}
Expand Down
4 changes: 2 additions & 2 deletions src/simulation/daminggang.hpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#if !defined(KANACHAN_SIMULATION_DAMINGGANG_HPP_INCLUDE_GUARD)
#define KANACHAN_SIMULATION_DAMINGGANG_HPP_INCLUDE_GUARD

#include "simulation/game_log.hpp"
#include "simulation/round_state.hpp"
#include <boost/python/dict.hpp>
#include <any>


namespace Kanachan{

std::any daminggang(Kanachan::RoundState &round_state, boost::python::dict result);
std::any daminggang(Kanachan::RoundState &round_state, Kanachan::GameLog &game_log);

} // namespace Kanachan

Expand Down
Loading

0 comments on commit 0cdfbd0

Please sign in to comment.