Skip to content

Commit b4dc079

Browse files
committed
add_two_dataset
1 parent 3eff2b7 commit b4dc079

File tree

12 files changed

+464
-33
lines changed

12 files changed

+464
-33
lines changed

ddlearn/DG_aug.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ def __init__(self, n_feature, n_act_class, n_aug_class, dataset, dp):
1616
self.n_aug_class = n_aug_class
1717
self.dataset = dataset
1818
self.dp = dp
19-
self.feature_module = net.Network(n_feature, dataset)
19+
if dataset == 'uschad':
20+
self.feature_module = net.Network_usc(n_feature, dataset)
21+
else:
22+
self.feature_module = net.Network(n_feature, dataset)
2023
self.act_cls = nn.Linear(n_feature, n_act_class)
2124
self.aug_cls = nn.Linear(n_feature, n_aug_class)
2225
self.criterion = nn.CrossEntropyLoss()

ddlearn/data_preprocess/deal_dsads.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def load_data(root_path, winsize, overlapsize):
3131
# merge p/ s01-s60 as data_sub
3232
data_sub = np.zeros((1, 45))
3333
for j in range(len(subname_list)):
34-
# subfile = subname_list[j]
3534
data_i = []
3635
if j < 9:
3736
name = '0' + str(j+1)

ddlearn/data_preprocess/deal_pamap.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
import numpy as np
4+
import os
5+
from math import isnan
6+
7+
8+
def load_data(root_path, winsize, overlap):
9+
file_list = os.listdir(root_path)
10+
list_len = len(file_list)
11+
x_all, y_all, s_all = [], [], []
12+
for filenum in range(list_len):
13+
data_i = []
14+
filename = file_list[filenum]
15+
data_i = np.loadtxt(os.path.join(root_path, filename))
16+
subject = int(filename.split('0')[1].split('.')[0])
17+
x_i = np.hstack((data_i[:, 4:7], data_i[:, 7:13], data_i[:, 21:24],
18+
data_i[:, 27:33], data_i[:, 38:41], data_i[:, 44:50]))
19+
y_i = data_i[:, 1]
20+
tx, ty, ts = getwin_replace(x_i, y_i, subject,
21+
winsize=winsize, overlap=overlap)
22+
if filenum == 0:
23+
x_all, y_all, s_all = tx, ty, ts
24+
else:
25+
x_all = np.vstack((x_all, tx))
26+
y_all = np.vstack((y_all, ty))
27+
s_all = np.vstack((s_all, ts))
28+
print('a')
29+
return x_all, y_all, s_all
30+
31+
32+
def getwin_replace(x, y, s, winsize, overlap):
33+
data_num = len(x)
34+
overlap_size = int(winsize*overlap)
35+
stepsize = winsize-overlap_size
36+
head, tail = 0, winsize
37+
xx, yy = [], []
38+
while tail <= data_num:
39+
ry = np.unique(y[head:tail])
40+
if len(ry) == 1:
41+
x_win = x[head:tail, :]
42+
x_new = replace_nan(x_win)
43+
xx.append(x_new)
44+
yy.append(y[head])
45+
head += stepsize
46+
tail += stepsize
47+
else:
48+
head = tail-1
49+
while y[head] == y[head-1]:
50+
head -= 1
51+
tail = head + winsize
52+
ss = np.ones(len(yy)) * s
53+
return np.array(xx), np.array(yy).reshape(-1, 1), np.array(ss).reshape(-1, 1)
54+
55+
56+
def replace_nan(x_win):
57+
x_new = []
58+
for col in range(x_win.shape[1]):
59+
x_col = x_win[:, col]
60+
x_col_mean = calculate_mean_value(x_col)
61+
index_nan = np.argwhere(np.isnan(x_col))
62+
x_col[index_nan] = x_col_mean
63+
if col == 0:
64+
x_new = x_col.reshape(-1, 1)
65+
else:
66+
x_new = np.hstack((x_new, x_col.reshape(-1, 1)))
67+
return x_new
68+
69+
70+
def calculate_mean_value(x):
71+
x_new = []
72+
for x_i in x:
73+
if isnan(x_i):
74+
continue
75+
else:
76+
x_new.append(x_i)
77+
x_mean = np.mean(np.array(x_new), axis=0)
78+
return x_mean
79+
80+
81+
def get_pamap_npy(root_path, save_path, winsize, overlap):
82+
if os.path.exists(save_path+'pamap_processwin.npz'):
83+
pass
84+
else:
85+
x, y, s = load_data(root_path, winsize, overlap)
86+
np.savez(save_path+'pamap_processwin.npz', x=x, y=y, s=s)
87+
88+
89+
if __name__ == '__main__':
90+
root_path = '/home/data/process/raw/PAMAP/PAMAP2_Dataset/Protocol/'
91+
save_path = '/home/data/process/pamap/'
92+
winsize = 512
93+
overlap = 0.5
94+
get_pamap_npy(root_path, save_path, winsize, overlap)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
import numpy as np
4+
import scipy.io
5+
import os
6+
7+
8+
def getwin(x, y, s, winsize, overlapsize):
9+
l = len(x)
10+
stepsize = winsize-overlapsize
11+
h, t = 0, winsize
12+
xx, yy, ss = [], [], []
13+
while t <= l:
14+
ry = np.unique(y[h:t])
15+
rs = np.unique(s[h:t])
16+
if len(ry) == 1 and len(rs) == 1:
17+
xx.append(x[h:t, :])
18+
yy.append(y[h])
19+
ss.append(s[h])
20+
else:
21+
print("error!")
22+
h += stepsize
23+
t += stepsize
24+
return np.array(xx), np.array(yy).reshape(-1, 1), np.array(ss).reshape(-1, 1)
25+
26+
27+
def get_npy(root_path, save_path, winsize, overlapsize):
28+
if os.path.exists(save_path+'uschad_processwin.npz'):
29+
pass
30+
else:
31+
x, y, s = get_raw_data_deal(root_path, winsize, overlapsize)
32+
np.savez(save_path+'uschad_processwin.npz', x=x, y=y, s=s)
33+
34+
35+
def get_raw_data_deal(root_path, winsize, overlapsize):
36+
file_name = os.listdir(root_path)
37+
x_all, y_all, s_all = np.zeros(
38+
(1, winsize, 6)), np.zeros((1, 1)), np.zeros((1, 1))
39+
sub_folder_list = []
40+
for i in file_name:
41+
if i == 'Readme.txt' or i == 'displayData_acc.m' or i == 'displayData_gyro.m':
42+
continue
43+
else:
44+
sub_folder_list.append(i)
45+
for subfolder in sub_folder_list:
46+
sub = subfolder.split('t')[1]
47+
path = os.path.join(root_path, subfolder)
48+
file_list = os.listdir(path)
49+
for file in file_list:
50+
data = scipy.io.loadmat(os.path.join(path, file))
51+
x, act_num = data['sensor_readings'], data['activity_number'] if 'activity_number' in data else data['activity_numbr']
52+
y = np.ones(x.shape[0]) * int(act_num[0])
53+
s = np.ones(x.shape[0]) * int(sub)
54+
tx, ty, ts = getwin(x, y, s, winsize, overlapsize)
55+
x_all, y_all, s_all = np.vstack((x_all, tx)), np.vstack(
56+
(y_all, ty)), np.vstack((s_all, ts))
57+
x_all, y_all, s_all = x_all[1:], y_all[1:], s_all[1:]
58+
return x_all, y_all, s_all
59+
60+
61+
if __name__ == '__main__':
62+
winsize = 500
63+
overlap = 0.5
64+
overlapsize = int(winsize*overlap)
65+
root_path = '/home/data/usc-had/raw/USC-HAD/'
66+
save_path = '/home/data/process/uschad/'
67+
x, y, s = get_raw_data_deal(root_path, winsize, overlapsize)
68+
get_npy(root_path, save_path, winsize, overlapsize)

