Skip to content
This repository was archived by the owner on Jan 29, 2025. It is now read-only.

Commit bc5396c

Browse files
togashidmkillianmuldoon
authored andcommitted
Create a middleware handler
1 parent 9f0c9ea commit bc5396c

File tree

1 file changed

+47
-17
lines changed

1 file changed

+47
-17
lines changed

pkg/scheduler/scheduler.go

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -163,19 +163,6 @@ func (m MetricsExtender) getPolicyFromPod(pod *v1.Pod) (telemetrypolicy.TASPolic
163163
//prescheduleChecks performs checks to ensure a pod is suitable for the extender.
164164
//this method will return pods as supplied if they have no declared policy
165165
func (m MetricsExtender) prescheduleChecks(w http.ResponseWriter, r *http.Request) (ExtenderArgs, http.ResponseWriter, error) {
166-
if r.Method != http.MethodPost {
167-
w.WriteHeader(http.StatusMethodNotAllowed)
168-
return ExtenderArgs{}, w, errors.New("method Type not POST")
169-
}
170-
if r.ContentLength > 1*1000*1000*1000 {
171-
w.WriteHeader(http.StatusInternalServerError)
172-
return ExtenderArgs{}, w, errors.New("request size too large")
173-
}
174-
requestContentType := r.Header.Get("Content-Type")
175-
if requestContentType != "application/json" {
176-
w.WriteHeader(http.StatusNotFound)
177-
return ExtenderArgs{}, w, errors.New("request content type not application/json")
178-
}
179166
extenderArgs, err := m.decodeExtenderRequest(r)
180167
if err != nil {
181168
log.Printf("cannot decode request %v", err)
@@ -256,9 +243,52 @@ func (m MetricsExtender) Filter(w http.ResponseWriter, r *http.Request) {
256243
}
257244
m.writeFilterResponse(w, filteredNodes)
258245
}
246+
//postOnly check if the method type is POST
247+
func postOnly(next http.HandlerFunc) http.HandlerFunc {
248+
return func(w http.ResponseWriter, r *http.Request) {
249+
if r.Method != http.MethodPost {
250+
w.WriteHeader(http.StatusMethodNotAllowed)
251+
log.Print("method Type not POST")
252+
return
253+
}
254+
next.ServeHTTP(w, r)
255+
}
256+
}
257+
258+
//contentLength check the if the request size is adequate
259+
func contentLength(next http.HandlerFunc) http.HandlerFunc {
260+
return func(w http.ResponseWriter, r *http.Request) {
261+
if r.ContentLength > 1*1000*1000*1000 {
262+
w.WriteHeader(http.StatusInternalServerError)
263+
log.Print("request size too large")
264+
return
265+
}
266+
next.ServeHTTP(w, r)
267+
}
268+
}
269+
270+
//requestContentType verify the content type of the request
271+
func requestContentType(next http.HandlerFunc) http.HandlerFunc {
272+
return func(w http.ResponseWriter, r *http.Request) {
273+
requestContentType := r.Header.Get("Content-Type")
274+
if requestContentType != "application/json" {
275+
w.WriteHeader(http.StatusNotFound)
276+
log.Print("request size too large")
277+
return
278+
}
279+
next.ServeHTTP(w, r)
280+
}
281+
}
282+
283+
//handlerWithMiddleware is handler wrapped with middleware to serve the prechecks at endpoint
284+
func handlerWithMiddleware(handle http.HandlerFunc) http.HandlerFunc {
285+
return requestContentType(
286+
contentLength(
287+
postOnly(handle)))
288+
}
259289

260290
//error handler deals with requests sent to an invalid endpoint and returns a 404.
261-
func (m MetricsExtender) errorHandler(w http.ResponseWriter, r *http.Request) {
291+
func errorHandler(w http.ResponseWriter, r *http.Request) {
262292
log.Print("unknown path")
263293
w.Header().Add("Content-Type", "application/json")
264294
w.WriteHeader(http.StatusNotFound)
@@ -279,9 +309,9 @@ func checkSymLinks(filename string) error {
279309
// StartServer starts the HTTP server needed for scheduler.
280310
// It registers the handlers and checks for existing telemetry policies.
281311
func (m MetricsExtender) StartServer(port string, certFile string, keyFile string, caFile string, unsafe bool) {
282-
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { m.errorHandler(w, r) })
283-
http.HandleFunc("/scheduler/prioritize", func(w http.ResponseWriter, r *http.Request) { m.Prioritize(w, r) })
284-
http.HandleFunc("/scheduler/filter", func(w http.ResponseWriter, r *http.Request) { m.Filter(w, r) })
312+
http.HandleFunc("/", handlerWithMiddleware(errorHandler))
313+
http.HandleFunc("/scheduler/prioritize", handlerWithMiddleware(m.Prioritize))
314+
http.HandleFunc("/scheduler/filter", handlerWithMiddleware(m.Filter))
285315
var err error
286316
if unsafe {
287317
log.Printf("Extender Listening on HTTP %v", port)

0 commit comments

Comments
 (0)