refactor(gateway): 重构 main.go 文件

- 移除日志初始化代码,改为使用 utils.InitLogger()
- 删除用户模型定义,移至 models 包
- 抽离路由处理逻辑到 handlers 包
- 使用 middleware 包中的 AuthRequired 中间件
- 优化数据库连接和迁移逻辑
- 简化 main 函数,提高代码可读性和维护性
This commit is contained in:
高手 2025-02-15 20:16:44 +08:00
parent 4446e5aeaa
commit 5cb134fa9d
7 changed files with 271 additions and 248 deletions

147
gateway/handlers/auth.go Normal file
View File

@ -0,0 +1,147 @@
package handlers
import (
"net/http"
"strconv"
"strings"
"gateway/models"
"gateway/utils"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"golang.org/x/crypto/bcrypt"
)
func GetLogin(c *gin.Context) {
c.HTML(http.StatusOK, "login.html", nil)
}
func PostLogin(db *gorm.DB) gin.HandlerFunc {
return func(c *gin.Context) {
username := c.PostForm("username")
password := c.PostForm("password")
var user models.User
if err := db.Where("mobile = ?", username).First(&user).Error; err != nil {
c.HTML(http.StatusUnauthorized, "login.html", gin.H{"error": "用户不存在或密码错误"})
return
}
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
c.HTML(http.StatusUnauthorized, "login.html", gin.H{"error": "用户不存在或密码错误"})
return
}
session := sessions.Default(c)
session.Set("user", user.ID)
if err := session.Save(); err != nil {
utils.Logger.Errorf("Session保存失败: %v", err)
c.HTML(http.StatusInternalServerError, "login.html", gin.H{"error": "登录状态保存失败"})
return
}
c.Redirect(http.StatusSeeOther, "/")
}
}
func GetRegister(db *gorm.DB) gin.HandlerFunc {
return func(c *gin.Context) {
regions, err := getRegions(db)
if err != nil {
c.HTML(http.StatusInternalServerError, "register.html", gin.H{"error": "系统错误"})
return
}
c.HTML(http.StatusOK, "register.html", gin.H{"regions": regions})
}
}
func PostRegister(db *gorm.DB) gin.HandlerFunc {
return func(c *gin.Context) {
regionStr, exists := c.GetPostForm("region")
if !exists {
handleRegisterError(c, db, "请选择所在地区")
return
}
regionID, err := strconv.ParseUint(regionStr, 10, 32)
if err != nil {
handleRegisterError(c, db, "无效的地区参数")
return
}
user := models.User{
FullName: c.PostForm("fullname"),
Mobile: c.PostForm("mobile"),
RegionID: uint(regionID),
}
var region models.Region
if err := db.First(&region, user.RegionID).Error; err != nil {
handleRegisterError(c, db, "请选择有效地区")
return
}
if len(user.Mobile) != 11 {
handleRegisterError(c, db, "手机号格式不正确")
return
}
password := c.PostForm("password")
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
utils.Logger.Errorf("密码加密失败: %v", err)
handleRegisterError(c, db, "注册失败")
return
}
user.Password = string(hashedPassword)
if err := db.Create(&user).Error; err != nil {
utils.Logger.Errorf("用户创建失败: %v", err)
errorMsg := "注册失败"
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
errorMsg = "该手机号已注册"
}
handleRegisterError(c, db, errorMsg)
return
}
c.Redirect(http.StatusSeeOther, "/login")
}
}
func Logout(c *gin.Context) {
session := sessions.Default(c)
session.Clear()
if err := session.Save(); err != nil {
utils.Logger.Errorf("退出登录失败: %v", err)
c.HTML(http.StatusInternalServerError, "error.html", gin.H{"error": "退出登录失败"})
return
}
c.Redirect(http.StatusSeeOther, "/login")
}
func getRegions(db *gorm.DB) ([]models.Region, error) {
var regions []models.Region
if err := db.Find(&regions).Error; err != nil {
utils.Logger.Errorf("获取地区数据失败: %v", err)
return nil, err
}
return regions, nil
}
func handleRegisterError(c *gin.Context, db *gorm.DB, errorMessage string) {
regions, err := getRegions(db)
if err != nil {
c.HTML(http.StatusInternalServerError, "register.html", gin.H{"error": "系统错误"})
return
}
c.HTML(http.StatusBadRequest, "register.html", gin.H{
"error": errorMessage,
"regions": regions,
"form": gin.H{
"fullname": c.PostForm("fullname"),
"mobile": c.PostForm("mobile"),
},
})
}

