• Stars
    star
    123
  • Rank 290,145 (Top 6 %)
  • Language
    Python
  • License
    MIT License
  • Created over 2 years ago
  • Updated over 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

👩🏼‍🦰😺[ECCV'22] Official PyTorch Implementation of Sem2NeRF: Converting Single-View Semantic Masks to NeRFs

Sem2NeRF

Official PyTorch implementation of [ECCV 2022] Sem2NeRF: Converting Single-View Semantic Masks to Neural Radiance Fields by Yuedong Chen, Qianyi Wu, Chuanxia Zheng, Tat-Jen Cham and Jianfei Cai.

Abstract

Image translation and manipulation have gain increasing attention along with the rapid development of deep generative models. Although existing approaches have brought impressive results, they mainly operated in 2D space. In light of recent advances in NeRF-based 3D-aware generative models, we introduce a new task, Semantic-to-NeRF translation, that aims to reconstruct a 3D scene modelled by NeRF, conditioned on one single-view semantic mask as input. To kick-off this novel task, we propose the Sem2NeRF framework. In particular, Sem2NeRF addresses the highly challenging task by encoding the semantic mask into the latent code that controls the 3D scene representation of a pretrained decoder. To further improve the accuracy of the mapping, we integrate a new region-aware learning strategy into the design of both the encoder and the decoder. We verify the efficacy of the proposed Sem2NeRF and demonstrate that it outperforms several strong baselines on two benchmark datasets.

Recent Updates
  • 14-Jul-2022: released CatMask dataset with related training scripts.

  • 11-Jul-2022: released Sem2NeRF codes and models for CelebAMask-HQ and CatMask datasets.

  • 22-Mar-2022: initialize the Sem2NeRF repository with demo and arxiv manuscript.

Getting Started

Installation

We recommend to use Anaconda to create the running environment for the project, and all related dependencies are provided in environment/sem2nerf.yml, kindly run

git clone https://github.com/donydchen/sem2nerf.git
cd sem2nerf
conda env create -f environment/sem2nerf.yml
conda activate sem2nerf

Note: The above environment contains PyTorch 1.7 with CUDA 11, if it does not work on your machine, please refer to environment/README.md for manual installation and trouble shootings.

Download Pretrained Weights

Download the pretrained models from here, and save them to pretrained_models/. Details of files are provided in pretrained_models/README.md.

Quick Test

We have provided some input semantic masks for a quick test, kindly run

python scripts/inference3d.py --use_merged_labels --infer_paths_conf=data/CelebAMask-HQ/val_paths.txt 

If the environment is setup correctly, this command should function properly and generate some results in the folder out/sem2nerf_qtest. For more details regarding datasets, training and more tunning options for inference, kindly walk through the following sections.


Datasets

CelebAMask-HQ

  • Download the CelebAMask-HQ dataset, and extract it to data/CelebAMask-HQ/. The folder should have the following structures
data/CelebAMask-HQ/
        |__ CelebA-HQ-img/
        |__ CelebAMask-HQ-mask-anno/
        |__ CelebAMask-HQ-pose-anno.txt
        |__ mask_samples/
        |__ test_paths.txt
        |__ train_paths.txt
        |__ val_paths.txt
  • Preprocess the semantic mask data by running
python scripts/build_celeba_mask.py

This script will save the combined mask labels to data/CelebAMask-HQ/masks for training the networks.

CatMask

Inference

CelebAMask-HQ

Render high quality images and videos.

python scripts/inference3d.py \
--exp_dir=out/sem2nerf_celebahq_test \
--checkpoint_path=pretrained_models/sem2nerf_celebahq_pretrained.pt \
--data_path=data/CelebAMask-HQ/mask_samples \
--test_output_size=512 \
--pigan_infer_ray_step=72 \
--use_merged_labels \
--use_original_pose \
--latent_mask=8 \
--inject_code_seed=92 \
--render_videos

Use --render_videos to render videos with predefined camera trajetory. Change inject_code_seed and latent_mask to generate multi-modal results, e.g., --latent_mask=6,7,8 --inject_code_seed=711. More options and descriptions can be found by running python scripts/inference3d.py --help

CatMask

python scripts/inference3d.py \
--exp_dir=out/sem2nerf_catmask_test \
--dataset_type=catmask_seg_to_3dface \
--pigan_curriculum_type=CatMask \
--checkpoint_path=pretrained_models/sem2nerf_catmask_pretrained.pt \
--data_path=data/CatMask/mask_samples \
--test_output_size=512 \
--pigan_infer_ray_step=72 \
--use_merged_labels \
--use_original_pose \
--latent_mask=7,8 \
--inject_code_seed=390234 \
--render_videos

Training

CelebAMask-HQ

We use 8x32G V100 GPUs to train and fine-tune the whole framework for better visual quality. Run the following comand to run the training,

python -m torch.distributed.launch --nproc_per_node=8 \
scripts/train3d.py \
--exp_dir=out/sem2nerf_celebahq \
--workers=2 \
--batch_size=2 \
--test_output_size=128 \
--train_paths_conf=data/CelebAMask-HQ/train_paths.txt \
--test_paths_conf=data/CelebAMask-HQ/val_paths.txt \
--pigan_steps_conf=configs/pigan_steps/sem2nerf.yaml \
--val_latent_mask=8 \
--train_rand_pose_prob=0.2 \
--use_contour \
--use_merged_labels \
--patch_train \
--start_from_latent_avg

If you only have limited GPU resources, e.g., 1 GPU, and still decide to try the training process, you are recommended to set --nproc_per_node=1 --batch_size=1 --dis_lambda=0.. If it still does not work, you may consider tuning down the decoder patch size by setting resolution_vol: 64 in configs/pigan_steps/sem2nerf.yaml. Note that this may harm the performance.

Our framework also supports running without the torch.distributed.launch module for easily debugging, kindly start the program as something like python scripts/train3d.py --exp_dir=out/sem2nerf_celebahq .... Besider, it also supports training with multiple nodes multiple GPUs, dive into options/train_options.py or drop us a message if you need further instructions in this regards.

CatMask

Configurations are in general similar to CelebAMask-HQ, but it mainly needs to change some options accordingly, e.g., dataset_type, pigan_curriculum_type, train_paths_conf, test_paths_conf, label_nc, input_nc. We provide a example as below,

python -m torch.distributed.launch --nproc_per_node=8 \
scripts/train3d.py \
--exp_dir=out/sem2nerf_catmask \
--dataset_type=catmask_seg_to_3dface \
--pigan_curriculum_type=CatMask \
--train_paths_conf=data/CatMask/train_paths.txt \
--test_paths_conf=data/CatMask/val_paths.txt \
--label_nc=8 \
--input_nc=10 \
--workers=2 \
--batch_size=2 \
--dis_lambda=0.1 \
--w_norm_lambda=0.008 \
--val_latent_mask=8 \
--train_rand_pose_prob=0.5 \
--use_contour \
--use_merged_labels \
--patch_train \
--ray_min_scale=0.08 \
--start_from_latent_avg

Misc

Citations

If you use this project for your research, please cite our paper.

@article{chen2022sem2nerf,
    title={Sem2NeRF: Converting Single-View Semantic Masks to Neural Radiance Fields},
    author={Chen, Yuedong and Wu, Qianyi and Zheng, Chuanxia and Cham, Tat-Jen and Cai, Jianfei},
    journal={arXiv preprint arXiv:2203.10821},
    year={2022}
}

Pull Request

You are more than welcome to contribute to this project by sending a pull request.

Related Work

If you are interested in NeRF / neural implicit representions + semantic map, we would also like to recommend you to check out other related works:

  • Object-compositional implicit neural surfaces: [ECCV 2022] ObjectSDF.

  • Digital human animation: [ECCV 2022 oral] SSPNeRF.

Acknowledgments

Our implementation was mainly inspired by pixel2style2pixel, we also borrowed many codes from pi-GAN, GRAF, GIRAFFE and Swin-Transformer. Many thanks for all the above mentioned projects.

More Repositories

1

mvsplat

🌊 [ECCV'24] MVSplat: Efficient 3D Gaussian Splatting from Sparse Multi-View Images
Python
548
star
2

ganimation_replicate

An Out-of-the-Box Replication of GANimation using PyTorch, pretrained weights are available!
Python
216
star
3

matchnerf

🖨️[arXiv'23] Official PyTorch Implementation of MatchNeRF
Python
170
star
4

FMPN-FER

😁[VCIP'19 Oral] Official PyTorch Implementation of Facial Motion Prior Networks for Facial Expression Recognition
Python
92
star
5

causal_emotion

☯︎[ACMMM'22] Official PyTorch Implementation of Towards Unbiased Visual Emotion Recognition via Causal Intervention
Python
15
star
6

ran_replicate

A PyTorch re-implementation of Weakly Supervised Facial Action Unit Recognition through Adversarial Training
Python
10
star
7

landmark-tool

A simple image landmark tool written in pyqt.
Python
7
star
8

Douban

一个展示豆瓣Top250电影详细信息以及最新影评的Win8 Metro小应用,使用sqlite数据库,实现了“搜索,动画,共享,网络,数据存取,磁贴,多线程”
C#
7
star
9

image-caption-cpp

A data driven query expansion approach for image caption, implemented in cpp
C++
4
star
10

Dragon-Front

The comment for A Complete Front End of the dragon book.
Java
3
star
11

Agenda

A simple cpp project for freshman in SS of SYSU.
C++
3
star
12

ExprEval

A calculator based on expression, using Eclipse Java.
Java
2
star
13

EAlbum

An Electronic Album running on PXA270
Assembly
1
star
14

multimedia

Homework Projects for the Course Multimedia Technology and Applications
Python
1
star
15

CS231n

assignment of CS231n
Jupyter Notebook
1
star
16

donydchen.github.io

Yuedong CHEN's homepage and project pages.
HTML
1
star