-
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdmr_data.py
116 lines (94 loc) · 3.48 KB
/
dmr_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = "Christian Heider Nielsen"
__doc__ = r"""
Created on 09/10/2019
"""
from typing import Any
import numpy
import torch
import torch.utils.data
from draugr.torch_utilities import channel_transform, to_tensor
from neodroid.environments.droid_environment import DictUnityEnvironment
from torch.nn.functional import binary_cross_entropy_with_logits
from warg import NOD
from neodroidvision.segmentation import dice_loss, jaccard_loss
def neodroid_camera_data_iterator(
env: DictUnityEnvironment, device: torch.device, batch_size: int = 12
) -> Any:
"""
:param env:
:type env:
:param device:
:type device:
:param batch_size:
:type batch_size:"""
while True:
rgb = []
mask_responses = []
depth_responses = []
normals_responses = []
while len(rgb) < batch_size:
env.update()
rgb_arr = env._sensor("RGB")
seg_arr = env._sensor("Layer")
depth_arr = env._sensor("CompressedDepth")
normal_arr = env._sensor("Normal")
red_mask = numpy.zeros(seg_arr.shape[:-1])
green_mask = numpy.zeros(seg_arr.shape[:-1])
blue_mask = numpy.zeros(seg_arr.shape[:-1])
# alpha_mask = numpy.ones(seg_arr.shape[:-1])
reddish = seg_arr[..., 0] > 50
greenish = seg_arr[..., 1] > 50
blueish = seg_arr[..., 2] > 50
red_mask[reddish] = 1
green_mask[greenish] = 1
blue_mask[blueish] = 1
depth_image = numpy.zeros(depth_arr.shape[:-1])
depth_image[:, :] = depth_arr[..., 0]
rgb.append(channel_transform.hwc_to_chw_tensor(rgb_arr)[:3, :, :])
mask_responses.append(numpy.asarray([red_mask, blue_mask, green_mask]))
depth_responses.append(
numpy.clip(numpy.asarray([depth_image / 255.0]), 0, 1)
)
normals_responses.append(
channel_transform.hwc_to_chw_tensor(normal_arr)[:3, :, :]
)
yield (
to_tensor(rgb, device=device),
(
to_tensor(mask_responses, device=device),
to_tensor(depth_responses, device=device),
to_tensor(normals_responses, device=device),
),
)
def calculate_multi_auto_encoder_loss(seg, recon, depth, normals):
"""
:param seg:
:type seg:
:param recon:
:type recon:
:param depth:
:type depth:
:param normals:
:type normals:
:return:
:rtype:"""
(
(seg_pred, seg_target),
(recon_pred, recon_target),
(depth_pred, depth_target),
(normals_pred, normals_target),
) = (seg, recon, depth, normals)
seg_bce_loss = binary_cross_entropy_with_logits(seg_pred, seg_target)
ae_bce_loss = binary_cross_entropy_with_logits(recon_pred, recon_target)
normals_bce_loss = binary_cross_entropy_with_logits(normals_pred, normals_target)
depth_bce_loss = binary_cross_entropy_with_logits(depth_pred, depth_target)
pred_soft = torch.sigmoid(seg_pred)
dice = dice_loss(pred_soft, seg_target, epsilon=1)
jaccard = jaccard_loss(pred_soft, seg_target, epsilon=1)
terms = (dice, jaccard, ae_bce_loss, seg_bce_loss, depth_bce_loss, normals_bce_loss)
term_weight = 1 / len(terms)
weighted_terms = [term.mean() * term_weight for term in terms]
loss = sum(weighted_terms)
return NOD(loss=loss, terms=terms)