View File

@ -0,0 +1,27 @@
package handlers
import (
"net/http"
"os"
"gateway/utils"
"github.com/gin-gonic/gin"
)
// 处理首页请求
func ServeIndex(c *gin.Context) {
utils.LogAccess(c.ClientIP(), c.Request.URL.Path, c.Request.Method)
http.ServeFile(c.Writer, c.Request, "./static/index.html")
}
// 处理静态文件请求
func ServeStatic(c *gin.Context) {
utils.LogAccess(c.ClientIP(), c.Request.URL.Path, c.Request.Method)
filePath := "./static" + c.Request.URL.Path
if _, err := os.Stat(filePath); err == nil {
http.ServeFile(c.Writer, c.Request, filePath)
} else {
c.AbortWithStatus(http.StatusNotFound)
}
}

View File

@ -2,70 +2,31 @@ package main
import (
"net/http"
"os"
"strconv"
"strings"
"time"
"fmt"
"io"
"path/filepath"
"gateway/handlers"
"gateway/middleware"
"gateway/models"
"gateway/utils"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/sqlite"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
)
var logger = logrus.New()
// 在全局变量后新增用户模型
type Region struct {
ID uint `gorm:"primary_key"`
Name string `gorm:"not null;unique"`
}
type User struct {
gorm.Model
FullName string `gorm:"not null"`
RegionID uint `gorm:"not null"` // 修改为关联地区ID
Mobile string `gorm:"unique;not null"`
Password string `gorm:"not null"`
Region Region // 关联关系
}
func init() {
// 确保log目录存在
if err := os.MkdirAll("log", 0755); err != nil {
panic(fmt.Sprintf("创建日志目录失败: %v", err))
}
// 生成日志文件名 (格式: log/2024-03-21.log)
logFileName := filepath.Join("log", time.Now().Format("2006-01-02")+".log")
// 打开日志文件
logFile, err := os.OpenFile(logFileName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
panic(fmt.Sprintf("打开日志文件失败: %v", err))
}
// 配置日志格式
logger.SetFormatter(&logrus.JSONFormatter{})
// 同时输出到文件和标准输出
logger.SetOutput(io.MultiWriter(os.Stdout, logFile))
}
func main() {
// 初始化日志
utils.InitLogger()
// 初始化数据库
db, err := gorm.Open("sqlite3", "family.db")
if err != nil {
logger.Fatalf("数据库连接失败: %v", err)
utils.Logger.Fatalf("数据库连接失败: %v", err)
}
defer db.Close()
db.AutoMigrate(&Region{}, &User{}) // 同时迁移Region和User表
db.AutoMigrate(&models.Region{}, &models.User{})
// 初始化 Gin 引擎
r := gin.Default()
@ -84,214 +45,24 @@ func main() {
})
r.Use(sessions.Sessions("mysession", store))
// 登录页面
r.GET("/login", func(c *gin.Context) {
c.HTML(http.StatusOK, "login.html", nil)
})
// 处理登录请求
r.POST("/login", func(c *gin.Context) {
username := c.PostForm("username") // 修改表单字段名
password := c.PostForm("password")
var user User
// 电话号作为用户名
if err := db.Where("mobile = ?", username).First(&user).Error; err != nil {
c.HTML(http.StatusUnauthorized, "login.html", gin.H{"error": "用户不存在或密码错误"})
return
}
// 验证密码
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
c.HTML(http.StatusUnauthorized, "login.html", gin.H{"error": "用户不存在或密码错误"})
return
}
// 保存session保持原有逻辑
session := sessions.Default(c)
session.Set("user", user.ID)
if err := session.Save(); err != nil {
logger.Errorf("Session保存失败: %v", err)
c.HTML(http.StatusInternalServerError, "login.html", gin.H{"error": "登录状态保存失败"})
return
}
c.Redirect(http.StatusSeeOther, "/") // 改用303状态码
})
// 在登录路由后新增注册路由
// 注册页面
r.GET("/register", func(c *gin.Context) {
regions, err := getRegions(db)
if err != nil {
c.HTML(http.StatusInternalServerError, "register.html", gin.H{"error": "系统错误"})
return
}
c.HTML(http.StatusOK, "register.html", gin.H{"regions": regions})
})
// 处理注册请求
r.POST("/register", func(c *gin.Context) {
// 获取地区参数(修改这部分)
regionStr, exists := c.GetPostForm("region")
if !exists {
regions, err := getRegions(db)
if err != nil {
c.HTML(http.StatusInternalServerError, "register.html", gin.H{"error": "系统错误"})
return
}
c.HTML(http.StatusBadRequest, "register.html", gin.H{
"error": "请选择所在地区",
"regions": regions,
"form": gin.H{
"fullname": c.PostForm("fullname"),
"mobile": c.PostForm("mobile"),
},
})
return
}
// 转换地区ID为数字
regionID, err := strconv.ParseUint(regionStr, 10, 32)
if err != nil {
regions, err := getRegions(db)
if err != nil {
c.HTML(http.StatusInternalServerError, "register.html", gin.H{"error": "系统错误"})
return
}
c.HTML(http.StatusBadRequest, "register.html", gin.H{
"error": "无效的地区参数",
"regions": regions,
"form": gin.H{
"fullname": c.PostForm("fullname"),
"mobile": c.PostForm("mobile"),
},
})
return
}
user := User{
FullName: c.PostForm("fullname"),
Mobile: c.PostForm("mobile"),
RegionID: uint(regionID), // 使用转换后的ID
}
// 验证地区是否存在
var region Region
if err := db.First(&region, user.RegionID).Error; err != nil {
regions, err := getRegions(db)
if err != nil {
c.HTML(http.StatusInternalServerError, "register.html", gin.H{"error": "系统错误"})
return
}
c.HTML(http.StatusBadRequest, "register.html", gin.H{
"error": "请选择有效地区",
"regions": regions,
"form": gin.H{
"fullname": c.PostForm("fullname"),
"mobile": c.PostForm("mobile"),
},
})
return
}
// 验证手机号格式
if len(user.Mobile) != 11 {
c.HTML(http.StatusBadRequest, "register.html", gin.H{"error": "手机号格式不正确"})
return
}
// 密码加密
password := c.PostForm("password")
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
logger.Errorf("密码加密失败: %v", err)
c.HTML(http.StatusInternalServerError, "register.html", gin.H{"error": "注册失败"})
return
}
user.Password = string(hashedPassword)
// 创建用户
if err := db.Create(&user).Error; err != nil {
logger.Errorf("用户创建失败: %v", err)
errorMsg := "注册失败"
if strings.Contains(err.Error(), "UNIQUE constraint failed") {
errorMsg = "该手机号已注册"
}
c.HTML(http.StatusBadRequest, "register.html", gin.H{"error": errorMsg})
return
}
c.Redirect(http.StatusSeeOther, "/login")
})
// 路由配置
r.GET("/login", handlers.GetLogin)
r.POST("/login", handlers.PostLogin(db))
r.GET("/register", handlers.GetRegister(db))
r.POST("/register", handlers.PostRegister(db))
// 权限校验中间件
authMiddleware := func(c *gin.Context) {
session := sessions.Default(c)
user := session.Get("user")
if user == nil {
c.Redirect(http.StatusFound, "/login")
c.Abort()
return
}
c.Next()
}
r.Use(middleware.AuthRequired())
// 文档页面路由
r.Use(authMiddleware)
r.GET("/", func(c *gin.Context) {
// 记录访问痕迹
logAccess(c)
http.ServeFile(c.Writer, c.Request, "./static/index.html")
})
r.GET("/", handlers.ServeIndex)
// 在权限校验中间件后添加退出路由
r.GET("/logout", func(c *gin.Context) {
session := sessions.Default(c)
session.Clear()
if err := session.Save(); err != nil {
logger.Errorf("退出登录失败: %v", err)
c.HTML(http.StatusInternalServerError, "error.html", gin.H{"error": "退出登录失败"})
return
}
c.Redirect(http.StatusSeeOther, "/login")
})
r.GET("/logout", handlers.Logout)
// 新增通用静态文件路由(放在其他路由之后)
r.NoRoute(func(c *gin.Context) {
logAccess(c)
filePath := "./static" + c.Request.URL.Path
// 检查文件是否存在
if _, err := os.Stat(filePath); err == nil {
http.ServeFile(c.Writer, c.Request, filePath)
} else {
c.AbortWithStatus(http.StatusNotFound)
}
})
r.NoRoute(handlers.ServeStatic)
// 启动服务
r.Run(":7070")
}
// 记录访问痕迹
func logAccess(c *gin.Context) {
ip := c.ClientIP()
path := c.Request.URL.Path
method := c.Request.Method
timestamp := time.Now().Format(time.RFC3339)
logger.WithFields(logrus.Fields{
"ip": ip,
"path": path,
"method": method,
"timestamp": timestamp,
}).Info("Page accessed")
}
func getRegions(db *gorm.DB) ([]Region, error) {
var regions []Region
if err := db.Find(&regions).Error; err != nil {
logger.Errorf("获取地区数据失败: %v", err)
return nil, err
}
return regions, nil
}

