Skip to content

Commit b217e32

Browse files
committed
initial commit
1 parent 435ac51 commit b217e32

File tree

464 files changed

+160045
-2
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

464 files changed

+160045
-2
lines changed

.gitignore

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# files types to exculde
2+
*.tar
3+
*.mp4
4+
*.h5
5+
6+
jester
7+
*.txt
8+
9+
model
10+
*.pth.tar
11+
*.pth
12+
13+
log
14+
*.csv

LICENSE

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
BSD 2-Clause License for Motion Fused Frames
2+
3+
Copyright (c) 2017, Okan Köpüklü
4+
All rights reserved.
5+
6+
Redistribution and use in source and binary forms, with or without
7+
modification, are permitted provided that the following conditions are met:
8+
9+
* Redistributions of source code must retain the above copyright notice, this
10+
list of conditions and the following disclaimer.
11+
12+
* Redistributions in binary form must reproduce the above copyright notice,
13+
this list of conditions and the following disclaimer in the documentation
14+
and/or other materials provided with the distribution.
15+
16+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
28+
BSD 2-Clause License for TSN-PyTorch
29+
30+
Copyright (c) 2017, Multimedia Laboratary, The Chinese University of Hong Kong
31+
All rights reserved.
32+
33+
Redistribution and use in source and binary forms, with or without
34+
modification, are permitted provided that the following conditions are met:
35+
36+
* Redistributions of source code must retain the above copyright notice, this
37+
list of conditions and the following disclaimer.
38+
39+
* Redistributions in binary form must reproduce the above copyright notice,
40+
this list of conditions and the following disclaimer in the documentation
41+
and/or other materials provided with the distribution.
42+
43+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
44+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
45+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
46+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
47+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
48+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
49+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
50+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
51+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
52+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

MLPmodule.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
class MLPmodule(torch.nn.Module):
5+
"""
6+
This is the 2-layer MLP implementation used for linking spatio-temporal
7+
features coming from different segments.
8+
"""
9+
def __init__(self, img_feature_dim, num_frames, num_class):
10+
super(MLPmodule, self).__init__()
11+
self.num_frames = num_frames
12+
self.num_class = num_class
13+
self.img_feature_dim = img_feature_dim
14+
self.num_bottleneck = 512
15+
self.classifier = nn.Sequential(
16+
nn.ReLU(),
17+
nn.Linear(self.num_frames * self.img_feature_dim,
18+
self.num_bottleneck),
19+
#nn.Dropout(0.90), # Add an extra DO if necess.
20+
nn.ReLU(),
21+
nn.Linear(self.num_bottleneck,self.num_class),
22+
)
23+
def forward(self, input):
24+
input = input.view(input.size(0), self.num_frames*self.img_feature_dim)
25+
input = self.classifier(input)
26+
return input
27+
28+
29+
def return_MLP(relation_type, img_feature_dim, num_frames, num_class):
30+
MLPmodel = MLPmodule(img_feature_dim, num_frames, num_class)
31+
32+
return MLPmodel

README.md

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,67 @@
1-
# MFF-pytorch
2-
Motion Fused Frames implementation in PyTorch
1+
# Motion Fused Frames (MFFs)
2+
3+
Pytorch implementation of Motion Fused Frames, built on top of the codebase [TSN-pytorch](https://github.com/yjxiong/temporal-segment-networks).
4+
5+
<p align="center"><img src="https://github.com/okankop/MFF-pytorch/blob/master/images/motion_fused_frames.jpg" align="middle" width="500" title="Motion Fused Frames" /></p>
6+
7+
**Note**: always use `git clone --recursive https://github.com/okankop/MFF-pytorch` to clone this project
8+
Otherwise you will not be able to use the inception series CNN architecture.
9+
10+
### Dataset Preparation
11+
Download the [jester dataset](https://www.twentybn.com/datasets/something-something) or [NVIDIA dataset](http://research.nvidia.com/publication/online-detection-and-classification-dynamic-hand-gestures-recurrent-3d-convolutional) or [ChaLearn LAP IsoGD dataset](http://www.cbsr.ia.ac.cn/users/jwan/database/isogd.html). Decompress them into the same folder and use [process_dataset.py](process_dataset.py) to generate the index files for train, val, and test split. Poperly set up the train, validatin, and category meta files in [datasets_video.py](datasets_video.py). Finally, use directory [flow_computation](flow_computation) to calculate the optical flow images using Brox method.
12+
13+
Assume the structure of data directories is the following:
14+
15+
```misc
16+
~/MFF-pytorch/
17+
datasets/
18+
jester/
19+
rgb/
20+
.../ (directories of video samples)
21+
.../ (jpg color frames)
22+
flow/
23+
u/
24+
.../ (directories of video samples)
25+
.../ (jpg optical-flow-u frames)
26+
v/
27+
.../ (directories of video samples)
28+
.../ (jpg optical-flow-v frames)
29+
model/
30+
.../(saved models for the last checkpoint and best model)
31+
```
32+
33+
34+
### Running the Code
35+
You can simply run 'python main.py' to start a training with the default parameters. Followings are some examples for training under different scenarios:
36+
37+
* Train 4-segment network with 3 flow, 1 color frames (4-MFFs-3f1c architecture)
38+
```bash
39+
python main.py jester RGBFlow --arch BNInception --num_segments 4 \
40+
--consensus_type MLP --num_motion 3 --batch-size 32
41+
```
42+
43+
* Train resuming the last checkpoint (4-MFFs-3f1c architecture)
44+
```bash
45+
python main.py jester RGBFlow --resume=<path-to-last-checkpoint> --arch BNInception \
46+
--consensus_type MLP --num_segments 4 --num_motion 3 --batch-size 32
47+
```
48+
49+
* The command to test trained model (8-MFFs-3f1c architecture)
50+
51+
```bash
52+
python test_models.py jester RGBFlow model/MFF_jester_RGBFlow_BNInception_segment8_3f1c_best.pth.tar --arch BNInception --consensus_type MLP --test_crops 1 --num_motion 3 --test_segments 8
53+
```
54+
55+
All GPUs is used for the training. If you want a part of GPUs, use CUDA_VISIBLE_DEVICES=...
56+
57+
### Citation
58+
O. Köpüklü, N. Köse, G. Rigoll. Motion Fused Frames: Data Level Fusion Strategy for Hand Gesture Recognition, 2018 [PDF]
59+
```
60+
@article{kopuklu2018motion,
61+
title = {Motion Fused Frames: Data Level Fusion Strategy for Hand Gesture Recognition},
62+
author = {K\"op\"ukl\"u, Okan and K\"ose, Neslihan and Rigoll, Gerhard},
63+
}
64+
```
65+
66+
### Acknowledgement
67+
We thank Yuanjun Xiong for releasing [TSN-Pytorch codebase](https://github.com/yjxiong/temporal-segment-networks), which we build our work on top. We also thank Bolei Zhou for the insprational work [Temporal Segment Networks](https://arxiv.org/pdf/1711.08496.pdf), from which we imported [process_dataset.py](https://github.com/metalbubble/TRN-pytorch/blob/master/process_dataset.py) to our project.

__pycache__/MLPmodule.cpython-36.pyc

1.24 KB
Binary file not shown.

__pycache__/TRNmodule.cpython-36.pyc

5.89 KB
Binary file not shown.

__pycache__/dataset.cpython-36.pyc

5.56 KB
Binary file not shown.
2.11 KB
Binary file not shown.

__pycache__/models.cpython-36.pyc

10.7 KB
Binary file not shown.

__pycache__/opts.cpython-36.pyc

2.4 KB
Binary file not shown.

__pycache__/transforms.cpython-36.pyc

14.7 KB
Binary file not shown.

dataset.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import torch.utils.data as data
2+
3+
import random
4+
from PIL import Image
5+
import os
6+
import os.path
7+
import numpy as np
8+
from numpy.random import randint
9+
10+
class VideoRecord(object):
11+
def __init__(self, row):
12+
self._data = row
13+
14+
@property
15+
def path(self):
16+
return self._data[0]
17+
18+
@property
19+
def num_frames(self):
20+
return int(self._data[1])
21+
22+
@property
23+
def label(self):
24+
return int(self._data[2])
25+
26+
27+
class TSNDataSet(data.Dataset):
28+
def __init__(self, root_path, list_file,
29+
num_segments=3, new_length=1, modality='RGB',
30+
image_tmpl='img_{:05d}.jpg', transform=None,
31+
force_grayscale=False, random_shift=True,
32+
test_mode=False, dataset='jester'):
33+
34+
self.root_path = root_path
35+
self.list_file = list_file
36+
self.num_segments = num_segments
37+
self.new_length = new_length
38+
self.modality = modality
39+
self.image_tmpl = image_tmpl
40+
self.transform = transform
41+
self.random_shift = random_shift
42+
self.test_mode = test_mode
43+
self.dataset = dataset
44+
45+
if self.modality == 'RGBDiff' or self.modality == 'RGBFlow':
46+
self.new_length += 1# Diff needs one more image to calculate diff
47+
48+
self._parse_list()
49+
50+
def _load_image(self, directory, idx, isLast=False):
51+
if self.modality == 'RGB' or self.modality == 'RGBDiff':
52+
try:
53+
return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')]
54+
except Exception:
55+
print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx)))
56+
return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')]
57+
58+
elif self.modality == 'Flow':
59+
try:
60+
idx_skip = 1 + (idx-1)*5
61+
flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx_skip))).convert('RGB')
62+
except Exception:
63+
print('error loading flow file:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx_skip)))
64+
flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')
65+
# the input flow file is RGB image with (flow_x, flow_y, blank) for each channel
66+
flow_x, flow_y, _ = flow.split()
67+
x_img = flow_x.convert('L')
68+
y_img = flow_y.convert('L')
69+
return [x_img, y_img]
70+
71+
elif self.modality == 'RGBFlow':
72+
if isLast:
73+
return [Image.open(os.path.join(self.root_path, "rgb", directory, self.image_tmpl.format(idx))).convert('RGB')]
74+
else:
75+
x_img = Image.open(os.path.join(self.root_path, "flow/u", directory, self.image_tmpl.format(idx))).convert('L')
76+
y_img = Image.open(os.path.join(self.root_path, "flow/v", directory, self.image_tmpl.format(idx))).convert('L')
77+
return [x_img, y_img]
78+
79+
80+
def _parse_list(self):
81+
# check the frame number is large >3:
82+
# usualy it is [video_id, num_frames, class_idx]
83+
tmp = [x.strip().split(' ') for x in open(self.list_file)]
84+
tmp = [item for item in tmp if int(item[1])>=3]
85+
self.video_list = [VideoRecord(item) for item in tmp]
86+
print('video number:%d'%(len(self.video_list)))
87+
88+
def _sample_indices(self, record):
89+
"""
90+
91+
:param record: VideoRecord
92+
:return: list
93+
"""
94+
average_duration = (record.num_frames - self.new_length + 1) // self.num_segments
95+
96+
if average_duration > 0:
97+
offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, size=self.num_segments)
98+
elif record.num_frames > self.num_segments:
99+
offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments))
100+
else:
101+
offsets = np.zeros((self.num_segments,))
102+
return offsets + 1
103+
104+
def _get_val_indices(self, record):
105+
if record.num_frames > self.num_segments + self.new_length - 1:
106+
tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
107+
offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])
108+
else:
109+
offsets = np.zeros((self.num_segments,))
110+
return offsets + 1
111+
112+
def _get_test_indices(self, record):
113+
tick = (record.num_frames - self.new_length + 1) / float(self.num_segments)
114+
offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)])
115+
return offsets + 1
116+
117+
def __getitem__(self, index):
118+
record = self.video_list[index]
119+
# check this is a legit video folder
120+
if self.modality == 'RGBFlow':
121+
while not os.path.exists(os.path.join(self.root_path, "rgb", record.path, self.image_tmpl.format(1))):
122+
index = np.random.randint(len(self.video_list))
123+
record = self.video_list[index]
124+
else:
125+
while not os.path.exists(os.path.join(self.root_path, record.path, self.image_tmpl.format(1))):
126+
index = np.random.randint(len(self.video_list))
127+
record = self.video_list[index]
128+
129+
if not self.test_mode:
130+
segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
131+
else:
132+
segment_indices = self._get_test_indices(record)
133+
134+
return self.get(record, segment_indices)
135+
136+
def get(self, record, indices):
137+
images = list()
138+
for seg_ind in indices:
139+
p = int(seg_ind)
140+
for i in range(self.new_length):
141+
if self.modality == 'RGBFlow':
142+
if i == self.new_length - 1:
143+
seg_imgs = self._load_image(record.path, p, True)
144+
else:
145+
if p == record.num_frames:
146+
seg_imgs = self._load_image(record.path, p-1)
147+
else:
148+
seg_imgs = self._load_image(record.path, p)
149+
else:
150+
seg_imgs = self._load_image(record.path, p)
151+
152+
images.extend(seg_imgs)
153+
if p < record.num_frames:
154+
p += 1
155+
156+
process_data = self.transform(images)
157+
return process_data, record.label
158+
159+
def __len__(self):
160+
return len(self.video_list)

0 commit comments

Comments
 (0)