@@ -33,14 +33,12 @@ __global__ void trilinear_fw_kernel(
3333
3434
3535torch::Tensor trilinear_fw_cu (
36- torch::Tensor feats,
37- torch::Tensor points
36+ const torch::Tensor feats,
37+ const torch::Tensor points
3838){
3939 const int N = feats.size (0 ), F = feats.size (2 );
4040
4141 torch::Tensor feat_interp = torch::zeros ({N, F}, feats.options ());
42- torch::Tensor feat_interp2 = torch::zeros ({N, F}, feats.options ());
43-
4442
4543 const dim3 threads (16 , 16 );
4644 const dim3 blocks ((N+threads.x -1 )/threads.x , (F+threads.y -1 )/threads.y );
@@ -55,4 +53,63 @@ torch::Tensor trilinear_fw_cu(
5553 }));
5654
5755 return feat_interp;
56+ }
57+
58+
59+ template <typename scalar_t >
60+ __global__ void trilinear_bw_kernel (
61+ const torch::PackedTensorAccessor<scalar_t , 2 , torch::RestrictPtrTraits, size_t > dL_dfeat_interp,
62+ const torch::PackedTensorAccessor<scalar_t , 3 , torch::RestrictPtrTraits, size_t > feats,
63+ const torch::PackedTensorAccessor<scalar_t , 2 , torch::RestrictPtrTraits, size_t > points,
64+ torch::PackedTensorAccessor<scalar_t , 3 , torch::RestrictPtrTraits, size_t > dL_dfeats
65+ ){
66+ const int n = blockIdx .x * blockDim .x + threadIdx .x ;
67+ const int f = blockIdx .y * blockDim .y + threadIdx .y ;
68+
69+ if (n>=feats.size (0 ) || f>=feats.size (2 )) return ;
70+
71+ // point -1~1
72+ const scalar_t u = (points[n][0 ]+1 )/2 ;
73+ const scalar_t v = (points[n][1 ]+1 )/2 ;
74+ const scalar_t w = (points[n][2 ]+1 )/2 ;
75+
76+ const scalar_t a = (1 -v)*(1 -w);
77+ const scalar_t b = (1 -v)*w;
78+ const scalar_t c = v*(1 -w);
79+ const scalar_t d = 1 -a-b-c;
80+
81+ dL_dfeats[n][0 ][f] = (1 -u)*a*dL_dfeat_interp[n][f];
82+ dL_dfeats[n][1 ][f] = (1 -u)*b*dL_dfeat_interp[n][f];
83+ dL_dfeats[n][2 ][f] = (1 -u)*c*dL_dfeat_interp[n][f];
84+ dL_dfeats[n][3 ][f] = (1 -u)*d*dL_dfeat_interp[n][f];
85+ dL_dfeats[n][4 ][f] = u*a*dL_dfeat_interp[n][f];
86+ dL_dfeats[n][5 ][f] = u*b*dL_dfeat_interp[n][f];
87+ dL_dfeats[n][6 ][f] = u*c*dL_dfeat_interp[n][f];
88+ dL_dfeats[n][7 ][f] = u*d*dL_dfeat_interp[n][f];
89+ }
90+
91+
92+ torch::Tensor trilinear_bw_cu (
93+ const torch::Tensor dL_dfeat_interp,
94+ const torch::Tensor feats,
95+ const torch::Tensor points
96+ ){
97+ const int N = feats.size (0 ), F = feats.size (2 );
98+
99+ torch::Tensor dL_dfeats = torch::zeros ({N, 8 , F}, feats.options ());
100+
101+ const dim3 threads (16 , 16 );
102+ const dim3 blocks ((N+threads.x -1 )/threads.x , (F+threads.y -1 )/threads.y );
103+
104+ AT_DISPATCH_FLOATING_TYPES (feats.type (), " trilinear_bw_cu" ,
105+ ([&] {
106+ trilinear_bw_kernel<scalar_t ><<<blocks, threads>>> (
107+ dL_dfeat_interp.packed_accessor <scalar_t , 2 , torch::RestrictPtrTraits, size_t >(),
108+ feats.packed_accessor <scalar_t , 3 , torch::RestrictPtrTraits, size_t >(),
109+ points.packed_accessor <scalar_t , 2 , torch::RestrictPtrTraits, size_t >(),
110+ dL_dfeats.packed_accessor <scalar_t , 3 , torch::RestrictPtrTraits, size_t >()
111+ );
112+ }));
113+
114+ return dL_dfeats;
58115}
0 commit comments