Professional Documents
Culture Documents
Main Whole
Main Whole
0 is needed
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os,time,cv2,scipy.io
import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np
import utils as utils
import myflowlib as flowlib
import flow_warp as flow_warp_op
import scipy.misc as sic
import subprocess
import network as net
import loss as loss
import argparse
from sklearn.neighbors import NearestNeighbors
parser = argparse.ArgumentParser()
parser.add_argument("--model", default='Result_whole', type=str, help="Model Name")
parser.add_argument("--div_num", default=4, type=int, help="diverse num")
parser.add_argument("--save_freq", default=1, type=int, help="save frequency")
parser.add_argument("--test_dir", default='./data/test/JPEGImages/480p/cows',
type=str, help="Test dir path")
parser.add_argument("--train_root", default="./data/train/JPEGImages/480p/",
type=str, help="Test dir path")
parser.add_argument("--test_root", default="./data/test/JPEGImages/480p/",
type=str, help="Test dir path")
parser.add_argument("--imgs_dir", default='../data/Imagenet', type=str, help="Test
dir path")
parser.add_argument("--is_training", default=1, type=int, help="Training or test")
parser.add_argument("--continue_training", default=1, type=int, help="Restore
checkpoint")
ARGS = parser.parse_args()
print(ARGS)
model=ARGS.model
div_num=ARGS.div_num
save_freq = ARGS.save_freq
test_dir = ARGS.test_dir
train_root = [ARGS.train_root]
test_root = [ARGS.test_root]
is_training=ARGS.is_training
continue_training=ARGS.continue_training
imgs_dir = ARGS.imgs_dir
num_frame = 2 # number of read in frames
config=tf.ConfigProto()
config.gpu_options.allow_growth=True
sess=tf.Session(config=config)
train_low=utils.read_image_path(train_root)
test_low=utils.read_image_path(test_root)
input_idx=tf.placeholder(tf.int32,shape=[None,5*num_frame])
input_i=tf.placeholder(tf.float32,shape=[None,None,None,1*num_frame])
input_target=tf.placeholder(tf.float32,shape=[None,None,None,3*num_frame])
input_flow_forward=tf.placeholder(tf.float32,shape=[None,None,None,2*(num_frame-
1)])
input_flow_backward=tf.placeholder(tf.float32,shape=[None,None,None,2*(num_frame-
1)])
gray_flow_forward=tf.placeholder(tf.float32,shape=[None,None,None,2*(num_frame-1)])
gray_flow_backward=tf.placeholder(tf.float32,shape=[None,None,None,2*(num_frame-
1)])
c0=tf.placeholder(tf.float32,shape=[None,None,None,3])
c1=tf.placeholder(tf.float32,shape=[None,None,None,3])
lossDict = {}
objDict={}
objDict["mask"],_=occlusion_mask(Y0,Y1,input_flow_backward[:,:,:,0:2])
objDict["warped"]=flow_warp_op.flow_warp(C0,input_flow_backward[:,:,:,0:2])
lossDict["RankDiv_im1"]=loss.RankDiverse_loss(C0,
tf.tile(input_target[:,:,:,0:3], [1,1,1,div_num]),div_num)
lossDict["RankDiv_im2"]=loss.RankDiverse_loss(C1,
tf.tile(input_target[:,:,:,3:6], [1,1,1,div_num]),div_num)
lossDict["RankDiv"]=lossDict["RankDiv_im1"]+lossDict["RankDiv_im2"]
lossDict['Bilateral_im1']= sum([loss.KNN_loss(C0[:,:,:,3*i:3*i+3],
input_idx[:,0:5]) for i in range(4)])
lossDict['Bilateral_im2']= sum([loss.KNN_loss(C1[:,:,:,3*i:3*i+3],
input_idx[:,5:10])for i in range(4)])
lossDict['Bilateral']= lossDict['Bilateral_im2'] + lossDict['Bilateral_im1']
lossDict["temporal"]=tf.reduce_mean(tf.multiply(tf.abs(objDict["warped"]-
C1),tf.tile(objDict["mask"],[1,1,1,4])))*5
lossDict["total"]=lossDict["RankDiv"]
+lossDict["temporal"]#+lossDict['Bilateral']
objDict["prediction_0"]=tf.concat([C0[:,:,:,0:3],C0[:,:,:,9:12],C0[:,:,:,3:6],C0[:,
:,:,6:9]],axis=2)
objDict["prediction_1"]=tf.concat([C1[:,:,:,0:3],C1[:,:,:,9:12],C1[:,:,:,3:6],C1[:,
:,:,6:9]],axis=2)
#-------------RefineNet---------------#
cmap_C, warp_C0 = occlusion_mask(c0, c1, gray_flow_backward[:,:,:,0:2])
cmap_X, warp_X0 = occlusion_mask(tf.tile(input_i[:,:,:,0:1], [1,1,1,3]),
tf.tile(input_i[:,:,:,1:2],[1,1,1,3]),gray_flow_backward[:,:,:,0:2])
low_conf_mask = tf.cast(tf.greater(cmap_X - cmap_C, 0), tf.float32)
opt=tf.train.AdamOptimizer(learning_rate=0.0001).minimize(lossDict["total"],var_lis
t=[var for var in tf.trainable_variables()])
opt2=tf.train.AdamOptimizer(learning_rate=0.0001).minimize(0.2*lossDict["RankDiv_im
1"],var_list=[var for var in tf.trainable_variables()])
opt_refine=tf.train.AdamOptimizer(learning_rate=0.0001).minimize(temporal_g_loss,va
r_list=[var for var in tf.trainable_variables() if var.name.startswith('VCRN')])
saver=tf.train.Saver(max_to_keep=1000)
sess.run([tf.global_variables_initializer()])
neigh=NearestNeighbors(n_neighbors=5)
maxepoch=1001
num_train=len(train_low)
print("Number of training images: ", num_train)
# Video VCN
cnt=0
all_D1, all_D2, all_B1, all_B2, all_T, all_loss = 0,0,0,0,0,0
for id in np.random.permutation(num_train):
st=time.time()
if input_list_src[id] is None:
input_list_src[id], input_list_target[id],
input_list_flow_forward[id], input_list_flow_backward[id] =
prepare_input_w_flow(train_low[id], num_frames=num_frame)
if input_list_src[id] is None or input_list_target[id] is None or
input_list_flow_forward[id] is None:
continue
input_frames_processed = input_list_src[id]
feed_dict={input_i:input_list_src[id],input_target:input_list_target[id],\
input_flow_backward:input_list_flow_backward[id],\
input_idx: np.concatenate([idxs1,idxs2],axis=1)})
all_D1 += out_loss["RankDiv_im1"]
all_D2 += out_loss["RankDiv_im2"]
all_B1 += out_loss["Bilateral_im1"]
all_B2 += out_loss["Bilateral_im2"]
all_T += out_loss["temporal"]
all_loss += out_loss["total"]
cnt+=1
print("iter: %d %d %.2fs loss: %.4f %.4f|| (D1) %.4f %.4f (D2) %.4f
%.4f || (B1) %.4f %.4f (B2) %.4f %.4f (T) %.4f %.4f"\
%(epoch,cnt,out_loss["total"],all_loss/cnt, time.time()-st,\
out_loss["RankDiv_im1"], all_D1/cnt, out_loss["RankDiv_im2"],
all_D2/cnt,\
out_loss["Bilateral_im1"], all_B1/cnt,
out_loss["Bilateral_im2"], all_B2/cnt,\
out_loss["temporal"], all_T/cnt))
# Video Refine
if epoch > 0:
_, _, gray_list_flow_forward[id], gray_list_flow_backward[id] =
prepare_input_w_flow(train_low[id], num_frames=num_frame)
_, out_loss, final_C1 = sess.run([opt_refine, temporal_g_loss,
final_r1],\
feed_dict={c0:C0_im[:,:,:,0:3], c1: C1_im[:,:,:,0:3],
input_i:input_list_src[id],input_target:input_list_target[id],\
input_flow_backward:input_list_flow_backward[id],\
gray_flow_backward:gray_list_flow_backward[id],\
input_idx: np.concatenate([idxs1,idxs2],axis=1)})
print("iter: %d %d || Refine || loss: %.4f %.4f"%
(epoch,cnt,out_loss,time.time()-st))
if cnt>=1000:
break
# Validation
if not os.path.isdir("%s/%04d"%(model,epoch)):
os.makedirs("%s/%04d"%(model,epoch))
if epoch % save_freq == 0:
numtest=len(test_low)
all_loss_test=np.zeros(numtest, dtype=float)
for ind in range(numtest):
if ind>30 and epoch%25>0:
break
input_image_src, input_image_target, input_flow_forward_src,
input_flow_backward_src = prepare_input_w_flow(test_low[int(ind*60/pow(60,int(epoch
%25==0)))],num_frames=num_frame)
if input_image_src is None or input_image_target is None or
input_flow_forward_src is None:
print("Not able to read the images/flows.")
flag=True
continue
st=time.time()
C0_imall,C1_imall,C0_im, C1_im,
warped,mask=sess.run([objDict["prediction_0"],objDict["prediction_1"],C0,
C1,objDict["warped"],objDict['mask']],feed_dict={input_i:input_image_src,
input_target:input_image_target,
input_flow_backward:input_flow_backward_src
})
print("test time for %s --> %.3f"%(ind, time.time()-st))
input_image_src, input_image_target, gray_flow_forward_src,
gray_flow_backward_src = prepare_input_w_flow(test_low[int(ind*60/pow(60,int(epoch
%25==0)))],num_frames=num_frame)
h,w = C0_im.shape[1:3]
outputs= []
for ref_i in range(4):
output, out_cmap_C, out_cmap_X, out_low_conf_mask =
sess.run([final_r1,
cmap_C,cmap_X,low_conf_mask],feed_dict={c0:C0_im[:,:,:,ref_i*3:ref_i*3+3],
c1:C1_im[:,:,:,ref_i*3:ref_i*3+3], \
input_i:input_image_src,
input_target:input_image_target, \
gray_flow_backward:gray_flow_backward_src,
input_flow_backward:input_flow_backward_src})
outputs.append(output[0,:,:,:])
# Debug
saver.save(sess,"%s/model.ckpt"%model)
if epoch%10==0:
saver.save(sess,"%s/%04d/model.ckpt"%(model,epoch))
# Inference
else:
test_low=utils.get_names(test_dir)
numtest=len(test_low)
print(test_low[0])
out_folder = test_dir.split('/')[-1]
outputs= [None]*4
for ind in range(numtest):
input_image_src, input_image_target, input_flow_forward_src,
input_flow_backward_src =
prepare_input_w_flow(test_low[ind],num_frames=num_frame,gray=True)
if input_image_src is None or input_image_target is None or
input_flow_forward_src is None:
print("Not able to read the images/flows.")
continue
st=time.time()
C0_imall,C1_imall,C0_im,
C1_im=sess.run([objDict["prediction_0"],objDict["prediction_1"],C0,
C1],feed_dict={input_i:input_image_src,
input_target:input_image_target,
input_flow_backward:input_flow_backward_src
})
print("test time for %s --> %.3f"%(ind, time.time()-st))
h,w = C0_im.shape[1:3]
print(C0_im.shape)
if not os.path.isdir("%s/%s" % (model, out_folder)):
os.makedirs("%s/%s/predictions" % (model, out_folder))
os.makedirs("%s/%s/predictions0" % (model, out_folder))
os.makedirs("%s/%s/predictions1" % (model, out_folder))
os.makedirs("%s/%s/predictions2" % (model, out_folder))
os.makedirs("%s/%s/predictions3" % (model, out_folder))
if ind == 0:
for ref_i in range(4):
output,_ = sess.run([final_r1,
temporal_g_loss],feed_dict={c0:C0_im[:,:,:,ref_i*3:ref_i*3+3],
c1:C1_im[:,:,:,ref_i*3:ref_i*3+3], \
input_i:input_image_src, input_target:input_image_target, \
gray_flow_backward:input_flow_backward_src,
input_flow_backward:input_flow_backward_src})
outputs[ref_i] = output
sic.imsave("%s/%s/predictions%d/final_%06d.jpg"%(model, out_folder,
ref_i, ind),np.uint8(np.maximum(np.minimum(C0_im[0,:,:,ref_i*3:ref_i*3+3] *
255.0,255.0),0.0)))
sic.imsave("%s/%s/predictions%d/final_%06d.jpg"%(model, out_folder,
ref_i, ind+1),np.uint8(np.maximum(np.minimum(output[0,:,:,:] * 255.0,255.0),0.0)))
sic.imsave("%s/%s/predictions/predictions_%06d.jpg"%(model, out_folder,
ind+1),np.uint8(np.maximum(np.minimum(C1_imall[0,:,:,:] * 255.0,255.0),0.0)))
sic.imsave("%s/%s/predictions/final_%06d.jpg"%(model, out_folder,
ind+1),np.uint8(np.maximum(np.minimum(np.concatenate(outputs,axis=2)[0,:,:,:] *
255.0,255.0),0.0)))
else:
for ref_i in range(4):
output,_ = sess.run([final_r1,
temporal_g_loss],feed_dict={c0:outputs[ref_i], c1:C1_im[:,:,:,:3], \
input_i:input_image_src, input_target:input_image_target, \
gray_flow_backward:input_flow_backward_src,
input_flow_backward:input_flow_backward_src})
outputs.append(output[0,:,:,:])
sic.imsave("%s/%s/predictions%d/final_%06d.jpg"%(model, out_folder,
ref_i, ind+1),np.uint8(np.maximum(np.minimum(output[0,:,:,:] * 255.0,255.0),0.0)))
sic.imsave("%s/%s/predictions/predictions_%06d.jpg"%(model, out_folder,
ind+1),np.uint8(np.maximum(np.minimum(C1_imall[0,:,:,:] * 255.0,255.0),0.0)))