terrible no good massive refactor commit (oh yeah and built generic sessions for admin panel)
This commit is contained in:
parent
cee99a6932
commit
45f33b8b46
34 changed files with 740 additions and 654 deletions
|
@ -2,8 +2,6 @@ package controller
|
|||
|
||||
import (
|
||||
"arimelody-web/model"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
|
@ -21,7 +19,21 @@ func GetAllAccounts(db *sqlx.DB) ([]model.Account, error) {
|
|||
return accounts, nil
|
||||
}
|
||||
|
||||
func GetAccount(db *sqlx.DB, username string) (*model.Account, error) {
|
||||
func GetAccountByID(db *sqlx.DB, id string) (*model.Account, error) {
|
||||
var account = model.Account{}
|
||||
|
||||
err := db.Get(&account, "SELECT * FROM account WHERE id=$1", id)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "no rows") {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func GetAccountByUsername(db *sqlx.DB, username string) (*model.Account, error) {
|
||||
var account = model.Account{}
|
||||
|
||||
err := db.Get(&account, "SELECT * FROM account WHERE username=$1", username)
|
||||
|
@ -49,12 +61,12 @@ func GetAccountByEmail(db *sqlx.DB, email string) (*model.Account, error) {
|
|||
return &account, nil
|
||||
}
|
||||
|
||||
func GetAccountByToken(db *sqlx.DB, token string) (*model.Account, error) {
|
||||
if token == "" { return nil, nil }
|
||||
func GetAccountBySession(db *sqlx.DB, sessionToken string) (*model.Account, error) {
|
||||
if sessionToken == "" { return nil, nil }
|
||||
|
||||
account := model.Account{}
|
||||
|
||||
err := db.Get(&account, "SELECT account.* FROM account JOIN token ON id=account WHERE token=$1", token)
|
||||
err := db.Get(&account, "SELECT account.* FROM account JOIN token ON id=account WHERE token=$1", sessionToken)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "no rows") {
|
||||
return nil, nil
|
||||
|
@ -65,7 +77,7 @@ func GetAccountByToken(db *sqlx.DB, token string) (*model.Account, error) {
|
|||
return &account, nil
|
||||
}
|
||||
|
||||
func GetTokenFromRequest(db *sqlx.DB, r *http.Request) string {
|
||||
func GetSessionFromRequest(db *sqlx.DB, r *http.Request) string {
|
||||
tokenStr := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
|
||||
if len(tokenStr) > 0 {
|
||||
return tokenStr
|
||||
|
@ -78,29 +90,6 @@ func GetTokenFromRequest(db *sqlx.DB, r *http.Request) string {
|
|||
return cookie.Value
|
||||
}
|
||||
|
||||
func GetAccountByRequest(db *sqlx.DB, r *http.Request) (*model.Account, error) {
|
||||
tokenStr := GetTokenFromRequest(db, r)
|
||||
|
||||
token, err := GetToken(db, tokenStr)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "no rows") {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.New("GetToken: " + err.Error())
|
||||
}
|
||||
|
||||
// does user-agent match the token?
|
||||
if r.UserAgent() != token.UserAgent {
|
||||
// invalidate the token
|
||||
DeleteToken(db, tokenStr)
|
||||
fmt.Printf("WARN: Attempted use of token by unauthorised User-Agent (Expected `%s`, got `%s`)\n", token.UserAgent, r.UserAgent())
|
||||
// TODO: log unauthorised activity to the user
|
||||
return nil, errors.New("User agent mismatch")
|
||||
}
|
||||
|
||||
return GetAccountByToken(db, tokenStr)
|
||||
}
|
||||
|
||||
func CreateAccount(db *sqlx.DB, account *model.Account) error {
|
||||
err := db.Get(
|
||||
&account.ID,
|
||||
|
|
128
controller/session.go
Normal file
128
controller/session.go
Normal file
|
@ -0,0 +1,128 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"arimelody-web/model"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
const TOKEN_LEN = 64
|
||||
|
||||
func CreateSession(db *sqlx.DB, userAgent string) (*model.Session, error) {
|
||||
tokenString := GenerateAlnumString(TOKEN_LEN)
|
||||
|
||||
session := model.Session{
|
||||
Token: string(tokenString),
|
||||
UserAgent: userAgent,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour * 24),
|
||||
}
|
||||
|
||||
_, err := db.Exec("INSERT INTO session " +
|
||||
"(token, user_agent, created_at, expires_at) VALUES " +
|
||||
"($1, $2, $3, $4)",
|
||||
session.Token,
|
||||
session.UserAgent,
|
||||
session.CreatedAt,
|
||||
session.ExpiresAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// func WriteSession(db *sqlx.DB, session *model.Session) error {
|
||||
// _, err := db.Exec(
|
||||
// "UPDATE session " +
|
||||
// "SET account=$2,message=$3,error=$4 " +
|
||||
// "WHERE token=$1",
|
||||
// session.Token,
|
||||
// session.Account.ID,
|
||||
// session.Message,
|
||||
// session.Error,
|
||||
// )
|
||||
// return err
|
||||
// }
|
||||
|
||||
func SetSessionAccount(db *sqlx.DB, session *model.Session, account *model.Account) error {
|
||||
var err error
|
||||
session.Account = account
|
||||
if account == nil {
|
||||
_, err = db.Exec("UPDATE session SET account=NULL WHERE token=$1", session.Token)
|
||||
} else {
|
||||
_, err = db.Exec("UPDATE session SET account=$2 WHERE token=$1", session.Token, account.ID)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func SetSessionMessage(db *sqlx.DB, session *model.Session, message string) error {
|
||||
var err error
|
||||
if message == "" {
|
||||
session.Message = sql.NullString{ }
|
||||
_, err = db.Exec("UPDATE session SET message=NULL WHERE token=$1", session.Token)
|
||||
} else {
|
||||
session.Message = sql.NullString{ String: message, Valid: true }
|
||||
_, err = db.Exec("UPDATE session SET message=$2 WHERE token=$1", session.Token, message)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func SetSessionError(db *sqlx.DB, session *model.Session, message string) error {
|
||||
var err error
|
||||
if message == "" {
|
||||
session.Message = sql.NullString{ }
|
||||
_, err = db.Exec("UPDATE session SET error=NULL WHERE token=$1", session.Token)
|
||||
} else {
|
||||
session.Message = sql.NullString{ String: message, Valid: true }
|
||||
_, err = db.Exec("UPDATE session SET error=$2 WHERE token=$1", session.Token, message)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func GetSession(db *sqlx.DB, token string) (*model.Session, error) {
|
||||
type dbSession struct {
|
||||
model.Session
|
||||
AccountID sql.NullString `db:"account"`
|
||||
}
|
||||
|
||||
session := dbSession{}
|
||||
err := db.Get(
|
||||
&session,
|
||||
"SELECT * FROM session WHERE token=$1",
|
||||
token,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if session.AccountID.Valid {
|
||||
session.Account, err = GetAccountByID(db, session.AccountID.String)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &session.Session, err
|
||||
}
|
||||
|
||||
// func GetAllSessionsForAccount(db *sqlx.DB, accountID string) ([]model.Session, error) {
|
||||
// sessions := []model.Session{}
|
||||
// err := db.Select(&sessions, "SELECT * FROM session WHERE account=$1 AND expires_at>current_timestamp", accountID)
|
||||
// return sessions, err
|
||||
// }
|
||||
|
||||
func DeleteAllSessionsForAccount(db *sqlx.DB, accountID string) error {
|
||||
_, err := db.Exec("DELETE FROM session WHERE account=$1", accountID)
|
||||
return err
|
||||
}
|
||||
|
||||
func DeleteSession(db *sqlx.DB, token string) error {
|
||||
_, err := db.Exec("DELETE FROM session WHERE token=$1", token)
|
||||
return err
|
||||
}
|
||||
|
|
@ -1,61 +0,0 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"arimelody-web/model"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
const TOKEN_LEN = 32
|
||||
|
||||
func CreateToken(db *sqlx.DB, accountID string, userAgent string) (*model.Token, error) {
|
||||
tokenString := GenerateAlnumString(TOKEN_LEN)
|
||||
|
||||
token := model.Token{
|
||||
Token: string(tokenString),
|
||||
AccountID: accountID,
|
||||
UserAgent: userAgent,
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(time.Hour * 24),
|
||||
}
|
||||
|
||||
_, err := db.Exec("INSERT INTO token " +
|
||||
"(token, account, user_agent, created_at, expires_at) VALUES " +
|
||||
"($1, $2, $3, $4, $5)",
|
||||
token.Token,
|
||||
token.AccountID,
|
||||
token.UserAgent,
|
||||
token.CreatedAt,
|
||||
token.ExpiresAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func GetToken(db *sqlx.DB, token_str string) (*model.Token, error) {
|
||||
token := model.Token{}
|
||||
err := db.Get(&token, "SELECT * FROM token WHERE token=$1", token_str)
|
||||
return &token, err
|
||||
}
|
||||
|
||||
func GetAllTokensForAccount(db *sqlx.DB, accountID string) ([]model.Token, error) {
|
||||
tokens := []model.Token{}
|
||||
err := db.Select(&tokens, "SELECT * FROM token WHERE account=$1 AND expires_at>current_timestamp", accountID)
|
||||
return tokens, err
|
||||
}
|
||||
|
||||
func DeleteAllTokensForAccount(db *sqlx.DB, accountID string) error {
|
||||
_, err := db.Exec("DELETE FROM token WHERE account=$1", accountID)
|
||||
return err
|
||||
}
|
||||
|
||||
func DeleteToken(db *sqlx.DB, token string) error {
|
||||
_, err := db.Exec("DELETE FROM token WHERE token=$1", token)
|
||||
return err
|
||||
}
|
||||
|
|
@ -17,7 +17,7 @@ import (
|
|||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
const TOTP_SECRET_LENGTH = 32
|
||||
const TOTP_SECRET_LENGTH = 64
|
||||
const TIME_STEP int64 = 30
|
||||
const CODE_LENGTH = 6
|
||||
|
||||
|
@ -89,6 +89,24 @@ func GetTOTPsForAccount(db *sqlx.DB, accountID string) ([]model.TOTP, error) {
|
|||
return totps, nil
|
||||
}
|
||||
|
||||
func CheckTOTPForAccount(db *sqlx.DB, accountID string, totp string) (*model.TOTP, error) {
|
||||
totps, err := GetTOTPsForAccount(db, accountID)
|
||||
if err != nil {
|
||||
// user has no TOTP methods
|
||||
return nil, err
|
||||
}
|
||||
for _, method := range totps {
|
||||
check := GenerateTOTP(method.Secret, 0)
|
||||
if check == totp {
|
||||
// return the whole TOTP method as it may be useful for logging
|
||||
return &method, nil
|
||||
}
|
||||
}
|
||||
// user failed all TOTP checks
|
||||
// note: this state will still occur even if the account has no TOTP methods.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func GetTOTP(db *sqlx.DB, accountID string, name string) (*model.TOTP, error) {
|
||||
totp := model.TOTP{}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue