Skip to content
Snippets Groups Projects
Commit 2d53557c authored by Andi Gerken's avatar Andi Gerken
Browse files

Initial commit

parents
Branches
No related tags found
No related merge requests found
%% Cell type:code id: tags:
```
%reload_ext autoreload
%autoreload 2
import fish_models
import numpy as np
import copy
from pathlib import Path
import json
requested_dset = "lfftr_aav_4_2piw_10_2pif_rd2_mf10_oris"
dataset_json = json.load(open("../buffer/datasets.json", "r"))
meta_config = dataset_json[requested_dset]
config = meta_config["dataset_config"]
```
%% Cell type:code id: tags:
```
modified_config = copy.deepcopy(config)
if config["data_path"] == "live_female_female/train":
modified_config["data_path"] = fish_models.live_female_female_data() / "train"
else:
raise Exception("Not supported data")
translate = {
"pi":np.pi,
"2pi":2 * np.pi
}
for k, v in modified_config["raycast"].items():
if type(v) is not list and v in translate:
modified_config["raycast"][k] = translate[v]
if "speed_bins" in modified_config:
sb = modified_config["speed_bins"]
modified_config["speed_bin_borders"] = np.linspace(sb[0], sb[1], sb[2])
modified_config.pop("speed_bins")
if "turn_bins" in modified_config:
sb = modified_config["turn_bins"]
modified_config["turn_bin_borders"] = np.linspace(sb[0], sb[1], sb[2])
modified_config.pop("turn_bins")
raycast = fish_models.gym_interface.Raycast(world_bounds=([-50, -50], [50, 50]),**modified_config["raycast"])
modified_config["raycast"] = raycast
dset = fish_models.datasets.io_dataset.IoDataset(**modified_config)
storage_path = Path("../buffer") / meta_config["path"]
if storage_path.exists():
raise Exception(f"File exists and will be overwritten.")
dset.store(storage_path)
```
%% Output
Loading data from 20 files.
100%|██████████| 20/20 [00:23<00:00, 1.17s/it]
Calculating views from 20 files.
20%|██ | 4/20 [00:11<00:48, 3.01s/it]
%% Cell type:code id: tags:
```
```
# Correct settings for php server
find . -type d -exec chmod 755 {} \;
find . -type f -exec chmod 655 {} \;
%% Cell type:code id: tags:
```
import datetime
import json
dataset_config= {
"data_path": "live_female_female/train",
"raycast": {
"n_wall_raycasts": 4,
"n_fish_bins": 10,
"fov_angle_fish_bins": "2pi",
"fov_angle_wall_raycasts": "2pi"
},
"output_strings": [
"actions",
"actions_binned",
"views"
],
"speed_bins":(-2,20,21),
"turn_bins":(-10,10,21),
"reduce_dim": 2,
"max_files": 10
}
name = ""
if dataset_config["data_path"] == "live_female_female/train":
name += "lfftr_"
for o in dataset_config["output_strings"]:
# add the first character of the output strings
name += o[0]
ray_config = dataset_config["raycast"]
name += f"_{ray_config['n_wall_raycasts']}_{ray_config['fov_angle_wall_raycasts']}w_{ray_config['n_fish_bins']}_{ray_config['fov_angle_fish_bins']}f_rd{dataset_config['reduce_dim']}"
if dataset_config["max_files"] is not None:
name += f"_mf{dataset_config['max_files']}"
today = datetime.date.today().strftime("%d.%m.%Y")
output_config = {name:{"path":f"datasets/{name + '.pickle'}", "date":today, "dataset_config":dataset_config}}
print(json.dumps(output_config, indent=4))
```
%% Output
{
"lfftr_av_4_2piw_10_2pif_rd2_mf20": {
"path": "datasets/lfftr_av_4_2piw_10_2pif_rd2_mf20.pickle",
"date": "01.09.2021",
"dataset_config": {
"data_path": "live_female_female/train",
"raycast": {
"n_wall_raycasts": 4,
"n_fish_bins": 10,
"fov_angle_fish_bins": "2pi",
"fov_angle_wall_raycasts": "2pi"
},
"output_strings": [
"actions",
"views"
],
"reduce_dim": 2,
"max_files": 20
}
}
}
%% Cell type:code id: tags:
```
```
%% Cell type:code id: tags:
```
%reload_ext autoreload
%autoreload 2
import fish_models
import numpy as np
import copy
from pathlib import Path
import json
requested_dset = "lfftr_aav_4_2piw_10_2pif_rd2_mf10_oris"
dataset_json = json.load(open("../buffer/datasets.json", "r"))
meta_config = dataset_json[requested_dset]
config = meta_config["dataset_config"]
```
%% Cell type:code id: tags:
```
modified_config = copy.deepcopy(config)
if config["data_path"] == "live_female_female/train":
modified_config["data_path"] = fish_models.live_female_female_data() / "train"
else:
raise Exception("Not supported data")
translate = {
"pi":np.pi,
"2pi":2 * np.pi
}
for k, v in modified_config["raycast"].items():
if type(v) is not list and v in translate:
modified_config["raycast"][k] = translate[v]
if "speed_bins" in modified_config:
sb = modified_config["speed_bins"]
modified_config["speed_bin_borders"] = np.linspace(sb[0], sb[1], sb[2])
modified_config.pop("speed_bins")
if "turn_bins" in modified_config:
sb = modified_config["turn_bins"]
modified_config["turn_bin_borders"] = np.linspace(sb[0], sb[1], sb[2])
modified_config.pop("turn_bins")
raycast = fish_models.gym_interface.Raycast(world_bounds=([-50, -50], [50, 50]),**modified_config["raycast"])
modified_config["raycast"] = raycast
dset = fish_models.datasets.io_dataset.IoDataset(**modified_config)
storage_path = Path("../buffer") / meta_config["path"]
if storage_path.exists():
raise Exception(f"File exists and will be overwritten.")
dset.store(storage_path)
```
%% Output
Loading data from 20 files.
100%|██████████| 20/20 [00:23<00:00, 1.17s/it]
Calculating views from 20 files.
100%|██████████| 20/20 [01:00<00:00, 3.02s/it]
Reducing shape of actions.
100%|██████████| 359560/359560 [00:02<00:00, 147861.50it/s]
Reducing shape of actions_binned.
100%|██████████| 359560/359560 [00:02<00:00, 156212.84it/s]
Reducing shape of views.
100%|██████████| 359560/359560 [00:02<00:00, 153981.95it/s]
Status of IoDataset:
The first 3 dimensions are reduced from (20, 2, 8989) to (359560)
actions (359560, 2): consisting of speed [cm/s] and turn [rad/s].
actions_binned (359560, 2): consisting of speed [discretized into 20 bins] and turn [discretized into 20 bins]
views (359560, 24): consisting of 10 fish bins, 4 wall raycasts and 10 fish oris.
%% Cell type:code id: tags:
```
```
%% Cell type:code id: tags:
```
%reload_ext autoreload
%autoreload 2
import fish_models
import inspect
from pathlib import Path
import json
import numpy as np
import matplotlib.pyplot as plt
requested_model = "ClassificationModel_v0"
models_config = json.load(open("../buffer/models.json", "r"))
config = models_config[requested_model]
# Load Dataset
dset = fish_models.datasets.IoDataset.load_from_name(config["dataset"], verbose=True)
# Load model class
mod = fish_models.utils.load_class_from_string(config["model_class"])
# Get the arguments for the model and extract the raycast if it is needed.
args = config["model_args"]
if "raycast" in inspect.signature(mod.__init__).parameters:
args["raycast"] = dset.raycast
# Initialize and train the model
model = mod(**config["model_args"])
model.train(dset, **config["train_args"])
# Store the model to buffer
storage_path = Path("../buffer") /config["path"]
if storage_path.exists():
print(f"File exists and will be overwritten.")
zip_location = fish_models.ModelStorage.store(storage_path, model, verbose=True, overwrite=True)
```
%% Output
Dataset already in storage, loading from /home/andi/blubber_workspace/tmp/fish_models/storage/datasets/lfftr_aav_4_2piw_10_2pif_rd2_mf10.pickle
dict_keys(['verbose', 'max_iter', 'hidden_layer_sizes'])
Iteration 1, loss = 7.29166196
Iteration 2, loss = 6.84582425
Iteration 3, loss = 6.81930853
Iteration 4, loss = 6.79500032
Iteration 5, loss = 6.77759737
Iteration 6, loss = 6.76833231
Iteration 7, loss = 6.76213111
Iteration 8, loss = 6.75681177
Iteration 9, loss = 6.75277971
Iteration 10, loss = 6.74918457
Iteration 11, loss = 6.74572998
Iteration 12, loss = 6.74265458
Iteration 13, loss = 6.73982576
Iteration 14, loss = 6.73700406
Iteration 15, loss = 6.73430564
Iteration 16, loss = 6.73222687
Iteration 17, loss = 6.72970860
Iteration 18, loss = 6.72762182
Iteration 19, loss = 6.72566669
Iteration 20, loss = 6.72377237
Iteration 21, loss = 6.72211411
Iteration 22, loss = 6.72056207
Iteration 23, loss = 6.71902463
Iteration 24, loss = 6.71801982
Iteration 25, loss = 6.71639951
Iteration 26, loss = 6.71515730
Iteration 27, loss = 6.71445103
Iteration 28, loss = 6.71337329
Iteration 29, loss = 6.71228083
Iteration 30, loss = 6.71136609
File exists and will be overwritten.
Stored model as ../buffer/models/ClassificationModel_v0.zip
/home/andi/.local/lib/python3.8/site-packages/scikit_learn-0.24.2-py3.8-linux-x86_64.egg/sklearn/neural_network/_multilayer_perceptron.py:614: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (30) reached and the optimization hasn't converged yet.
warnings.warn(
%% Cell type:code id: tags:
```
# TODO: This should be dependent on the models needs
size = dset["views"].shape[0]
actions = np.zeros((size, 2))
for i in range(size):
actions[i] = model.choose_action(dset["views"][i])
plt.figure()
plt.hist([actions[:,0], dset["actions"][:,0]], bins=np.linspace(-3,20,20), label=["real speed", "predicted speed"])
plt.legend()
plt.figure()
plt.hist([actions[:,1], dset["actions"][:,1]], bins=np.linspace(-10,10,20), label=["real turn", "predicted turn"])
plt.legend()
plt.show()
```
%% Output
Source diff could not be displayed: it is too large. Options to address this: view the blob.
inotifywait -m ~/blubber_workspace/model_server/public_html/ -e create -e moved_to -e modify -e attrib -e close_write -e delete |
while read dir action file; do
if [[ ! "$file" =~ ".+\.swp" ]]; then
echo "$action happened to $file, copying everything"
rsync -a --delete /home/andi/blubber_workspace/model_server/public_html/ login_zedat:public_html/model_server/
fi
done
{
"lfftr_av_6w_6f_rd2_mf5": {
"path": "datasets/lfftr_av_6w_6f_rd2_mf5.pickle",
"date": "31.08.2021",
"dataset_config": {
"data_path": "live_female_female/train",
"raycast": {
"n_wall_raycasts": 6,
"n_fish_bins": 6,
"fov_angle_fish_bins": "pi",
"fov_angle_wall_raycasts": "pi"
},
"output_strings": [
"actions",
"views"
],
"reduce_dim": 2,
"max_files": 5
}
},
"lfftr_pav_10w_10f_rd2": {
"path": "datasets/lfftr_pav_10w_10f_rd2.pickle",
"date": "31.08.2021",
"dataset_config": {
"data_path": "live_female_female/train",
"raycast": {
"n_wall_raycasts": 10,
"n_fish_bins": 10,
"fov_angle_fish_bins": "pi",
"fov_angle_wall_raycasts": "pi"
},
"output_strings": [
"poses",
"actions",
"views"
],
"reduce_dim": 2,
"max_files": null
}
},
"lfftr_av_4_2piw_10_2pif_rd2_mf20": {
"path": "datasets/lfftr_av_4_2piw_10_2pif_rd2_mf20.pickle",
"date": "01.09.2021",
"dataset_config": {
"data_path": "live_female_female/train",
"raycast": {
"n_wall_raycasts": 4,
"n_fish_bins": 10,
"fov_angle_fish_bins": "2pi",
"fov_angle_wall_raycasts": "2pi"
},
"output_strings": [
"actions",
"views"
],
"reduce_dim": 2,
"max_files": 20
}
},
"lfftr_aav_4_2piw_10_2pif_rd2_mf10": {
"path": "datasets/lfftr_aav_4_2piw_10_2pif_rd2_mf10.pickle",
"date": "05.10.2021",
"dataset_config": {
"data_path": "live_female_female/train",
"raycast": {
"n_wall_raycasts": 4,
"n_fish_bins": 10,
"fov_angle_fish_bins": "2pi",
"fov_angle_wall_raycasts": "2pi"
},
"output_strings": [
"actions",
"actions_binned",
"views"
],
"speed_bins": [
-2,
20,
21
],
"turn_bins": [
-10,
10,
21
],
"reduce_dim": 2,
"max_files": 20
}
},
"lfftr_aav_4_2piw_10_2pif_rd2_mf10_oris": {
"path": "datasets/lfftr_aav_4_2piw_10_2pif_rd2_mf10.pickle",
"date": "05.10.2021",
"dataset_config": {
"data_path": "live_female_female/train",
"raycast": {
"n_wall_raycasts": 4,
"n_fish_bins": 10,
"fov_angle_fish_bins": "2pi",
"fov_angle_wall_raycasts": "2pi",
"view_of": [
"fish",
"walls",
"fish_oris"
]
},
"output_strings": [
"actions",
"actions_binned",
"views"
],
"speed_bins": [
-2,
20,
21
],
"turn_bins": [
-10,
10,
21
],
"reduce_dim": 2,
"max_files": 20
}
}
}
\ No newline at end of file
{
"ReplayModel_v0": {
"path": "models/ReplayModel_v0.zip",
"date": "31.08.2021",
"model_class": "fish_models.models.released.replay_model.ReplayModel",
"model_args": {},
"train_args": {},
"dataset": "lfftr_av_6w_6f_rd2_mf5"
},
"KNNModel_v0": {
"path": "models/KNNModel_v0.zip",
"date": "31.08.2021",
"model_class": "fish_models.models.released.knn_model.KNNModel",
"model_args": {
"k": 3
},
"train_args": {},
"dataset": "lfftr_av_6w_6f_rd2_mf5"
},
"ClusterModel_v0": {
"path": "models/ClusterModel_v0.zip",
"date": "01.09.2021",
"model_class": "fish_models.models.released.cluster_model.ClusterModel",
"model_args": {
"cluster_method": "KMeans",
"cluster_args": {
"n_init": 100,
"n_clusters": 200
}
},
"train_args": {},
"dataset": "lfftr_av_4_2piw_10_2pif_rd2_mf20"
},
"ClusterModel_v1": {
"path": "models/ClusterModel_v1.zip",
"date": "01.09.2021",
"model_class": "fish_models.models.released.cluster_model.ClusterModel",
"model_args": {
"cluster_method": "KMeans",
"cluster_args": {
"n_init": 100,
"n_clusters": 200
}
},
"description": "This model was trained with fish, wall and fish_oris inputs.",
"train_args": {},
"dataset": "lfftr_aav_4_2piw_10_2pif_rd2_mf10"
},
"ClassificationModel_v0": {
"path": "models/ClassificationModel_v0.zip",
"date": "05.10.2021",
"model_class": "fish_models.models.andi.classification_model.ClassificationModel",
"model_args": {
"classifier_method": "MLP",
"classifier_args": {
"verbose": true,
"max_iter": 30,
"hidden_layer_sizes": [
128
]
}
},
"train_args": {},
"dataset": "lfftr_aav_4_2piw_10_2pif_rd2_mf10"
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment