Skip to content

Commit 5e35a2a

Browse files
committed
code also runs for RGB and Flow modalities
1 parent 28a8163 commit 5e35a2a

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

dataset.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,22 +50,19 @@ def __init__(self, root_path, list_file,
5050
def _load_image(self, directory, idx, isLast=False):
5151
if self.modality == 'RGB' or self.modality == 'RGBDiff':
5252
try:
53-
return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')]
53+
return [Image.open(os.path.join(self.root_path, "rgb", directory, self.image_tmpl.format(idx))).convert('RGB')]
5454
except Exception:
55-
print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx)))
56-
return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')]
55+
print('error loading image:', os.path.join(self.root_path, "rgb", directory, self.image_tmpl.format(idx)))
56+
return [Image.open(os.path.join(self.root_path, "rgb", directory, self.image_tmpl.format(1))).convert('RGB')]
5757

5858
elif self.modality == 'Flow':
5959
try:
60-
idx_skip = 1 + (idx-1)*5
61-
flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx_skip))).convert('RGB')
60+
x_img = Image.open(os.path.join(self.root_path, "flow/u", directory, self.image_tmpl.format(idx))).convert('L')
61+
y_img = Image.open(os.path.join(self.root_path, "flow/v", directory, self.image_tmpl.format(idx))).convert('L')
6262
except Exception:
63-
print('error loading flow file:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx_skip)))
64-
flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')
65-
# the input flow file is RGB image with (flow_x, flow_y, blank) for each channel
66-
flow_x, flow_y, _ = flow.split()
67-
x_img = flow_x.convert('L')
68-
y_img = flow_y.convert('L')
63+
print('error loading flow file:', os.path.join(self.root_path, "flow/v", directory, self.image_tmpl.format(idx)))
64+
x_img = Image.open(os.path.join(self.root_path, "flow/u", directory, self.image_tmpl.format(1))).convert('L')
65+
y_img = Image.open(os.path.join(self.root_path, "flow/v", directory, self.image_tmpl.format(1))).convert('L')
6966
return [x_img, y_img]
7067

7168
elif self.modality == 'RGBFlow':
@@ -122,7 +119,7 @@ def __getitem__(self, index):
122119
index = np.random.randint(len(self.video_list))
123120
record = self.video_list[index]
124121
else:
125-
while not os.path.exists(os.path.join(self.root_path, record.path, self.image_tmpl.format(1))):
122+
while not os.path.exists(os.path.join(self.root_path, "rgb", record.path, self.image_tmpl.format(1))):
126123
index = np.random.randint(len(self.video_list))
127124
record = self.video_list[index]
128125

0 commit comments

Comments
 (0)