ddlearn/data_util/data_preprocess_devide_domain.py

Lines changed: 97 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,106 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4+
import pickle
5+
from sklearn.model_selection import train_test_split
6+
import numpy as np
7+
import utils
8+
from main import args_parse
9+
from raw_aug_loader import set_param
410
import sys
511
import os
612
sys.path.append(os.path.dirname(sys.path[0]))
7-
from raw_aug_loader import set_param
8-
from main import args_parse
9-
import utils
10-
import numpy as np
11-
from sklearn.model_selection import train_test_split
12-
import pickle
13+
14+
# ============ PAMAP2 ===============
15+
16+
17+
def merge_split_pamap(seed, root_path="/home/data/process/pamap/pamap_processwin.npz", n_domain=4, save_file='/home/data/process/pamap/pamap_subject_final.pkl'):
18+
d = np.load(root_path)
19+
x, y, s = d['x'], d['y'].reshape(-1,), d['s'].reshape(-1,)
20+
x_new, y_new, s_new = select_sub_act(x, y, s)
21+
y_new = y_new-1
22+
s_new = s_new-1
23+
data_lst = []
24+
for i in range(n_domain):
25+
data_i = []
26+
d_index = np.argwhere((s_new == 2*i) | (s_new == 2*i+1)).reshape(-1,)
27+
x_i = x_new[d_index, :, :]
28+
y_i = y_new[d_index]
29+
data_i.append(x_i)
30+
data_i.append(y_i)
31+
data_lst.append(data_i)
32+
33+
devide_train_val_test(data_lst, n_domain, save_file, seed)
34+
35+
36+
def select_sub_act(x, y, s):
37+
x_new, y_new, s_new = [], [], []
38+
sub_list = [1, 2, 3, 4, 5, 6, 7, 8]
39+
act_list = [1, 2, 3, 4, 12, 13, 16, 17]
40+
for index in range(len(y)):
41+
if (s[index] in sub_list) and (y[index] in act_list):
42+
x_new.append(x[index])
43+
y_new.append(y[index])
44+
s_new.append(s[index])
45+
else:
46+
continue
47+
x_new, y_new, s_new = np.array(x_new), np.array(y_new), np.array(s_new)
48+
index_5 = np.argwhere(y_new == 12)
49+
y_new[index_5] = 5
50+
index_6 = np.argwhere(y_new == 13)
51+
y_new[index_6] = 6
52+
index_7 = np.argwhere(y_new == 16)
53+
y_new[index_7] = 7
54+
index_8 = np.argwhere(y_new == 17)
55+
y_new[index_8] = 8
56+
return x_new, y_new, s_new
57+
58+
59+
# ============ USC-HAD ===============
60+
def merge_split_uschad(seed, root_path='/home/data/process/uschad/uschad_processwin.npz', n_domain=5, save_file='/home/data/process/uschad/uschad_subject_final.pkl'):
61+
d = np.load(root_path)
62+
x, y, s = d['x'], (d['y']-1).reshape(-1,), d['s'].reshape(-1,)
63+
data_lst = []
64+
data_0, data_1, data_2, data_3, data_4 = [], [], [], [], []
65+
66+
d_index_0 = np.argwhere((s == 1) | (s == 3) | (s == 10)).reshape(-1,)
67+
x_0 = x[d_index_0]
68+
y_0 = y[d_index_0]
69+
data_0.append(x_0)
70+
data_0.append(y_0)
71+
data_lst.append(data_0)
72+
73+
d_index_1 = np.argwhere((s == 2) | (s == 5) | (s == 13)).reshape(-1,)
74+
x_1 = x[d_index_1]
75+
y_1 = y[d_index_1]
76+
data_1.append(x_1)
77+
data_1.append(y_1)
78+
data_lst.append(data_1)
79+
80+
d_index_2 = np.argwhere((s == 4) | (s == 7) | (s == 9)).reshape(-1,)
81+
x_2 = x[d_index_2]
82+
y_2 = y[d_index_2]
83+
data_2.append(x_2)
84+
data_2.append(y_2)
85+
data_lst.append(data_2)
86+
87+
d_index_3 = np.argwhere((s == 6) | (s == 8) | (s == 14)).reshape(-1,)
88+
x_3 = x[d_index_3]
89+
y_3 = y[d_index_3]
90+
data_3.append(x_3)
91+
data_3.append(y_3)
92+
data_lst.append(data_3)
93+
94+
d_index_4 = np.argwhere((s == 11) | (s == 12)).reshape(-1,)
95+
x_4 = x[d_index_4]
96+
y_4 = y[d_index_4]
97+
data_4.append(x_4)
98+
data_4.append(y_4)
99+
data_lst.append(data_4)
100+
101+
devide_train_val_test(data_lst, n_domain, save_file, seed)
102+
103+
# =================DSADS=====================
13104

