Skip to content

Commit 127da1f

Browse files
committed
init
0 parents  commit 127da1f

File tree

91 files changed

+8468
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+8468
-0
lines changed

.gitignore

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
/.vscode
2+
/.idea
3+
/code/
4+
/raw_data/*/
5+
/libcity/cache/*
6+
/libcity/log/*
7+
/libcity/__pycache__
8+
/libcity/*/__pycache__

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2022 aptx1231
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

bj.json

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
{
2+
"n_layers": 6,
3+
"d_model": 256,
4+
"attn_heads": 8,
5+
"max_epoch": 30,
6+
"batch_size": 64,
7+
"grad_accmu_steps": 1,
8+
"learning_rate": 2e-4,
9+
"dataset": "bj",
10+
"roadnetwork": "bj_roadmap_edge_bj_True_1_merge",
11+
"geo_file": "bj_roadmap_edge_bj_True_1_merge_withdegree",
12+
"rel_file": "bj_roadmap_edge_bj_True_1_merge_withdegree",
13+
"merge": true,
14+
"min_freq": 1,
15+
"seq_len": 128,
16+
"test_every": 50,
17+
"temperature": 0.05,
18+
"contra_loss_type": "simclr",
19+
"classify_label": "vflag",
20+
"type_ln": "post",
21+
"add_cls": true,
22+
"add_time_in_day": true,
23+
"add_day_in_week": true,
24+
"add_pe": true,
25+
"add_temporal_bias": true,
26+
"temporal_bias_dim": 64,
27+
"use_mins_interval": false,
28+
"add_gat": true,
29+
"gat_heads_per_layer": [8, 16, 1],
30+
"gat_features_per_layer": [16, 16, 256],
31+
"gat_dropout": 0.1,
32+
"gat_K": 1,
33+
"gat_avg_last": true,
34+
"load_trans_prob": true,
35+
"append_degree2gcn": true,
36+
"normal_feature": false,
37+
"pooling": "cls"
38+
}

framework.png

268 KB
Loading

libcity/__init__.py

Whitespace-only changes.

libcity/config/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from libcity.config.config_parser import ConfigParser
2+
3+
__all__ = [
4+
'ConfigParser'
5+
]

libcity/config/config_parser.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import os
2+
import json
3+
import torch
4+
5+
6+
class ConfigParser(object):
7+
8+
def __init__(self, task, model, dataset, config_file=None,
9+
saved_model=True, train=True, other_args=None, hyper_config_dict=None):
10+
self.config = {}
11+
self._parse_external_config(task, model, dataset, saved_model, train, other_args, hyper_config_dict)
12+
self._parse_config_file(config_file)
13+
self._load_default_config()
14+
self._init_device()
15+
16+
def _parse_external_config(self, task, model, dataset,
17+
saved_model=True, train=True, other_args=None, hyper_config_dict=None):
18+
if task is None:
19+
raise ValueError('the parameter task should not be None!')
20+
if model is None:
21+
raise ValueError('the parameter model should not be None!')
22+
if dataset is None:
23+
raise ValueError('the parameter dataset should not be None!')
24+
self.config['task'] = task
25+
self.config['model'] = model
26+
self.config['dataset'] = dataset
27+
self.config['saved_model'] = saved_model
28+
self.config['train'] = False if task == 'map_matching' else train
29+
if other_args is not None:
30+
for key in other_args:
31+
self.config[key] = other_args[key]
32+
if hyper_config_dict is not None:
33+
for key in hyper_config_dict:
34+
self.config[key] = hyper_config_dict[key]
35+
36+
def _parse_config_file(self, config_file):
37+
if config_file is not None:
38+
if os.path.exists('./{}.json'.format(config_file)):
39+
with open('./{}.json'.format(config_file), 'r') as f:
40+
x = json.load(f)
41+
for key in x:
42+
if key not in self.config:
43+
self.config[key] = x[key]
44+
else:
45+
raise FileNotFoundError(
46+
'Config file {}.json is not found. Please ensure \
47+
the config file is in the root dir and is a JSON \
48+
file.'.format(config_file))
49+
50+
def _load_default_config(self):
51+
with open('./libcity/config/task_config.json', 'r') as f:
52+
task_config = json.load(f)
53+
task_config = task_config[self.config['task']]
54+
model = self.config['model']
55+
if 'dataset_class' not in self.config:
56+
self.config['dataset_class'] = task_config[model]['dataset_class']
57+
if self.config['task'] == 'traj_loc_pred' and 'traj_encoder' not in self.config:
58+
self.config['traj_encoder'] = task_config[model]['traj_encoder']
59+
if 'executor' not in self.config:
60+
self.config['executor'] = task_config[model]['executor']
61+
if 'evaluator' not in self.config:
62+
self.config['evaluator'] = task_config[model]['evaluator']
63+
if self.config['model'].upper() in ['LSTM', 'GRU', 'RNN']:
64+
self.config['rnn_type'] = self.config['model']
65+
self.config['model'] = 'RNN'
66+
if self.config['model'] == 'BERTContrastive':
67+
if 'split' in self.config and self.config['split'] is True:
68+
self.config['dataset_class'] = 'ContrastiveSplitDataset'
69+
self.config['executor'] = 'ContrastiveSplitExecutor'
70+
if self.config['model'] == 'BERTContrastiveLM':
71+
if 'split' in self.config and self.config['split'] is True:
72+
self.config['dataset_class'] = 'ContrastiveSplitLMDataset'
73+
self.config['executor'] = 'ContrastiveSplitMLMExecutor'
74+
if self.config['model'] == 'LinearClassify':
75+
if 'classify_label' in self.config and self.config['classify_label'] == 'usrid':
76+
self.config['evaluator'] = 'MultiClassificationEvaluator'
77+
default_file_list = []
78+
# model
79+
default_file_list.append('model/{}/{}.json'.format(self.config['task'], self.config['model']))
80+
# dataset
81+
default_file_list.append('data/{}.json'.format(self.config['dataset_class']))
82+
# executor
83+
default_file_list.append('executor/{}.json'.format(self.config['executor']))
84+
# evaluator
85+
default_file_list.append('evaluator/{}.json'.format(self.config['evaluator']))
86+
for file_name in default_file_list:
87+
with open('./libcity/config/{}'.format(file_name), 'r') as f:
88+
x = json.load(f)
89+
for key in x:
90+
if key not in self.config:
91+
self.config[key] = x[key]
92+
93+
def _init_device(self):
94+
use_gpu = self.config.get('gpu', True)
95+
gpu_id = self.config.get('gpu_id', 0)
96+
if use_gpu:
97+
torch.cuda.set_device(gpu_id)
98+
self.config['device'] = torch.device(
99+
"cuda:%d" % gpu_id if torch.cuda.is_available() and use_gpu else "cpu")
100+
101+
def get(self, key, default=None):
102+
return self.config.get(key, default)
103+
104+
def __getitem__(self, key):
105+
if key in self.config:
106+
return self.config[key]
107+
else:
108+
raise KeyError('{} is not in the config'.format(key))
109+
110+
def __setitem__(self, key, value):
111+
self.config[key] = value
112+
113+
def __contains__(self, key):
114+
return key in self.config
115+
116+
def __iter__(self):
117+
return self.config.__iter__()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"batch_size": 32,
3+
"num_workers": 0,
4+
"vocab_path": null,
5+
"seq_len": 128,
6+
"min_freq": 1,
7+
"merge": true,
8+
"bidir_adj_mx": false,
9+
"masking_ratio": 0.2,
10+
"masking_mode": "together",
11+
"distribution": "random",
12+
"avg_mask_len": 3
13+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"batch_size": 32,
3+
"num_workers": 0,
4+
"vocab_path": null,
5+
"seq_len": 128,
6+
"min_freq": 1,
7+
"merge": true,
8+
"bidir_adj_mx": false,
9+
"cluster_data_path": null
10+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"batch_size": 32,
3+
"num_workers": 0,
4+
"vocab_path": null,
5+
"seq_len": 128,
6+
"min_freq": 1,
7+
"merge": true,
8+
"bidir_adj_mx": false
9+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
{
2+
"batch_size": 32,
3+
"num_workers": 0,
4+
"vocab_path": null,
5+
"seq_len": 128,
6+
"min_freq": 1,
7+
"merge": true,
8+
"bidir_adj_mx": false,
9+
"masking_ratio": 0.2,
10+
"masking_mode": "together",
11+
"distribution": "random",
12+
"avg_mask_len": 3
13+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"batch_size": 32,
3+
"num_workers": 0,
4+
"vocab_path": null,
5+
"seq_len": 128,
6+
"min_freq": 1,
7+
"merge": true,
8+
"bidir_adj_mx": false,
9+
"out_data_argument1": null,
10+
"out_data_argument2": null
11+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"batch_size": 32,
3+
"num_workers": 0,
4+
"vocab_path": null,
5+
"seq_len": 128,
6+
"min_freq": 1,
7+
"merge": true,
8+
"bidir_adj_mx": false,
9+
"out_data_argument1": null,
10+
"out_data_argument2": null
11+
}

libcity/config/data/ETADataset.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"batch_size": 32,
3+
"num_workers": 0,
4+
"vocab_path": null,
5+
"seq_len": 128,
6+
"min_freq": 1,
7+
"merge": true,
8+
"bidir_adj_mx": false
9+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"batch_size": 32,
3+
"num_workers": 0,
4+
"vocab_path": null,
5+
"seq_len": 128,
6+
"min_freq": 1,
7+
"merge": true,
8+
"bidir_adj_mx": false,
9+
"query_data_path": null,
10+
"detour_data_path": null,
11+
"origin_big_data_path": null
12+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"batch_size": 32,
3+
"num_workers": 0,
4+
"vocab_path": null,
5+
"seq_len": 128,
6+
"min_freq": 1,
7+
"merge": true,
8+
"bidir_adj_mx": false,
9+
"classify_label": "vflag"
10+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
3+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"metrics": ["Precision", "Recall", "F1", "MRR", "NDCG"],
3+
"save_modes": ["csv", "json"],
4+
"topk": [1, 5, 10]
5+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"cluster_kinds": 0,
3+
"roadnetwork": "bj_roadmap_edge"
4+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"metrics": ["Precision", "Recall", "F1", "MRR", "NDCG", "microF1", "macroF1"],
3+
"save_modes": ["json"],
4+
"topk": [1, 5, 10]
5+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"metrics": ["MAE", "RMSE", "MAPE", "R2", "EVAR"],
3+
"save_modes": ["csv", "json"]
4+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"sim_select_num": 5,
3+
"roadnetwork": "bj_roadmap_edge",
4+
"metrics": ["MR", "MRR", "HR"],
5+
"topk": [1, 5, 10],
6+
"sim_mode": "most"
7+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"metrics": ["Accuracy", "Precision", "Recall", "F1", "AUC", "kappa"],
3+
"save_modes": ["csv", "json"]
4+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"batch_size": 32,
3+
"grad_accmu_steps": 1,
4+
"max_epoch": 20,
5+
"learner": "adamw",
6+
"learning_rate": 2e-4,
7+
"lr_eta_min": 0,
8+
"lr_warmup_epoch": 4,
9+
"lr_warmup_init": 1e-6,
10+
"lr_decay": true,
11+
"lr_scheduler": "cosinelr",
12+
"lr_decay_ratio": 0.1,
13+
"t_in_epochs": true,
14+
"clip_grad_norm": true,
15+
"max_grad_norm": 5,
16+
"use_early_stop": true,
17+
"patience": 50,
18+
"test_every": 10,
19+
"log_batch": 500,
20+
"log_every": 1,
21+
"l2_reg": null
22+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
{
2+
"batch_size": 32,
3+
"grad_accmu_steps": 1,
4+
"max_epoch": 20,
5+
"learner": "adamw",
6+
"learning_rate": 2e-4,
7+
"lr_eta_min": 0,
8+
"lr_warmup_epoch": 4,
9+
"lr_warmup_init": 1e-6,
10+
"lr_decay": true,
11+
"lr_scheduler": "cosinelr",
12+
"lr_decay_ratio": 0.1,
13+
"t_in_epochs": true,
14+
"clip_grad_norm": true,
15+
"max_grad_norm": 5,
16+
"use_early_stop": true,
17+
"patience": 50,
18+
"test_every": 10,
19+
"log_batch": 500,
20+
"log_every": 1,
21+
"l2_reg": null,
22+
"add_cls": true,
23+
"pooling": "cls",
24+
"pretrain_path": null,
25+
"freeze": false
26+
}

0 commit comments

Comments
 (0)