forked from amueller/introduction_to_ml_with_python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_animal_tree.py
27 lines (25 loc) · 890 Bytes
/
plot_animal_tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from imageio import imread
import matplotlib.pyplot as plt
def plot_animal_tree(ax=None):
import graphviz
if ax is None:
ax = plt.gca()
mygraph = graphviz.Digraph(node_attr={'shape': 'box'},
edge_attr={'labeldistance': "10.5"},
format="png")
mygraph.node("0", "Has feathers?")
mygraph.node("1", "Can fly?")
mygraph.node("2", "Has fins?")
mygraph.node("3", "Hawk")
mygraph.node("4", "Penguin")
mygraph.node("5", "Dolphin")
mygraph.node("6", "Bear")
mygraph.edge("0", "1", label="True")
mygraph.edge("0", "2", label="False")
mygraph.edge("1", "3", label="True")
mygraph.edge("1", "4", label="False")
mygraph.edge("2", "5", label="True")
mygraph.edge("2", "6", label="False")
mygraph.render("tmp")
ax.imshow(imread("tmp.png"))
ax.set_axis_off()