run.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. import cv2
  2. #Import Neural Network Model
  3. from gan import DataLoader, DeepModel, tensor2im
  4. #OpenCv Transform:
  5. from opencv_transform.mask_to_maskref import create_maskref
  6. from opencv_transform.maskdet_to_maskfin import create_maskfin
  7. from opencv_transform.dress_to_correct import create_correct
  8. from opencv_transform.nude_to_watermark import create_watermark
  9. """
  10. run.py
  11. This script manage the entire transormation.
  12. Transformation happens in 6 phases:
  13. 0: dress -> correct [opencv] dress_to_correct
  14. 1: correct -> mask: [GAN] correct_to_mask
  15. 2: mask -> maskref [opencv] mask_to_maskref
  16. 3: maskref -> maskdet [GAN] maskref_to_maskdet
  17. 4: maskdet -> maskfin [opencv] maskdet_to_maskfin
  18. 5: maskfin -> nude [GAN] maskfin_to_nude
  19. 6: nude -> watermark [opencv] nude_to_watermark
  20. """
  21. phases = ["dress_to_correct", "correct_to_mask", "mask_to_maskref", "maskref_to_maskdet", "maskdet_to_maskfin", "maskfin_to_nude", "nude_to_watermark"]
  22. class Options():
  23. #Init options with default values
  24. def __init__(self):
  25. # experiment specifics
  26. self.norm = 'batch' #instance normalization or batch normalization
  27. self.use_dropout = False #use dropout for the generator
  28. self.data_type = 32 #Supported data type i.e. 8, 16, 32 bit
  29. # input/output sizes
  30. self.batchSize = 1 #input batch size
  31. self.input_nc = 3 # of input image channels
  32. self.output_nc = 3 # of output image channels
  33. # for setting inputs
  34. self.serial_batches = True #if true, takes images in order to make batches, otherwise takes them randomly
  35. self.nThreads = 1 ## threads for loading data (???)
  36. self.max_dataset_size = 1 #Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.
  37. # for generator
  38. self.netG = 'global' #selects model to use for netG
  39. self.ngf = 64 ## of gen filters in first conv layer
  40. self.n_downsample_global = 4 #number of downsampling layers in netG
  41. self.n_blocks_global = 9 #number of residual blocks in the global generator network
  42. self.n_blocks_local = 0 #number of residual blocks in the local enhancer network
  43. self.n_local_enhancers = 0 #number of local enhancers to use
  44. self.niter_fix_global = 0 #number of epochs that we only train the outmost local enhancer
  45. #Phase specific options
  46. self.checkpoints_dir = ""
  47. self.dataroot = ""
  48. #Changes options accordlying to actual phase
  49. def updateOptions(self, phase):
  50. if phase == "correct_to_mask":
  51. self.checkpoints_dir = "checkpoints/cm.lib"
  52. elif phase == "maskref_to_maskdet":
  53. self.checkpoints_dir = "checkpoints/mm.lib"
  54. elif phase == "maskfin_to_nude":
  55. self.checkpoints_dir = "checkpoints/mn.lib"
  56. # process(cv_img, mode)
  57. # return:
  58. # watermark image
  59. def process(cv_img):
  60. #InMemory cv2 images:
  61. dress = cv_img
  62. correct = None
  63. mask = None
  64. maskref = None
  65. maskfin = None
  66. maskdet = None
  67. nude = None
  68. watermark = None
  69. for index, phase in enumerate(phases):
  70. print("Executing phase: " + phase)
  71. #GAN phases:
  72. if (phase == "correct_to_mask") or (phase == "maskref_to_maskdet") or (phase == "maskfin_to_nude"):
  73. #Load global option
  74. opt = Options()
  75. #Load custom phase options:
  76. opt.updateOptions(phase)
  77. #Load Data
  78. if (phase == "correct_to_mask"):
  79. data_loader = DataLoader(opt, correct)
  80. elif (phase == "maskref_to_maskdet"):
  81. data_loader = DataLoader(opt, maskref)
  82. elif (phase == "maskfin_to_nude"):
  83. data_loader = DataLoader(opt, maskfin)
  84. dataset = data_loader.load_data()
  85. #Create Model
  86. model = DeepModel()
  87. model.initialize(opt)
  88. #Run for every image:
  89. for i, data in enumerate(dataset):
  90. generated = model.inference(data['label'], data['inst'])
  91. im = tensor2im(generated.data[0])
  92. #Save Data
  93. if (phase == "correct_to_mask"):
  94. mask = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
  95. elif (phase == "maskref_to_maskdet"):
  96. maskdet = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
  97. elif (phase == "maskfin_to_nude"):
  98. nude = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
  99. #Correcting:
  100. elif (phase == 'dress_to_correct'):
  101. correct = create_correct(dress)
  102. #mask_ref phase (opencv)
  103. elif (phase == "mask_to_maskref"):
  104. maskref = create_maskref(mask, correct)
  105. #mask_fin phase (opencv)
  106. elif (phase == "maskdet_to_maskfin"):
  107. maskfin = create_maskfin(maskref, maskdet)
  108. #nude_to_watermark phase (opencv)
  109. elif (phase == "nude_to_watermark"):
  110. watermark = create_watermark(nude)
  111. return watermark