Skip to content

Commit eb9a68e

Browse files
Add ability to watch secondary resources (#117)
1 parent 1cf1078 commit eb9a68e

File tree

3 files changed

+248
-1
lines changed

3 files changed

+248
-1
lines changed

pkg/controller/watch/watch.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package watch
2+
3+
import (
4+
"github.com/mongodb/mongodb-kubernetes-operator/pkg/util/contains"
5+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
6+
"k8s.io/apimachinery/pkg/types"
7+
"k8s.io/client-go/util/workqueue"
8+
"sigs.k8s.io/controller-runtime/pkg/event"
9+
"sigs.k8s.io/controller-runtime/pkg/reconcile"
10+
)
11+
12+
// ResourceWatcher implements handler.EventHandler and is used to trigger reconciliation when
13+
// a watched object changes. It's designed to only be used for a single type of object.
14+
// If multiple types should be watched, one ResourceWatcher for each type should be used.
15+
type ResourceWatcher struct {
16+
watched map[types.NamespacedName][]types.NamespacedName
17+
}
18+
19+
// New will create a new ResourceWatcher with no watched objects.
20+
func New() ResourceWatcher {
21+
return ResourceWatcher{
22+
watched: make(map[types.NamespacedName][]types.NamespacedName),
23+
}
24+
}
25+
26+
// Watch will add a new object to watch.
27+
func (w ResourceWatcher) Watch(watchedName, dependentName types.NamespacedName) {
28+
existing, hasExisting := w.watched[watchedName]
29+
if !hasExisting {
30+
existing = []types.NamespacedName{}
31+
}
32+
33+
// Check if resource is already being watched.
34+
if contains.NamespacedName(existing, dependentName) {
35+
return
36+
}
37+
38+
w.watched[watchedName] = append(existing, dependentName)
39+
}
40+
41+
func (w ResourceWatcher) Create(event event.CreateEvent, queue workqueue.RateLimitingInterface) {
42+
w.handleEvent(event.Meta, queue)
43+
}
44+
45+
func (w ResourceWatcher) Update(event event.UpdateEvent, queue workqueue.RateLimitingInterface) {
46+
w.handleEvent(event.MetaOld, queue)
47+
}
48+
49+
func (w ResourceWatcher) Delete(event event.DeleteEvent, queue workqueue.RateLimitingInterface) {
50+
w.handleEvent(event.Meta, queue)
51+
}
52+
53+
func (w ResourceWatcher) Generic(event event.GenericEvent, queue workqueue.RateLimitingInterface) {
54+
w.handleEvent(event.Meta, queue)
55+
}
56+
57+
// handleEvent is called when an event is received for an object.
58+
// It will check if the object is being watched and trigger a reconciliation for
59+
// the dependent object.
60+
func (w ResourceWatcher) handleEvent(meta metav1.Object, queue workqueue.RateLimitingInterface) {
61+
changedObjectName := types.NamespacedName{
62+
Name: meta.GetName(),
63+
Namespace: meta.GetNamespace(),
64+
}
65+
66+
// Enqueue reconciliation for each dependent object.
67+
for _, reconciledObjectName := range w.watched[changedObjectName] {
68+
queue.Add(reconcile.Request{
69+
NamespacedName: reconciledObjectName,
70+
})
71+
}
72+
}

pkg/controller/watch/watch_test.go

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
package watch
2+
3+
import (
4+
"testing"
5+
6+
"k8s.io/apimachinery/pkg/types"
7+
8+
mdbv1 "github.com/mongodb/mongodb-kubernetes-operator/pkg/apis/mongodb/v1"
9+
10+
"github.com/stretchr/testify/assert"
11+
12+
"sigs.k8s.io/controller-runtime/pkg/controller/controllertest"
13+
14+
corev1 "k8s.io/api/core/v1"
15+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
16+
"k8s.io/client-go/util/workqueue"
17+
18+
"sigs.k8s.io/controller-runtime/pkg/event"
19+
)
20+
21+
func TestWatcher(t *testing.T) {
22+
obj := &corev1.Pod{
23+
ObjectMeta: metav1.ObjectMeta{
24+
Name: "pod",
25+
Namespace: "namespace",
26+
},
27+
}
28+
objNsName := types.NamespacedName{Name: obj.Name, Namespace: obj.Namespace}
29+
30+
mdb1 := mdbv1.MongoDB{
31+
ObjectMeta: metav1.ObjectMeta{
32+
Name: "mdb1",
33+
Namespace: "namespace",
34+
},
35+
}
36+
37+
mdb2 := mdbv1.MongoDB{
38+
ObjectMeta: metav1.ObjectMeta{
39+
Name: "mdb2",
40+
Namespace: "namespace",
41+
},
42+
}
43+
44+
t.Run("Non-watched object", func(t *testing.T) {
45+
watcher := New()
46+
queue := controllertest.Queue{Interface: workqueue.New()}
47+
48+
watcher.Create(event.CreateEvent{
49+
Meta: obj.GetObjectMeta(),
50+
Object: obj,
51+
}, queue)
52+
53+
// Ensure no reconciliation is queued if object is not watched.
54+
assert.Equal(t, 0, queue.Len())
55+
})
56+
57+
t.Run("Multiple objects to reconile", func(t *testing.T) {
58+
watcher := New()
59+
queue := controllertest.Queue{Interface: workqueue.New()}
60+
watcher.Watch(objNsName, mdb1.NamespacedName())
61+
watcher.Watch(objNsName, mdb2.NamespacedName())
62+
63+
watcher.Create(event.CreateEvent{
64+
Meta: obj.GetObjectMeta(),
65+
Object: obj,
66+
}, queue)
67+
68+
// Ensure multiple reconciliations are enqueued.
69+
assert.Equal(t, 2, queue.Len())
70+
})
71+
72+
t.Run("Create event", func(t *testing.T) {
73+
watcher := New()
74+
queue := controllertest.Queue{Interface: workqueue.New()}
75+
watcher.Watch(objNsName, mdb1.NamespacedName())
76+
77+
watcher.Create(event.CreateEvent{
78+
Meta: obj.GetObjectMeta(),
79+
Object: obj,
80+
}, queue)
81+
82+
assert.Equal(t, 1, queue.Len())
83+
})
84+
85+
t.Run("Update event", func(t *testing.T) {
86+
watcher := New()
87+
queue := controllertest.Queue{Interface: workqueue.New()}
88+
watcher.Watch(objNsName, mdb1.NamespacedName())
89+
90+
watcher.Update(event.UpdateEvent{
91+
MetaOld: obj.GetObjectMeta(),
92+
ObjectOld: obj,
93+
MetaNew: obj.GetObjectMeta(),
94+
ObjectNew: obj,
95+
}, queue)
96+
97+
assert.Equal(t, 1, queue.Len())
98+
})
99+
100+
t.Run("Delete event", func(t *testing.T) {
101+
watcher := New()
102+
queue := controllertest.Queue{Interface: workqueue.New()}
103+
watcher.Watch(objNsName, mdb1.NamespacedName())
104+
105+
watcher.Delete(event.DeleteEvent{
106+
Meta: obj.GetObjectMeta(),
107+
Object: obj,
108+
}, queue)
109+
110+
assert.Equal(t, 1, queue.Len())
111+
})
112+
113+
t.Run("Generic event", func(t *testing.T) {
114+
watcher := New()
115+
queue := controllertest.Queue{Interface: workqueue.New()}
116+
watcher.Watch(objNsName, mdb1.NamespacedName())
117+
118+
watcher.Generic(event.GenericEvent{
119+
Meta: obj.GetObjectMeta(),
120+
Object: obj,
121+
}, queue)
122+
123+
assert.Equal(t, 1, queue.Len())
124+
})
125+
}
126+
127+
func TestWatcherAdd(t *testing.T) {
128+
watcher := New()
129+
assert.Empty(t, watcher.watched)
130+
131+
watchedName := types.NamespacedName{Name: "object", Namespace: "namespace"}
132+
133+
mdb1 := mdbv1.MongoDB{
134+
ObjectMeta: metav1.ObjectMeta{
135+
Name: "mdb1",
136+
Namespace: "namespace",
137+
},
138+
}
139+
mdb2 := mdbv1.MongoDB{
140+
ObjectMeta: metav1.ObjectMeta{
141+
Name: "mdb2",
142+
Namespace: "namespace",
143+
},
144+
}
145+
146+
// Ensure single object can be added to empty watchlist.
147+
watcher.Watch(watchedName, mdb1.NamespacedName())
148+
assert.Len(t, watcher.watched, 1)
149+
assert.Equal(t, []types.NamespacedName{mdb1.NamespacedName()}, watcher.watched[watchedName])
150+
151+
// Ensure object can only be watched once.
152+
watcher.Watch(watchedName, mdb1.NamespacedName())
153+
assert.Len(t, watcher.watched, 1)
154+
assert.Equal(t, []types.NamespacedName{mdb1.NamespacedName()}, watcher.watched[watchedName])
155+
156+
// Ensure a single object can be watched for multiple reconciliations.
157+
watcher.Watch(watchedName, mdb2.NamespacedName())
158+
assert.Len(t, watcher.watched, 1)
159+
assert.Equal(t, []types.NamespacedName{
160+
mdb1.NamespacedName(),
161+
mdb2.NamespacedName(),
162+
}, watcher.watched[watchedName])
163+
}

pkg/util/contains/contains.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package contains
22

3-
import mdbv1 "github.com/mongodb/mongodb-kubernetes-operator/pkg/apis/mongodb/v1"
3+
import (
4+
mdbv1 "github.com/mongodb/mongodb-kubernetes-operator/pkg/apis/mongodb/v1"
5+
"k8s.io/apimachinery/pkg/types"
6+
)
47

58
func String(slice []string, s string) bool {
69
for _, elem := range slice {
@@ -19,3 +22,12 @@ func AuthMode(slice []mdbv1.AuthMode, s mdbv1.AuthMode) bool {
1922
}
2023
return false
2124
}
25+
26+
func NamespacedName(nsNames []types.NamespacedName, nsName types.NamespacedName) bool {
27+
for _, elem := range nsNames {
28+
if elem == nsName {
29+
return true
30+
}
31+
}
32+
return false
33+
}

0 commit comments

Comments
 (0)