Skip to content

Resnet graph regression #213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions Classification/cnns/align.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
rm -rf core.*
rm -rf ./output/snapshots/*

if [ -n "$1" ]; then
NUM_EPOCH=$1
else
NUM_EPOCH=50
fi
echo NUM_EPOCH=$NUM_EPOCH

# training with imagenet
if [ -n "$2" ]; then
DATA_ROOT=$2
else
DATA_ROOT=/dataset/ImageNet/ofrecord
fi
echo DATA_ROOT=$DATA_ROOT

LOG_FOLDER=../logs
mkdir -p $LOG_FOLDER
LOGFILE=$LOG_FOLDER/resnet_training.log

export PYTHONUNBUFFERED=1
echo PYTHONUNBUFFERED=$PYTHONUNBUFFERED
export NCCL_LAUNCH_MODE=PARALLEL
echo NCCL_LAUNCH_MODE=$NCCL_LAUNCH_MODE

#--momentum=0.875 \
python3 of_cnn_train_val.py \
--train_data_dir=$DATA_ROOT/train \
--train_data_part_num=256 \
--val_data_dir=$DATA_ROOT/validation \
--val_data_part_num=256 \
--num_nodes=1 \
--model_load_dir=/ssd/xiexuan/models/resnet50/init_ckpt \
--gpu_num_per_node=1 \
--optimizer="sgd" \
--momentum=0.0 \
--lr_decay="none" \
--label_smoothing=0.1 \
--learning_rate=0.1 \
--loss_print_every_n_iter=1 \
--batch_size_per_device=64 \
--val_batch_size_per_device=64 \
--channel_last=False \
--fuse_bn_relu=False \
--fuse_bn_add_relu=False \
--nccl_fusion_threshold_mb=16 \
--nccl_fusion_max_ops=24 \
--gpu_image_decoder=True \
--num_epoch=$NUM_EPOCH \
--model="resnet50" 2>&1 | tee ${LOGFILE}
# --use_fp16 \
#--pad_output \

echo "Writting log to ${LOGFILE}"
3 changes: 3 additions & 0 deletions Classification/cnns/job_function_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def _default_config(args):
if args.use_xla:
config.use_xla_jit(True)
config.enable_fuse_add_to_output(True)
config.cudnn_conv_force_fwd_algo(0)
config.cudnn_conv_force_bwd_data_algo(1)
config.cudnn_conv_force_bwd_filter_algo(1)
return config


Expand Down
18 changes: 14 additions & 4 deletions Classification/cnns/of_cnn_train_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import alexnet_model
import inception_model
import mobilenet_v2_model
from util import build_watch_cb, build_watch_diff_cb

parser = configs.get_parser()
args = parser.parse_args()
Expand All @@ -51,7 +52,7 @@


flow.config.gpu_device_num(args.gpu_num_per_node)
# flow.config.enable_debug_mode(True)
flow.config.enable_debug_mode(True)

if args.use_fp16 and args.num_nodes * args.gpu_num_per_node > 1:
flow.config.collective_boxing.nccl_fusion_all_reduce_use_buffer(False)
Expand Down Expand Up @@ -84,12 +85,15 @@ def TrainNet():
if args.train_data_dir:
assert os.path.exists(args.train_data_dir)
print("Loading data from {}".format(args.train_data_dir))
(labels, images) = ofrecord_util.load_imagenet_for_training(args)
#(labels, images) = ofrecord_util.load_imagenet_for_training(args)
(labels, images) = ofrecord_util.load_imagenet_for_validation(args)

else:
print("Loading synthetic data.")
(labels, images) = ofrecord_util.load_synthetic(args)
logits = model_dict[args.model](images, args)
flow.watch(logits, build_watch_cb('logits'))
flow.watch_diff(logits, build_watch_diff_cb('logits_grad'))
if args.label_smoothing > 0:
one_hot_labels = label_smoothing(
labels, args.num_classes, args.label_smoothing, logits.dtype
Expand All @@ -104,7 +108,7 @@ def TrainNet():

loss = flow.math.reduce_mean(loss)
predictions = flow.nn.softmax(logits)
outputs = {"loss": loss, "predictions": predictions, "labels": labels}
outputs = {"loss": loss, "predictions": predictions, "labels": labels, 'images': images, 'logits': logits}

# set up warmup,learning rate and optimizer
optimizer_util.set_up_optimizer(loss, args)
Expand Down Expand Up @@ -144,7 +148,13 @@ def main():
loss_key="loss",
)
for i in range(epoch_size):
TrainNet().async_get(metric.metric_cb(epoch, i))
# TrainNet().async_get(metric.metric_cb(epoch, i))
a = TrainNet().get()
# snapshot.save("epoch_{}_iter{}".format(epoch, i))
print('loss:', a['loss'].numpy())
if i>=100:
break
break

if args.val_data_dir:
metric = Metric(
Expand Down
7 changes: 6 additions & 1 deletion Classification/cnns/optimizer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ def set_up_optimizer(loss, args):
staircase=False,
warmup=warmup,
)
elif args.lr_decay == "none":
lr_scheduler = flow.optimizer.PiecewiseConstantScheduler(
boundaries=[],
values=[args.learning_rate],
)
else:
lr_scheduler = flow.optimizer.PiecewiseScalingScheduler(
base_lr=args.learning_rate,
Expand All @@ -134,7 +139,7 @@ def set_up_optimizer(loss, args):
print("Optimizer: SGD")
flow.optimizer.SGD(
lr_scheduler,
momentum=args.momentum if args.momentum > 0 else None,
momentum=args.momentum if args.momentum > 0 else 0.0,
grad_clipping=grad_clipping,
loss_scale_policy=loss_scale_policy,
).minimize(loss)
Expand Down
83 changes: 44 additions & 39 deletions Classification/cnns/resnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import oneflow.compatible.single_client as flow
from util import build_watch_cb, build_watch_diff_cb

BLOCK_COUNTS = [3, 4, 6, 3]
BLOCK_FILTERS = [256, 512, 1024, 2048]
Expand Down Expand Up @@ -50,14 +51,17 @@ def _conv2d(
else:
shape = (filters, input.shape[1], kernel_size, kernel_size)
weight = flow.get_variable(
name + "-weight",
name + ".weight",
shape=shape,
dtype=input.dtype,
initializer=self.weight_initializer,
regularizer=self.weight_regularizer,
model_name="weight",
trainable=self.trainable,
)
if 'conv1' == name:
flow.watch(weight, build_watch_cb('conv1_weight'))
flow.watch_diff(weight, build_watch_diff_cb('conv1_weight_grad'))

return flow.nn.conv2d(
input,
Expand Down Expand Up @@ -113,7 +117,7 @@ def _batch_norm_relu(self, inputs, name=None, last=False):
name=name + "_bn_relu",
)
else:
return flow.nn.relu(self._batch_norm(inputs, name + "_bn", last=last))
return flow.nn.relu(self._batch_norm(inputs, name, last=last))

def _batch_norm_add_relu(self, inputs, addend, name=None, last=False):
if self.fuse_bn_add_relu:
Expand All @@ -139,7 +143,7 @@ def _batch_norm_add_relu(self, inputs, addend, name=None, last=False):
)
else:
return flow.nn.relu(
self._batch_norm(inputs, name + "_bn", last=last) + addend
self._batch_norm(inputs, name, last=last) + addend
)

def conv2d_affine(self, input, name, filters, kernel_size, strides):
Expand All @@ -150,37 +154,37 @@ def conv2d_affine(self, input, name, filters, kernel_size, strides):
def bottleneck_transformation(
self, input, block_name, filters, filters_inner, strides
):
a = self.conv2d_affine(input, block_name + "_branch2a", filters_inner, 1, 1)
a = self._batch_norm_relu(a, block_name + "_branch2a")
a = self.conv2d_affine(input, block_name + ".conv1", filters_inner, 1, 1)
a = self._batch_norm_relu(a, block_name + ".bn1")

b = self.conv2d_affine(a, block_name + "_branch2b", filters_inner, 3, strides)
b = self._batch_norm_relu(b, block_name + "_branch2b")
b = self.conv2d_affine(a, block_name + ".conv2", filters_inner, 3, strides)
b = self._batch_norm_relu(b, block_name + ".bn2")

c = self.conv2d_affine(b, block_name + "_branch2c", filters, 1, 1)
c = self.conv2d_affine(b, block_name + ".conv3", filters, 1, 1)
return c

def residual_block(self, input, block_name, filters, filters_inner, strides_init):
if strides_init != 1 or block_name == "res2_0":
if strides_init != 1 or block_name == "layer1.0":
shortcut = self.conv2d_affine(
input, block_name + "_branch1", filters, 1, strides_init
input, block_name + ".downsample.0", filters, 1, strides_init
)
shortcut = self._batch_norm(shortcut, block_name + "_branch1_bn")
shortcut = self._batch_norm(shortcut, block_name + ".downsample.1")
else:
shortcut = input

bottleneck = self.bottleneck_transformation(
input, block_name, filters, filters_inner, strides_init,
)
return self._batch_norm_add_relu(
bottleneck, shortcut, block_name + "_branch2c", last=True
bottleneck, shortcut, block_name + ".bn3", last=True
)

def residual_stage(
self, input, stage_name, counts, filters, filters_inner, stride_init=2
):
output = input
for i in range(counts):
block_name = "%s_%d" % (stage_name, i)
block_name = "%s.%d" % (stage_name, i)
output = self.residual_block(
output, block_name, filters, filters_inner, stride_init if i == 0 else 1
)
Expand All @@ -192,7 +196,7 @@ def resnet_conv_x_body(self, input):
for i, (counts, filters, filters_inner) in enumerate(
zip(BLOCK_COUNTS, BLOCK_FILTERS, BLOCK_FILTERS_INNER)
):
stage_name = "res%d" % (i + 2)
stage_name = "layer%d" % (i + 1)
output = self.residual_stage(
output, stage_name, counts, filters, filters_inner, 1 if i == 0 else 2
)
Expand All @@ -201,7 +205,7 @@ def resnet_conv_x_body(self, input):

def resnet_stem(self, input):
conv1 = self._conv2d("conv1", input, 64, 7, 2)
conv1_bn = self._batch_norm_relu(conv1, "conv1")
conv1_bn = self._batch_norm_relu(conv1, "bn1")
pool1 = flow.nn.max_pool2d(
conv1_bn,
ksize=3,
Expand Down Expand Up @@ -232,28 +236,29 @@ def resnet50(images, args, trainable=True, training=True):
else:
paddings = ((0, 0), (0, 1), (0, 0), (0, 0))
images = flow.pad(images, paddings=paddings)
with flow.scope.namespace("Resnet"):
stem = builder.resnet_stem(images)
body = builder.resnet_conv_x_body(stem)
pool5 = flow.nn.avg_pool2d(
body,
ksize=7,
strides=1,
padding="VALID",
data_format=builder.data_format,
name="pool5",
)
fc1001 = flow.layers.dense(
flow.reshape(pool5, (pool5.shape[0], -1)),
units=1000,
use_bias=True,
kernel_initializer=flow.variance_scaling_initializer(
2, "fan_in", "random_normal"
),
bias_initializer=flow.zeros_initializer(),
kernel_regularizer=weight_regularizer,
bias_regularizer=weight_regularizer,
trainable=trainable,
name="fc1001",
)
# with flow.scope.namespace("resnet50"):
stem = builder.resnet_stem(images)
body = builder.resnet_conv_x_body(stem)
pool5 = flow.nn.avg_pool2d(
body,
ksize=7,
strides=1,
padding="VALID",
data_format=builder.data_format,
name="avgpool",
)
fc1001 = flow.layers.dense(
flow.reshape(pool5, (pool5.shape[0], -1)),
units=1000,
use_bias=True,
kernel_initializer=flow.variance_scaling_initializer(
2, "fan_in", "random_normal"
),
bias_initializer=flow.zeros_initializer(),
kernel_regularizer=weight_regularizer,
bias_regularizer=weight_regularizer,
trainable=trainable,
name="fc",
)
return fc1001

30 changes: 29 additions & 1 deletion Classification/cnns/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def __init__(self, model_save_dir, model_load_dir):
if model_load_dir:
assert os.path.isdir(model_load_dir)
print("Restoring model from {}.".format(model_load_dir))
flow.load_variables(flow.checkpoint.get(model_load_dir))
flow.load_variables(flow.checkpoint.get(model_load_dir), ignore_mismatch=False)
# flow.checkpoint.save('loaded_init_ckpt')
else:
# flow.checkpoint.save("initial_model")
print("Init model on demand.")
Expand Down Expand Up @@ -84,6 +85,15 @@ def match_top_k(predictions, labels, top_k=1):
return num_matched, match_array.shape[0]


def dump_outputs(outputs, step, dump_dir='output'):
for k, v in outputs.items():
root = os.path.join(dump_dir, str(step))
if not os.path.isdir(root):
os.makedirs(root)
path = os.path.join(root, k)
np.save(path, v.numpy())


class Metric(object):
def __init__(
self,
Expand Down Expand Up @@ -142,6 +152,7 @@ def callback(outputs):
self.num_samples += num_samples

if (step + 1) % self.calculate_batches == 0:
dump_outputs(outputs, step)
throughput = self.num_samples / self.timer.split()
if self.prediction_key:
top_1_accuracy = self.top_1_num_matched / self.num_samples
Expand Down Expand Up @@ -180,3 +191,20 @@ def callback(outputs):
self._clear()

return callback


from oneflow.compatible.single_client import typing as tp

def build_watch_cb(name, iter=0, root='output'):
path = os.path.join(root, str(iter), f'{name}.npy')
def cb(blob: tp.Numpy):
np.save(path, blob)
return cb


def build_watch_diff_cb(name, iter=0, root='output'):
path = os.path.join(root, str(iter), f'{name}_grad.npy')
def cb(blob: tp.Numpy):
np.save(path, blob)
return cb