-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcheck_data_structure.py
75 lines (68 loc) · 2.76 KB
/
check_data_structure.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import pickle
'''
Script is used to visualize structure and dimension of a pickle data file
Data structure:
list(
dict(
"observations": nparray(nparray(np.float32)),
"next_observations": nparray(nparray(np.float32)),
"actions": nparray(nparray(np.float32)),
"rewards": nparray(np.oat32),
"terminals": nparray(np.bool_)
)
)
'''
if __name__ == '__main__':
file = "./data/rb1_8759x5x1.pkl"
with open(file, "rb") as f:
data = pickle.load(f)
print("data ", type(data), "length: ", len(data))
dict0 = data[0]
print("data[0] ", type(dict0), "length: ", len(dict0))
print("data[0].keys() ", dict0.keys())
print()
observations = dict0['observations']
print("observations ", type(observations), "length: ", len(observations))
observations0 = observations[0]
print("observations[0] ", type(observations0), "length: ", len(observations0))
observations00 = observations[0][0]
print("observations[0][0] ", type(observations00))
print()
next_observations = dict0['next_observations']
print("next_observations ", type(next_observations), "length: ", len(next_observations))
next_observations0 = next_observations[0]
print("next_observations[0] ", type(next_observations0), "length: ", len(next_observations0))
next_observations00 = next_observations[0][0]
print("next_observations[0][0] ", type(next_observations00))
print()
actions = dict0['actions']
print("actions ", type(actions), "length: ", len(actions))
actions0 = actions[0]
print("actions[0] ", type(actions0), "length: ", len(actions0))
actions00 = actions[0][0]
print("actions[0][0] ", type(actions00))
print()
rewards = dict0['rewards']
print("rewards ", type(rewards), "length: ", len(rewards))
rewards0 = rewards[0]
print("rewards[0] ", type(rewards0))
print()
terminals = dict0['terminals']
print("terminals ", type(terminals), "length: ", len(terminals))
terminals0 = terminals[0]
print("terminals[0] ", type(terminals0))
print()
print("========================= Data Size =============================")
length = 0
for d in data:
if len(d["observations"]) > length:
length = len(d["observations"])
print("Amount Of Sequences: ", len(data))
print("Longest Sequence: ", length)
file_size = os.stat(file).st_size
if file_size > 1e+6:
string_byte = "(" + str(round(file_size / 1e+6)) + " MB)"
else:
string_byte = "(" + str(round(file_size / 1e+3)) + " kB)"
print(file, string_byte)