This commit is contained in:
Wojciech Janota 2023-11-10 23:20:47 +01:00
parent 4ca41494ee
commit 83d5eb9ad7
8 changed files with 147 additions and 0 deletions

8
longinus/.idea/.gitignore generated vendored Normal file
View File

@ -0,0 +1,8 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

13
longinus/.idea/longinus.iml generated Normal file
View File

@ -0,0 +1,13 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="EMPTY_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/helpers/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/helpers/target" />
<excludeFolder url="file://$MODULE_DIR$/target" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

8
longinus/.idea/modules.xml generated Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/longinus.iml" filepath="$PROJECT_DIR$/.idea/longinus.iml" />
</modules>
</component>
</project>

6
longinus/.idea/vcs.xml generated Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
</component>
</project>

9
longinus/Cargo.toml Normal file
View File

@ -0,0 +1,9 @@
[package]
name = "longinus"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
helpers = { path = "helpers" }

View File

@ -0,0 +1,9 @@
[package]
name = "helpers"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
num-traits = "0.2.17"

View File

@ -0,0 +1,91 @@
mod min_max_stats {
use std::f64;
#[derive(Debug)]
#[derive(PartialEq)]
pub struct MinMaxStats {
pub(crate) maximum: f64,
pub(crate) minimum: f64
}
pub struct KnownBounds {
pub(crate) maximum: f64,
pub(crate) minimum: f64
}
pub fn initialize(known_bounds: Option<KnownBounds>) -> MinMaxStats {
let f64_max: f64 = f64::MAX;
//let default_known_bounds: KnownBounds = KnownBounds {maximum: - f64_max, minimum: f64_max };
let int_known_bounds: KnownBounds = known_bounds.unwrap_or(KnownBounds {maximum: -f64_max, minimum: f64_max} );
let min_max_stats: MinMaxStats = MinMaxStats {maximum: int_known_bounds.maximum, minimum: int_known_bounds.minimum};
return min_max_stats;
}
pub fn normalize(min_max_stats_obj: MinMaxStats, value: Option<f64>) -> Option<f64> {
if min_max_stats_obj.maximum > min_max_stats_obj.minimum {
if value.is_none(){
return None;
}
else {
let result: Option<f64> = Some((value.unwrap() - min_max_stats_obj.minimum) / (min_max_stats_obj.maximum - min_max_stats_obj.minimum));
return result;
}
}
else {
return None;
}
}
}
mod config{
use crate::min_max_stats;
pub struct MuZeroConfig {
pub(crate) action_space_size: i64,
pub(crate) observation_space_size: i64,
pub(crate) max_moves: i64,
pub(crate) discount: f64,
pub(crate) dirichlet_alpha_value: f64,
pub(crate) simulations_count: i64,
pub(crate) batch_size: i64,
pub(crate) td_steps: i64,
pub(crate) actors_count: i64,
pub(crate) lr_init: i64,
pub(crate) lr_decay_steps: i64,
pub(crate) training_episodes: i64,
pub(crate) hidden_layer_size: i64,
pub(crate) visit_softmax_temp_fn: i64,
pub(crate) known_bounds: min_max_stats::KnownBounds
}
}
#[cfg(test)]
mod tests {
use crate::min_max_stats::KnownBounds;
use super::*;
#[test]
fn initialize_min_max_stats_no_args() {
let result: min_max_stats::MinMaxStats = min_max_stats::initialize(None);
let test: min_max_stats::MinMaxStats = min_max_stats::MinMaxStats {minimum: f64::MAX, maximum: -f64::MAX};
assert_eq!(test, result);
}
#[test]
fn initialize_min_max_stats_bounds_args() {
let bounds: min_max_stats::KnownBounds = KnownBounds {minimum: 12.0, maximum: 80.0};
let result: min_max_stats::MinMaxStats = min_max_stats::initialize(Some(bounds));
let test: min_max_stats::MinMaxStats = min_max_stats::MinMaxStats {minimum: 12.0, maximum: 80.0};
assert_eq!(test, result);
}
#[test]
fn min_max_stats_normalize() {
let bounds: min_max_stats::KnownBounds = KnownBounds {minimum: 12.0, maximum: 80.0};
let min_max_stats_obj: min_max_stats::MinMaxStats = min_max_stats::initialize(Some(bounds));
let value: f64 = 30.0;
let value_normalized: f64 = min_max_stats::normalize(min_max_stats_obj, Some(value)).unwrap_or(0.0);
let test: f64 = (30.0 - 12.0) / (80.0 - 12.0);
assert_eq!(test, value_normalized);
}
}

3
longinus/src/main.rs Normal file
View File

@ -0,0 +1,3 @@
fn main() {
println!("Hello, world!");
}