schema migration and account fixes
very close to rolling this out! just need to address some security concerns first
This commit is contained in:
parent
5566a795da
commit
570cdf6ce2
20 changed files with 641 additions and 392 deletions
|
@ -5,7 +5,6 @@ import (
|
|||
"arimelody-web/model"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
|
@ -17,6 +16,9 @@ func GetAccount(db *sqlx.DB, username string) (*model.Account, error) {
|
|||
|
||||
err := db.Get(&account, "SELECT * FROM account WHERE username=$1", username)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "no rows") {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -28,6 +30,9 @@ func GetAccountByEmail(db *sqlx.DB, email string) (*model.Account, error) {
|
|||
|
||||
err := db.Get(&account, "SELECT * FROM account WHERE email=$1", email)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "no rows") {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -41,7 +46,7 @@ func GetAccountByToken(db *sqlx.DB, token string) (*model.Account, error) {
|
|||
|
||||
err := db.Get(&account, "SELECT account.* FROM account JOIN token ON id=account WHERE token=$1", token)
|
||||
if err != nil {
|
||||
if err.Error() == "sql: no rows in result set" {
|
||||
if strings.Contains(err.Error(), "no rows") {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
|
@ -50,24 +55,28 @@ func GetAccountByToken(db *sqlx.DB, token string) (*model.Account, error) {
|
|||
return &account, nil
|
||||
}
|
||||
|
||||
func GetAccountByRequest(db *sqlx.DB, r *http.Request) (*model.Account, error) {
|
||||
func GetTokenFromRequest(db *sqlx.DB, r *http.Request) string {
|
||||
tokenStr := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
|
||||
|
||||
if tokenStr == "" {
|
||||
cookie, err := r.Cookie(global.COOKIE_TOKEN)
|
||||
if err != nil {
|
||||
// not logged in
|
||||
return nil, nil
|
||||
}
|
||||
tokenStr = cookie.Value
|
||||
if len(tokenStr) > 0 {
|
||||
return tokenStr
|
||||
}
|
||||
|
||||
cookie, err := r.Cookie(global.COOKIE_TOKEN)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cookie.Value
|
||||
}
|
||||
|
||||
func GetAccountByRequest(db *sqlx.DB, r *http.Request) (*model.Account, error) {
|
||||
tokenStr := GetTokenFromRequest(db, r)
|
||||
|
||||
token, err := GetToken(db, tokenStr)
|
||||
if err != nil {
|
||||
if strings.HasPrefix(err.Error(), "sql: no rows") {
|
||||
if strings.Contains(err.Error(), "no rows") {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.New(fmt.Sprintf("GetToken: %s", err.Error()))
|
||||
return nil, errors.New("GetToken: " + err.Error())
|
||||
}
|
||||
|
||||
// does user-agent match the token?
|
||||
|
@ -83,42 +92,36 @@ func GetAccountByRequest(db *sqlx.DB, r *http.Request) (*model.Account, error) {
|
|||
}
|
||||
|
||||
func CreateAccount(db *sqlx.DB, account *model.Account) error {
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO account (username, password, email, avatar_url) " +
|
||||
"VALUES ($1, $2, $3, $4)",
|
||||
account.Username,
|
||||
account.Password,
|
||||
account.Email,
|
||||
account.AvatarURL)
|
||||
err := db.Get(
|
||||
&account.ID,
|
||||
"INSERT INTO account (username, password, email, avatar_url) " +
|
||||
"VALUES ($1, $2, $3, $4) " +
|
||||
"RETURNING id",
|
||||
account.Username,
|
||||
account.Password,
|
||||
account.Email,
|
||||
account.AvatarURL,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func UpdateAccount(db *sqlx.DB, account *model.Account) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE account " +
|
||||
"SET username=$2, password=$3, email=$4, avatar_url=$5) " +
|
||||
"WHERE id=$1",
|
||||
account.ID,
|
||||
account.Username,
|
||||
account.Password,
|
||||
account.Email,
|
||||
account.AvatarURL)
|
||||
"UPDATE account " +
|
||||
"SET username=$2, password=$3, email=$4, avatar_url=$5) " +
|
||||
"WHERE id=$1",
|
||||
account.ID,
|
||||
account.Username,
|
||||
account.Password,
|
||||
account.Email,
|
||||
account.AvatarURL,
|
||||
)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func DeleteAccount(db *sqlx.DB, accountID string) error {
|
||||
_, err := db.Exec("DELETE FROM account WHERE id=$1", accountID)
|
||||
func DeleteAccount(db *sqlx.DB, username string) error {
|
||||
_, err := db.Exec("DELETE FROM account WHERE username=$1", username)
|
||||
return err
|
||||
}
|
||||
|
||||
var inviteChars = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
|
||||
|
||||
func GenerateInviteCode(length int) []byte {
|
||||
code := []byte{}
|
||||
for i := 0; i < length; i++ {
|
||||
code = append(code, inviteChars[rand.Intn(len(inviteChars) - 1)])
|
||||
}
|
||||
return code
|
||||
}
|
||||
|
|
67
controller/invite.go
Normal file
67
controller/invite.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"arimelody-web/model"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
var inviteChars = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
|
||||
|
||||
func GetInvite(db *sqlx.DB, code string) (*model.Invite, error) {
|
||||
invite := model.Invite{}
|
||||
|
||||
err := db.Get(&invite, "SELECT * FROM invite WHERE code=$1", code)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "no rows") {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &invite, nil
|
||||
}
|
||||
|
||||
func CreateInvite(db *sqlx.DB, length int, lifetime time.Duration) (*model.Invite, error) {
|
||||
invite := model.Invite{
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(lifetime),
|
||||
}
|
||||
|
||||
code := []byte{}
|
||||
for i := 0; i < length; i++ {
|
||||
code = append(code, inviteChars[rand.Intn(len(inviteChars) - 1)])
|
||||
}
|
||||
invite.Code = string(code)
|
||||
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO invite (code, created_at, expires_at) " +
|
||||
"VALUES ($1, $2, $3)",
|
||||
invite.Code,
|
||||
invite.CreatedAt,
|
||||
invite.ExpiresAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &invite, nil
|
||||
}
|
||||
|
||||
func DeleteInvite(db *sqlx.DB, code string) error {
|
||||
_, err := db.Exec("DELETE FROM invite WHERE code=$1", code)
|
||||
return err
|
||||
}
|
||||
|
||||
func DeleteAllInvites(db *sqlx.DB) error {
|
||||
_, err := db.Exec("DELETE FROM invite")
|
||||
return err
|
||||
}
|
||||
|
||||
func DeleteExpiredInvites(db *sqlx.DB) error {
|
||||
_, err := db.Exec("DELETE FROM invite WHERE expires_at<current_timestamp")
|
||||
return err
|
||||
}
|
86
controller/migrator.go
Normal file
86
controller/migrator.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
const DB_VERSION int = 2
|
||||
|
||||
func CheckDBVersionAndMigrate(db *sqlx.DB) {
|
||||
db.MustExec("CREATE SCHEMA IF NOT EXISTS arimelody")
|
||||
db.MustExec("SET search_path TO arimelody, public")
|
||||
db.MustExec(
|
||||
"CREATE TABLE IF NOT EXISTS arimelody.schema_version (" +
|
||||
"version INTEGER PRIMARY KEY, " +
|
||||
"applied_at TIMESTAMP DEFAULT current_timestamp)",
|
||||
)
|
||||
|
||||
oldDBVersion := 0
|
||||
|
||||
err := db.Get(&oldDBVersion, "SELECT MAX(version) FROM schema_version")
|
||||
if err != nil { panic(err) }
|
||||
|
||||
for oldDBVersion < DB_VERSION {
|
||||
switch oldDBVersion {
|
||||
case 0:
|
||||
// default case; assume no database exists
|
||||
ApplyMigration(db, "000-init")
|
||||
oldDBVersion = DB_VERSION
|
||||
|
||||
case 1:
|
||||
// the irony is i actually have to awkwardly shove schema_version
|
||||
// into the old database in order for this to work LOL
|
||||
ApplyMigration(db, "001-pre-versioning")
|
||||
oldDBVersion = 2
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Printf("Database schema up to date.\n")
|
||||
}
|
||||
|
||||
func ApplyMigration(db *sqlx.DB, scriptFile string) {
|
||||
fmt.Printf("Applying schema migration %s...\n", scriptFile)
|
||||
|
||||
bytes, err := os.ReadFile("schema_migration/" + scriptFile + ".sql")
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "FATAL: Failed to open schema file \"%s\": %v\n", scriptFile, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
script := string(bytes)
|
||||
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "FATAL: Failed to begin migration: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(script)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
fmt.Fprintf(os.Stderr, "FATAL: Failed to apply migration: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(
|
||||
"INSERT INTO schema_version (version, applied_at) " +
|
||||
"VALUES ($1, $2)",
|
||||
DB_VERSION,
|
||||
time.Now(),
|
||||
)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
fmt.Fprintf(os.Stderr, "FATAL: Failed to update schema version: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "FATAL: Failed to commit transaction: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue