Skip to content

Commit 66c41ee

Browse files
committed
implement install.py
1 parent ca5b24b commit 66c41ee

File tree

5 files changed

+123
-4
lines changed

5 files changed

+123
-4
lines changed

embodiedscan/datasets/transforms/loading.py

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __repr__(self):
9393
return repr_str
9494

9595

96+
# TODO : refine
9697
@TRANSFORMS.register_module()
9798
class LoadAnnotations3D(LoadAnnotations):
9899
"""Load Annotations3D.

embodiedscan/explorer.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
2+
import pickle
23
from typing import List, Union
34

4-
import mmengine
55
import numpy as np
66
import open3d as o3d
77

@@ -68,7 +68,8 @@ def __init__(self,
6868
self.metainfo = None
6969
data_list = []
7070
for file in self.ann_files:
71-
data = mmengine.load(file)
71+
with open(file, 'rb') as f:
72+
data = pickle.load(f)
7273
if self.metainfo is None:
7374
self.metainfo = data['metainfo']
7475
else:

embodiedscan/visualization/continuous_drawer.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
2+
import pickle
23

34
import cv2
4-
import mmengine
55
import numpy as np
66
import open3d as o3d
77

@@ -223,7 +223,8 @@ def begin(self):
223223
print('You can also press Esc to close window immediately,',
224224
'which may result in a segmentation fault.')
225225
self.gt = np.load(self.occ_path)
226-
self.mask = mmengine.load(self.mask_path)
226+
with open(self.mask_path, 'rb') as f:
227+
self.mask = pickle.load(f)
227228

228229
point_cloud_range = [-3.2, -3.2, -1.28 + 0.5, 3.2, 3.2, 1.28 + 0.5]
229230
occ_size = [40, 40, 16]

install.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import argparse
2+
import re
3+
import subprocess
4+
import sys
5+
6+
7+
def run_subprocess(command):
8+
try:
9+
process = subprocess.Popen(command,
10+
stdout=subprocess.PIPE,
11+
stderr=subprocess.PIPE,
12+
universal_newlines=True)
13+
14+
# Read output and error in real-time
15+
for line in process.stdout:
16+
print(line.strip())
17+
for line in process.stderr:
18+
print(line.strip())
19+
20+
# Wait for the subprocess to finish
21+
process.wait()
22+
23+
# Get the return code
24+
return_code = process.returncode
25+
26+
if return_code != 0:
27+
print(f'Command failed with return code {return_code}')
28+
29+
except subprocess.CalledProcessError as e:
30+
print(f'Command failed with return code {e.returncode}')
31+
print('Error output:')
32+
print(e.output.decode())
33+
34+
35+
def pytorch3d_links():
36+
try:
37+
import torch
38+
except ImportError as e:
39+
print('Pytorch is not installed.')
40+
raise e
41+
cuda_version = torch.version.cuda
42+
if cuda_version is None:
43+
print('Pytorch is cpu only.')
44+
raise NotImplementedError
45+
46+
pyt_version_str = torch.__version__.split('+')[0].replace('.', '')
47+
cuda_version_str = torch.version.cuda.replace('.', '')
48+
version_str = ''.join([
49+
f'py3{sys.version_info.minor}_cu', cuda_version_str,
50+
f'_pyt{pyt_version_str}'
51+
])
52+
pytorch3d_links = f'https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html' # noqa: E501
53+
return pytorch3d_links
54+
55+
56+
def mmcv_links():
57+
try:
58+
import torch
59+
except ImportError as e:
60+
print('Pytorch is not installed.')
61+
raise e
62+
cuda_version = torch.version.cuda
63+
if cuda_version is None:
64+
print('Pytorch is cpu only.')
65+
raise NotImplementedError
66+
67+
cuda_version_str = torch.version.cuda.replace('.', '')
68+
pyt_version = torch.__version__.split('+')[0].split('.')
69+
pyt_version_mmcv = pyt_version[0] + '.' + pyt_version[1]
70+
mmcv_links = f'https://download.openmmlab.com/mmcv/dist/cu{cuda_version_str}/torch{pyt_version_mmcv}/index.html' # noqa: E501
71+
return mmcv_links
72+
73+
74+
def install_package(line):
75+
pat = '(' + '|'.join(['>=', '==', '>', '<', '<=', '@']) + ')'
76+
parts = re.split(pat, line, maxsplit=1)
77+
package_name = parts[0].strip()
78+
print('installing', package_name)
79+
if package_name == 'pytorch3d':
80+
links = pytorch3d_links()
81+
run_subprocess(
82+
[sys.executable, '-m', 'pip', 'install', 'pytorch3d', '-f', links])
83+
elif package_name == 'mmcv':
84+
links = mmcv_links()
85+
run_subprocess(
86+
[sys.executable, '-m', 'pip', 'install', line, '-f', links])
87+
else:
88+
run_subprocess([sys.executable, '-m', 'pip', 'install', line])
89+
90+
91+
def install_requires(fname):
92+
with open(fname, 'r') as f:
93+
for line in f.readlines():
94+
line = line.strip()
95+
if line:
96+
install_package(line)
97+
98+
99+
if __name__ == '__main__':
100+
parser = argparse.ArgumentParser(
101+
description='Install Embodiedscan from pre-built package.')
102+
parser.add_argument('mode', default=None)
103+
args = parser.parse_args()
104+
105+
install_requires('requirements/base.txt')
106+
if args.mode == 'visual' or args.mode == 'all':
107+
install_requires('requirements/visual.txt')
108+
109+
if args.mode == 'run' or args.mode == 'all':
110+
install_requires('requirements/run.txt')

requirements/run.txt

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
MinkowskiEngine @ git+https://github.com/NVIDIA/MinkowskiEngine.git
2+
mmcv==2.0.0rc4
3+
mmdet
4+
mmengine
5+
ninja
6+
pytorch3d

0 commit comments

Comments
 (0)