First
This commit is contained in:
parent
4ca41494ee
commit
83d5eb9ad7
8
longinus/.idea/.gitignore
generated
vendored
Normal file
8
longinus/.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
13
longinus/.idea/longinus.iml
generated
Normal file
13
longinus/.idea/longinus.iml
generated
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="EMPTY_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$">
|
||||||
|
<sourceFolder url="file://$MODULE_DIR$/helpers/src" isTestSource="false" />
|
||||||
|
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/helpers/target" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/target" />
|
||||||
|
</content>
|
||||||
|
<orderEntry type="inheritedJdk" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
8
longinus/.idea/modules.xml
generated
Normal file
8
longinus/.idea/modules.xml
generated
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/longinus.iml" filepath="$PROJECT_DIR$/.idea/longinus.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
6
longinus/.idea/vcs.xml
generated
Normal file
6
longinus/.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
9
longinus/Cargo.toml
Normal file
9
longinus/Cargo.toml
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
[package]
|
||||||
|
name = "longinus"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
helpers = { path = "helpers" }
|
9
longinus/helpers/Cargo.toml
Normal file
9
longinus/helpers/Cargo.toml
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
[package]
|
||||||
|
name = "helpers"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
num-traits = "0.2.17"
|
91
longinus/helpers/src/lib.rs
Normal file
91
longinus/helpers/src/lib.rs
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
mod min_max_stats {
|
||||||
|
use std::f64;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
#[derive(PartialEq)]
|
||||||
|
pub struct MinMaxStats {
|
||||||
|
pub(crate) maximum: f64,
|
||||||
|
pub(crate) minimum: f64
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct KnownBounds {
|
||||||
|
pub(crate) maximum: f64,
|
||||||
|
pub(crate) minimum: f64
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn initialize(known_bounds: Option<KnownBounds>) -> MinMaxStats {
|
||||||
|
let f64_max: f64 = f64::MAX;
|
||||||
|
//let default_known_bounds: KnownBounds = KnownBounds {maximum: - f64_max, minimum: f64_max };
|
||||||
|
let int_known_bounds: KnownBounds = known_bounds.unwrap_or(KnownBounds {maximum: -f64_max, minimum: f64_max} );
|
||||||
|
let min_max_stats: MinMaxStats = MinMaxStats {maximum: int_known_bounds.maximum, minimum: int_known_bounds.minimum};
|
||||||
|
return min_max_stats;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn normalize(min_max_stats_obj: MinMaxStats, value: Option<f64>) -> Option<f64> {
|
||||||
|
if min_max_stats_obj.maximum > min_max_stats_obj.minimum {
|
||||||
|
if value.is_none(){
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
let result: Option<f64> = Some((value.unwrap() - min_max_stats_obj.minimum) / (min_max_stats_obj.maximum - min_max_stats_obj.minimum));
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mod config{
|
||||||
|
use crate::min_max_stats;
|
||||||
|
|
||||||
|
pub struct MuZeroConfig {
|
||||||
|
pub(crate) action_space_size: i64,
|
||||||
|
pub(crate) observation_space_size: i64,
|
||||||
|
pub(crate) max_moves: i64,
|
||||||
|
pub(crate) discount: f64,
|
||||||
|
pub(crate) dirichlet_alpha_value: f64,
|
||||||
|
pub(crate) simulations_count: i64,
|
||||||
|
pub(crate) batch_size: i64,
|
||||||
|
pub(crate) td_steps: i64,
|
||||||
|
pub(crate) actors_count: i64,
|
||||||
|
pub(crate) lr_init: i64,
|
||||||
|
pub(crate) lr_decay_steps: i64,
|
||||||
|
pub(crate) training_episodes: i64,
|
||||||
|
pub(crate) hidden_layer_size: i64,
|
||||||
|
pub(crate) visit_softmax_temp_fn: i64,
|
||||||
|
pub(crate) known_bounds: min_max_stats::KnownBounds
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::min_max_stats::KnownBounds;
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn initialize_min_max_stats_no_args() {
|
||||||
|
let result: min_max_stats::MinMaxStats = min_max_stats::initialize(None);
|
||||||
|
let test: min_max_stats::MinMaxStats = min_max_stats::MinMaxStats {minimum: f64::MAX, maximum: -f64::MAX};
|
||||||
|
assert_eq!(test, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn initialize_min_max_stats_bounds_args() {
|
||||||
|
let bounds: min_max_stats::KnownBounds = KnownBounds {minimum: 12.0, maximum: 80.0};
|
||||||
|
let result: min_max_stats::MinMaxStats = min_max_stats::initialize(Some(bounds));
|
||||||
|
let test: min_max_stats::MinMaxStats = min_max_stats::MinMaxStats {minimum: 12.0, maximum: 80.0};
|
||||||
|
assert_eq!(test, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn min_max_stats_normalize() {
|
||||||
|
let bounds: min_max_stats::KnownBounds = KnownBounds {minimum: 12.0, maximum: 80.0};
|
||||||
|
let min_max_stats_obj: min_max_stats::MinMaxStats = min_max_stats::initialize(Some(bounds));
|
||||||
|
let value: f64 = 30.0;
|
||||||
|
let value_normalized: f64 = min_max_stats::normalize(min_max_stats_obj, Some(value)).unwrap_or(0.0);
|
||||||
|
let test: f64 = (30.0 - 12.0) / (80.0 - 12.0);
|
||||||
|
assert_eq!(test, value_normalized);
|
||||||
|
}
|
||||||
|
}
|
3
longinus/src/main.rs
Normal file
3
longinus/src/main.rs
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
fn main() {
|
||||||
|
println!("Hello, world!");
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user