diff --git a/otp/otp.go b/otp/otp.go index 38f6cdb..ea4e3ad 100644 --- a/otp/otp.go +++ b/otp/otp.go @@ -20,6 +20,7 @@ import ( const ( otpSize = 16 + maxOTPs = 5 // PageLogin is for identifying the login page for parsing & validation PageLogin = "otplogin" @@ -196,14 +197,20 @@ func (o *OTP) AddPost(w http.ResponseWriter, r *http.Request) error { return err } + otpUser := MustBeOTPable(user) + currentOTPs := splitOTPs(otpUser.GetOTPs()) + + if len(currentOTPs) >= maxOTPs { + data := authboss.HTMLData{authboss.DataValidation: fmt.Sprintf("you cannot have more than %d one time passwords", maxOTPs)} + return o.Core.Responder.Respond(w, r, http.StatusOK, PageAdd, data) + } + logger.Infof("generating otp for %s", user.GetPID()) otp, hash, err := generateOTP() if err != nil { return err } - otpUser := MustBeOTPable(user) - currentOTPs := splitOTPs(otpUser.GetOTPs()) currentOTPs = append(currentOTPs, hash) otpUser.PutOTPs(joinOTPs(currentOTPs)) diff --git a/otp/otp_test.go b/otp/otp_test.go index 69f92ee..83e12b5 100644 --- a/otp/otp_test.go +++ b/otp/otp_test.go @@ -459,6 +459,46 @@ func TestAddPost(t *testing.T) { } } +func TestAddPostTooMany(t *testing.T) { + t.Parallel() + + h := testSetup() + uname := "test@test.com" + h.storer.Users[uname] = &mocks.User{ + Email: uname, + OTPs: "2aID,2aID,2aID,2aID,2aID", + } + h.session.ClientValues[authboss.SessionKey] = uname + + r := mocks.Request("POST") + w := h.ab.NewResponse(httptest.NewRecorder()) + + var err error + r, err = h.ab.LoadClientState(w, r) + if err != nil { + t.Fatal(err) + } + + if err := h.otp.AddPost(w, r); err != nil { + t.Fatal(err) + } + + if h.responder.Page != PageAdd { + t.Error("wanted add page, got:", h.responder.Page) + } + if h.responder.Status != http.StatusOK { + t.Error("wanted ok status, got:", h.responder.Status) + } + if len(h.responder.Data[authboss.DataValidation].(string)) == 0 { + t.Error("there should have been a validation error") + } + + otps := splitOTPs(h.storer.Users[uname].OTPs) + if len(otps) != maxOTPs { + t.Error("expected the number of OTPs to be equal to the maximum") + } +} + func TestAddGetUserNotFound(t *testing.T) { t.Parallel()