forked from pytorch/tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
34 lines (24 loc) · 783 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
print("BEGIN preamble")
torch.ops.load_library("build/libwarp_perspective.so")
print(torch.ops.my_ops.warp_perspective(torch.randn(32, 32), torch.rand(3, 3)))
print("END preamble")
# BEGIN compute
def compute(x, y, z):
return x.matmul(y) + torch.relu(z)
# END compute
print("BEGIN trace")
inputs = [torch.randn(4, 8), torch.randn(8, 5), torch.randn(4, 5)]
trace = torch.jit.trace(compute, inputs)
print(trace.graph)
print("END trace")
# BEGIN compute2
def compute(x, y, z):
x = torch.ops.my_ops.warp_perspective(x, torch.eye(3))
return x.matmul(y) + torch.relu(z)
# END compute2
print("BEGIN trace2")
inputs = [torch.randn(4, 8), torch.randn(8, 5), torch.randn(8, 5)]
trace = torch.jit.trace(compute, inputs)
print(trace.graph)
print("END trace2")