View File

@ -0,0 +1,21 @@
package middleware
import (
"net/http"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
func AuthRequired() gin.HandlerFunc {
return func(c *gin.Context) {
session := sessions.Default(c)
user := session.Get("user")
if user == nil {
c.Redirect(http.StatusFound, "/login")
c.Abort()
return
}
c.Next()
}
}

6
gateway/models/region.go Normal file
View File

@ -0,0 +1,6 @@
package models
type Region struct {
ID uint `gorm:"primary_key"`
Name string `gorm:"not null;unique"`
}

14
gateway/models/user.go Normal file
View File

@ -0,0 +1,14 @@
package models
import (
"github.com/jinzhu/gorm"
)
type User struct {
gorm.Model
FullName string `gorm:"not null"`
RegionID uint `gorm:"not null"`
Mobile string `gorm:"unique;not null"`
Password string `gorm:"not null"`
Region Region
}

37
gateway/utils/logger.go Normal file
View File

@ -0,0 +1,37 @@
package utils
import (
"fmt"
"io"
"os"
"path/filepath"
"time"
"github.com/sirupsen/logrus"
)
var Logger = logrus.New()
func InitLogger() {
if err := os.MkdirAll("log", 0755); err != nil {
panic(fmt.Sprintf("创建日志目录失败: %v", err))
}
logFileName := filepath.Join("log", time.Now().Format("2006-01-02")+".log")
logFile, err := os.OpenFile(logFileName, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
panic(fmt.Sprintf("打开日志文件失败: %v", err))
}
Logger.SetFormatter(&logrus.JSONFormatter{})
Logger.SetOutput(io.MultiWriter(os.Stdout, logFile))
}
func LogAccess(ip, path, method string) {
Logger.WithFields(logrus.Fields{
"ip": ip,
"path": path,
"method": method,
"timestamp": time.Now().Format(time.RFC3339),
}).Info("Page accessed")
}