Skip to content

Commit fd3986e

Browse files
committed
remove ddp when using a single gpu
1 parent 18f1d20 commit fd3986e

File tree

8 files changed

+212
-176
lines changed

8 files changed

+212
-176
lines changed

dataloader/blender.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ def train_dataloader(self):
5757
import torch_xla.core.xla_model as xm
5858
sampler = SingleImageDDPSampler(
5959
self.batch_size, xm.xrt_world_size(), xm.get_ordinal(),
60-
len(self.i_train), self.image_len, self.args.i_validation
60+
len(self.i_train), self.image_len, self.args.i_validation, True
6161
)
6262
else:
6363
sampler = SingleImageDDPSampler(
6464
self.batch_size, None, None, len(self.i_train),
65-
self.image_len, self.args.i_validation
65+
self.image_len, self.args.i_validation, False
6666
)
6767
return DataLoader(
6868
self.train_dset, batch_sampler=sampler, num_workers=self.args.num_workers,
@@ -73,9 +73,13 @@ def val_dataloader(self):
7373

7474
if self.args.tpu:
7575
import torch_xla.core.xla_model as xm
76-
sampler = DDPSequnetialSampler(self.chunk, xm.xrt_world_size(), xm.get_ordinal(), len(self.val_dset))
76+
sampler = DDPSequnetialSampler(
77+
self.chunk, xm.xrt_world_size(), xm.get_ordinal(), len(self.val_dset), True
78+
)
7779
else:
78-
sampler = DDPSequnetialSampler(self.chunk, None, None, len(self.val_dset))
80+
sampler = DDPSequnetialSampler(
81+
self.chunk, None, None, len(self.val_dset), False
82+
)
7983

8084
return DataLoader(
8185
self.val_dset, batch_size=self.chunk, sampler=sampler,
@@ -86,9 +90,13 @@ def test_dataloader(self):
8690

8791
if self.args.tpu:
8892
import torch_xla.core.xla_model as xm
89-
sampler = DDPSequnetialSampler(self.chunk, xm.xrt_world_size(), xm.get_ordinal(), len(self.test_dset))
93+
sampler = DDPSequnetialSampler(
94+
self.chunk, xm.xrt_world_size(), xm.get_ordinal(), len(self.test_dset), True
95+
)
9096
else:
91-
sampler = DDPSequnetialSampler(self.chunk, None, None, len(self.test_dset))
97+
sampler = DDPSequnetialSampler(
98+
self.chunk, None, None, len(self.test_dset), False
99+
)
92100

93101
return DataLoader(
94102
self.test_dset, batch_size=self.chunk, sampler=sampler,
@@ -98,9 +106,13 @@ def test_dataloader(self):
98106
def predict_dataloader(self):
99107
if self.args.tpu:
100108
import torch_xla.core.xla_model as xm
101-
sampler = DDPSequnetialSampler(self.chunk, xm.xrt_world_size(), xm.get_ordinal(), len(self.predict_dset))
109+
sampler = DDPSequnetialSampler(
110+
self.chunk, xm.xrt_world_size(), xm.get_ordinal(), len(self.predict_dset), True
111+
)
102112
else:
103-
sampler = DDPSequnetialSampler(self.chunk, None, None, len(self.predict_dset))
113+
sampler = DDPSequnetialSampler(
114+
self.chunk, None, None, len(self.predict_dset), False
115+
)
104116

105117
return DataLoader(
106118
self.predict_dset, batch_size=self.args.chunk, sampler=sampler,

dataloader/llff.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ def train_dataloader(self):
6161
import torch_xla.core.xla_model as xm
6262
sampler = MultipleImageDDPSampler(
6363
self.batch_size, xm.xrt_world_size(), xm.get_ordinal(),
64-
len(self.train_dset), self.args.i_validation
64+
len(self.train_dset), self.args.i_validation, True
6565
)
6666
else:
6767
sampler = MultipleImageDDPSampler(
68-
self.batch_size, None, None, len(self.train_dset), self.args.i_validation
68+
self.batch_size, None, None, len(self.train_dset), self.args.i_validation, False
6969
)
7070

7171
return DataLoader(
@@ -77,9 +77,13 @@ def val_dataloader(self):
7777

7878
if self.args.tpu:
7979
import torch_xla.core.xla_model as xm
80-
sampler = DDPSequnetialSampler(self.chunk, xm.xrt_world_size(), xm.get_ordinal(), len(self.val_dset))
80+
sampler = DDPSequnetialSampler(
81+
self.chunk, xm.xrt_world_size(), xm.get_ordinal(), len(self.val_dset), True
82+
)
8183
else:
82-
sampler = DDPSequnetialSampler(self.chunk, None, None, len(self.val_dset))
84+
sampler = DDPSequnetialSampler(
85+
self.chunk, None, None, len(self.val_dset), False
86+
)
8387

8488
return DataLoader(
8589
self.val_dset, batch_size=self.chunk, sampler=sampler,
@@ -90,9 +94,13 @@ def test_dataloader(self):
9094

9195
if self.args.tpu:
9296
import torch_xla.core.xla_model as xm
93-
sampler = DDPSequnetialSampler(self.chunk, xm.xrt_world_size(), xm.get_ordinal(), len(self.test_dset))
97+
sampler = DDPSequnetialSampler(
98+
self.chunk, xm.xrt_world_size(), xm.get_ordinal(), len(self.test_dset), True
99+
)
94100
else:
95-
sampler = DDPSequnetialSampler(self.chunk, None, None, len(self.test_dset))
101+
sampler = DDPSequnetialSampler(
102+
self.chunk, None, None, len(self.test_dset), False
103+
)
96104

97105
return DataLoader(
98106
self.test_dset, batch_size=self.chunk, sampler=sampler,
@@ -102,9 +110,13 @@ def test_dataloader(self):
102110
def predict_dataloader(self):
103111
if self.args.tpu:
104112
import torch_xla.core.xla_model as xm
105-
sampler = DDPSequnetialSampler(self.chunk, xm.xrt_world_size(), xm.get_ordinal(), len(self.predict_dset))
113+
sampler = DDPSequnetialSampler(
114+
self.chunk, xm.xrt_world_size(), xm.get_ordinal(), len(self.predict_dset), True
115+
)
106116
else:
107-
sampler = DDPSequnetialSampler(self.chunk, None, None, len(self.predict_dset))
117+
sampler = DDPSequnetialSampler(
118+
self.chunk, None, None, len(self.predict_dset), False
119+
)
108120

109121
return DataLoader(
110122
self.predict_dset, batch_size=self.chunk, sampler=sampler,

dataloader/sampler.py

+42-30
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,23 @@
77

88

99
class DDPSequnetialSampler(SequentialSampler):
10-
def __init__(self, batch_size, num_replicas, rank, N_total):
10+
def __init__(self, batch_size, num_replicas, rank, N_total, tpu):
1111
self.data_source=None
1212
self.batch_size = batch_size
1313
self.N_total = N_total
1414
self.drop_last=False
15-
if num_replicas is None:
16-
if not dist.is_available():
17-
raise RuntimeError("Requires distributed package to be available")
18-
num_replicas = dist.get_world_size()
19-
if rank is None:
20-
if not dist.is_available():
21-
raise RuntimeError("Requires distributed package to be available")
22-
rank = dist.get_rank()
15+
ngpus = torch.cuda.device_count()
16+
if ngpus == 1 and not tpu:
17+
rank, num_replicas = 0, 1
18+
else:
19+
if num_replicas is None:
20+
if not dist.is_available():
21+
raise RuntimeError("Requires distributed package to be available")
22+
num_replicas = dist.get_world_size()
23+
if rank is None:
24+
if not dist.is_available():
25+
raise RuntimeError("Requires distributed package to be available")
26+
rank = dist.get_rank()
2327
self.rank = rank
2428
self.num_replicas = num_replicas
2529

@@ -32,20 +36,24 @@ def __len__(self):
3236

3337

3438
class SingleImageDDPSampler:
35-
def __init__(self, batch_size, num_replicas, rank, N_img, N_pixels, i_validation):
39+
def __init__(self, batch_size, num_replicas, rank, N_img, N_pixels, i_validation, tpu):
3640
self.batch_size = batch_size
3741
self.N_pixels = N_pixels
3842
self.N_img = N_img
3943
self.drop_last = False
4044
self.i_validation = i_validation
41-
if num_replicas is None:
42-
if not dist.is_available():
43-
raise RuntimeError("Requires distributed package to be available")
44-
num_replicas = dist.get_world_size()
45-
if rank is None:
46-
if not dist.is_available():
47-
raise RuntimeError("Require distributed package to be available")
48-
rank = dist.get_rank()
45+
ngpus = torch.cuda.device_count()
46+
if ngpus == 1 and not tpu:
47+
rank, num_replicas = 0, 1
48+
else:
49+
if num_replicas is None:
50+
if not dist.is_available():
51+
raise RuntimeError("Requires distributed package to be available")
52+
num_replicas = dist.get_world_size()
53+
if rank is None:
54+
if not dist.is_available():
55+
raise RuntimeError("Require distributed package to be available")
56+
rank = dist.get_rank()
4957
self.rank = rank
5058
self.num_replicas = num_replicas
5159

@@ -66,21 +74,25 @@ def __len__(self):
6674

6775

6876
class MultipleImageDDPSampler(DistributedSampler):
69-
def __init__(self, batch_size, num_replicas, rank, total_len, i_validation):
77+
def __init__(self, batch_size, num_replicas, rank, total_len, i_validation, tpu):
7078
self.batch_size = batch_size
7179
self.total_len = total_len
7280
self.i_validation = i_validation
73-
self.drop_last = False
74-
if num_replicas is None:
75-
if not dist.is_available():
76-
raise RuntimeError(
77-
"Require distributed package to be available")
78-
num_replicas = dist.get_world_size()
79-
if rank is None:
80-
if not dist.is_available():
81-
raise RuntimeError(
82-
"Require distributed package to be available")
83-
rank = dist.get_rank()
81+
self.drop_last = False
82+
ngpus = torch.cuda.device_count()
83+
if ngpus == 1 and not tpu:
84+
rank, num_replicas = 0, 1
85+
else:
86+
if num_replicas is None:
87+
if not dist.is_available():
88+
raise RuntimeError(
89+
"Require distributed package to be available")
90+
num_replicas = dist.get_world_size()
91+
if rank is None:
92+
if not dist.is_available():
93+
raise RuntimeError(
94+
"Require distributed package to be available")
95+
rank = dist.get_rank()
8496
self.num_replicas = num_replicas
8597
self.rank = rank
8698

run.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@
6969
tpu_cores=args.tpu_num if args.tpu else None,
7070
replace_sampler_ddp=False,
7171
deterministic=True,
72-
strategy=DDPPlugin(
73-
find_unused_parameters=False) if not args.tpu else None,
72+
strategy=DDPPlugin(find_unused_parameters=False) \
73+
if n_gpus > 1 and not args.tpu else None,
7474
check_val_every_n_epoch=1,
7575
precision=32,
7676
num_sanity_val_steps=0,

scripts/jaxnerf_torch/eval.sh

+32-32
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,72 @@
11

22
# LLFF
33
python3 -m run --config configs/jaxnerf_torch/llff.yaml --datadir data/llff/fern \
4-
--expname fern --eval --model jaxnerf_torch
4+
--expname fern --eval
55
python3 -m run --config configs/jaxnerf_torch/llff.yaml --datadir data/llff/flower \
6-
--expname flower --eval --model jaxnerf_torch
6+
--expname flower --eval
77
python3 -m run --config configs/jaxnerf_torch/llff.yaml --datadir data/llff/fortress \
8-
--expname fortress --eval --model jaxnerf_torch
8+
--expname fortress --eval
99
python3 -m run --config configs/jaxnerf_torch/llff.yaml --datadir data/llff/horns \
10-
--expname horns --eval --model jaxnerf_torch
10+
--expname horns --eval
1111
python3 -m run --config configs/jaxnerf_torch/llff.yaml --datadir data/llff/leaves \
12-
--expname leaves --eval --model jaxnerf_torch
12+
--expname leaves --eval
1313
python3 -m run --config configs/jaxnerf_torch/llff.yaml --datadir data/llff/orchids \
14-
--expname orchids --eval --model jaxnerf_torch
14+
--expname orchids --eval
1515
python3 -m run --config configs/jaxnerf_torch/llff.yaml --datadir data/llff/room \
16-
--expname room --eval --model jaxnerf_torch
16+
--expname room --eval
1717
python3 -m run --config configs/jaxnerf_torch/llff.yaml --datadir data/llff/trex \
18-
--expname trex --eval --model jaxnerf_torch
18+
--expname trex --eval
1919

2020
# LLFF_LARGE
2121
python3 -m run --config configs/jaxnerf_torch/llff_large.yaml --datadir data/llff/fern \
22-
--expname fern_large --eval --model jaxnerf_torch
22+
--expname fern_large --eval
2323
python3 -m run --config configs/jaxnerf_torch/llff_large.yaml --datadir data/llff/flower \
24-
--expname flower_large --eval --model jaxnerf_torch
24+
--expname flower_large --eval
2525
python3 -m run --config configs/jaxnerf_torch/llff_large.yaml --datadir data/llff/fortress \
26-
--expname fortress_large --eval --model jaxnerf_torch
26+
--expname fortress_large --eval
2727
python3 -m run --config configs/jaxnerf_torch/llff_large.yaml --datadir data/llff/horns \
28-
--expname horns_large --eval --model jaxnerf_torch
28+
--expname horns_large --eval
2929
python3 -m run --config configs/jaxnerf_torch/llff_large.yaml --datadir data/llff/leaves \
30-
--expname leaves_large --eval --model jaxnerf_torch
30+
--expname leaves_large --eval
3131
python3 -m run --config configs/jaxnerf_torch/llff_large.yaml --datadir data/llff/orchids \
32-
--expname orchids_large --eval --model jaxnerf_torch
32+
--expname orchids_large --eval
3333
python3 -m run --config configs/jaxnerf_torch/llff_large.yaml --datadir data/llff/room \
34-
--expname room_large --eval --model jaxnerf_torch
34+
--expname room_large --eval
3535
python3 -m run --config configs/jaxnerf_torch/llff_large.yaml --datadir data/llff/trex \
36-
--expname trex_large --eval --model jaxnerf_torch
36+
--expname trex_large --eval
3737

3838
# BLENDER
3939
python3 -m run --config configs/jaxnerf_torch/blender.yaml --datadir data/blender/chair \
40-
--expname chair --eval --model jaxnerf_torch
40+
--expname chair --eval
4141
python3 -m run --config configs/jaxnerf_torch/blender.yaml --datadir data/blender/drums \
42-
--expname drums --eval --model jaxnerf_torch
42+
--expname drums --eval
4343
python3 -m run --config configs/jaxnerf_torch/blender.yaml --datadir data/blender/ficus \
44-
--expname ficus --eval --model jaxnerf_torch
44+
--expname ficus --eval
4545
python3 -m run --config configs/jaxnerf_torch/blender.yaml --datadir data/blender/hotdog \
46-
--expname hotdog --eval --model jaxnerf_torch
46+
--expname hotdog --eval
4747
python3 -m run --config configs/jaxnerf_torch/blender.yaml --datadir data/blender/lego \
48-
--expname lego --eval --model jaxnerf_torch
48+
--expname lego --eval
4949
python3 -m run --config configs/jaxnerf_torch/blender.yaml --datadir data/blender/materials \
50-
--expname materials --eval --model jaxnerf_torch
50+
--expname materials --eval
5151
python3 -m run --config configs/jaxnerf_torch/blender.yaml --datadir data/blender/mic \
52-
--expname mic --eval --model jaxnerf_torch
52+
--expname mic --eval
5353
python3 -m run --config configs/jaxnerf_torch/blender.yaml --datadir data/blender/ship \
54-
--expname ship --eval --model jaxnerf_torch
54+
--expname ship --eval
5555

5656
# BLENDER LARGE
5757
python3 -m run --config configs/jaxnerf_torch/blender_large.yaml --datadir data/blender/chair \
58-
--expname chair_large --eval --model jaxnerf_torch
58+
--expname chair_large --eval
5959
python3 -m run --config configs/jaxnerf_torch/blender_large.yaml --datadir data/blender/drums \
60-
--expname drums_large --eval --model jaxnerf_torch
60+
--expname drums_large --eval
6161
python3 -m run --config configs/jaxnerf_torch/blender_large.yaml --datadir data/blender/ficus \
62-
--expname ficus_large --eval --model jaxnerf_torch
62+
--expname ficus_large --eval
6363
python3 -m run --config configs/jaxnerf_torch/blender_large.yaml --datadir data/blender/hotdog \
64-
--expname hotdog_large --eval --model jaxnerf_torch
64+
--expname hotdog_large --eval
6565
python3 -m run --config configs/jaxnerf_torch/blender_large.yaml --datadir data/blender/lego \
66-
--expname lego_large --eval --model jaxnerf_torch
66+
--expname lego_large --eval
6767
python3 -m run --config configs/jaxnerf_torch/blender_large.yaml --datadir data/blender/materials \
68-
--expname materials_large --eval --model jaxnerf_torch
68+
--expname materials_large --eval
6969
python3 -m run --config configs/jaxnerf_torch/blender_large.yaml --datadir data/blender/mic \
70-
--expname mic_large --eval --model jaxnerf_torch
70+
--expname mic_large --eval
7171
python3 -m run --config configs/jaxnerf_torch/blender_large.yaml --datadir data/blender/ship \
72-
--expname ship_large --eval --model jaxnerf_torch
72+
--expname ship_large --eval

0 commit comments

Comments
 (0)