14105

15106
def merge_split_dsads(seed, root_path='/home/data/process/dsads/dsads_processwin.npz', n_domain=4, save_file='/home/data/process/dsads/dsads_subject_final.pkl'):

ddlearn/data_util/raw_aug_loader.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,17 @@ def reshape_data(x, dataset, when):
128128
if when == 'begin':
129129
if dataset == 'dsads':
130130
x = x.reshape(-1, 45)
131+
elif dataset == 'uschad':
132+
x = x.reshape(-1, 6)
133+
elif dataset == 'pamap':
134+
x = x.reshape(-1, 27)
131135
elif when == 'end':
132136
if dataset == 'dsads':
133137
x = x.reshape(-1, 125, 45)
138+
elif dataset == 'uschad':
139+
x = x.reshape(-1, 500, 6)
140+
elif dataset == 'pamap':
141+
x = x.reshape(-1, 512, 27)
134142
else:
135143
print("error")
136144
return x
@@ -171,15 +179,17 @@ def pick_data(data, data_type, data_name, src):
171179
def set_param(dataset):
172180
if dataset == 'dsads':
173181
n_domain = 4
174-
else:
175-
print("no matching dataset")
182+
elif dataset == 'pamap':
183+
n_domain = 4
184+
elif dataset == 'uschad':
185+
n_domain = 5
176186
return n_domain
177187

178188

179189
if __name__ == "__main__":
180190
args = args_parse()
181191
root_path = "/home/data/process/"
182-
for args.dataset in ['dsads']:
192+
for args.dataset in ['dsads','pamap','uschad']:
183193
n_domain = set_param(args.dataset)
184194
for args.scaler_method in ['minmax']:
185195
for remain_data_rate in [0.2, 0.4, 0.6, 0.8, 1.0]:
@@ -207,4 +217,3 @@ def set_param(dataset):
207217
}
208218
with open(save_path, 'wb') as f:
209219
pickle.dump(raw_and_aug, f)
210-
print("successful")

ddlearn/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ def args_parse():
4141
parser.add_argument('--root_path', type=str,
4242
default="/home/ddlearn/data/")
4343
parser.add_argument('--data_save_path', type=str,
44-
default='/home/ddlearn/data/')
44+
default="/home/ddlearn/data/")
4545
parser.add_argument('--save_path', type=str,
46-
default="/home/ddlearn/results/")
46+
default="/home/results/")
4747

4848
args = parser.parse_args()
4949
args.step_per_epoch = 100000000000

0 commit comments

Comments
 (0)