add-domain-user-into-ad-security-group-api
获得一个简单需求,要能够比较简单的将一些用户添加到安全组里面去,然后根据规则,对这个安全组的人员的网络进行限制。
大概用了2个小时的时间完成整个工作。
主要思路就是提供一个api,然后按照add,delete,list这样的动作来进行添加,删除和列表这个安全组的成员。
用法: 添加:https://server/add?user=user1 删除:https://server/delete?user=user1 列表:https://server/list
用的是golang实现。
代码如下:
package main
import (
"crypto/tls"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/go-ldap/ldap/v3"
)
//定义连接参数
const (
adServer = "ldaps://domaincontroller:636"
adUser = "account@example.com"
adPassword = "xxxxxx"
adGroupDN = "CN=limited_sc,OU=Groups,dc=example,dc=com"
logDir = "logs"
logFile = "add-netid-into-vpn-group.log"
certFile = "server.crt"
keyFile = "server.key"
port = ":443"
ldapTimeout = 10 * time.Second
)
// Add this init function at the beginning of the file
func init() {
// Create logs directory if it doesn't exist
if err := os.MkdirAll(logDir, 0755); err != nil {
log.Fatalf("Failed to create log directory: %v", err)
}
// Open log file with append mode
logPath := filepath.Join(logDir, logFile)
file, err := os.OpenFile(logPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
log.Fatalf("Failed to open log file: %v", err)
}
// Set stdout to be unbuffered
// Use MultiWriter to write to both file and console
}
// Update the logAction function to include both file logging and console output
func logAction(clientIP, action, details string) {
....
timestamp, clientIP, action, details)
//log.Print(logMessage)
// Flush the output to ensure it appears immediately
if f, ok := log.Writer().(*os.File); ok {
f.Sync() // Ensure the output is flushed
} else {
// If log.Writer() is not a file, we can use log.Output to flush
log.Output(2, logMessage) // Adjust the call to ensure immediate output
}
}
func isIPAllowed(ipStr string) bool {
// Handle empty IP
if ipStr == "" {
return false
}
allowedNetwork := "192.168.10.0/23" // or your specific network
ip := net.ParseIP(ipStr)
if ip == nil {
log.Printf("Failed to parse IP: %s", ipStr)
return false
}
_, subnet, err := net.ParseCIDR(allowedNetwork)
if err != nil {
log.Printf("Failed to parse CIDR: %v", err)
return false
}
return subnet.Contains(ip)
}
func dialLDAP() (*ldap.Conn, error) {
dialer := &net.Dialer{Timeout: ldapTimeout}
conn, err := ldap.DialURL(
adServer,
ldap.DialWithDialer(dialer),
ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
)
if err != nil {
return nil, fmt.Errorf("failed to connect to AD: %v", err)
}
return conn, nil
}
func listGroupMembers() ([]string, error) {
conn, err := dialLDAP()
if err != nil {
return nil, fmt.Errorf("failed to connect to AD: %v", err)
}
defer conn.Close()
err = conn.Bind(adUser, adPassword)
if err != nil {
return nil, fmt.Errorf("failed to bind to AD: %v", err)
}
searchRequest := ldap.NewSearchRequest(
....
nil,
)
result, err := conn.Search(searchRequest)
if err != nil {
return nil, fmt.Errorf("failed to search group: %v", err)
}
if len(result.Entries) == 0 {
return nil, fmt.Errorf("group not found")
}
members := []string{}
for _, member := range result.Entries[0].GetAttributeValues("member") {
// Extract CN from DN
cn := strings.Split(member, ",")[0]
cn = strings.TrimPrefix(cn, "CN=")
members = append(members, cn)
}
return members, nil
}
func addUserToGroup(userDN string) error {
conn, err := dialLDAP()
if err != nil {
return err
}
defer conn.Close()
err = conn.Bind(adUser, adPassword)
if err != nil {
return fmt.Errorf("failed to bind to AD: %v", err)
}
modifyReq := ldap.NewModifyRequest(adGroupDN, nil)
modifyReq.Add("member", []string{userDN})
// 使用普通的 Modify 方法,因为 ModifyWithContext 不可用
err = conn.Modify(modifyReq)
if err != nil {
return fmt.Errorf("failed to add user to group: %v", err)
}
return nil
}
func removeUserFromGroup(userDN string) error {
conn, err := dialLDAP()
if err != nil {
return err
}
defer conn.Close()
err = conn.Bind(adUser, adPassword)
if err != nil {
return fmt.Errorf("failed to bind to AD: %v", err)
}
modifyReq := ldap.NewModifyRequest(adGroupDN, nil)
modifyReq.Delete("member", []string{userDN})
// 使用普通的 Modify 方法,而不是 ModifyWithContext
err = conn.Modify(modifyReq)
if err != nil {
return fmt.Errorf("failed to remove user from group: %v", err)
}
return nil
}
// Add this type for error response
type ErrorResponse struct {
Error string `json:"error"`
Time string `json:"time"`
Details string `json:"details"`
}
// Add middleware for IP checking
func checkIPMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Try X-Real-IP header first
....
}
}
// If no proxy headers, use RemoteAddr
if ip == "" {
ip = c.ClientIP()
}
ip = strings.TrimSpace(ip)
if !isIPAllowed(ip) {
c.JSON(http.StatusForbidden, ErrorResponse{
Error: "Access Denied",
Time: time.Now().Format("2006-01-02 15:04:05"),
Details: fmt.Sprintf("IP %s is not allowed to access this service", ip),
})
logAction(ip, "ACCESS_DENIED", fmt.Sprintf("Attempted access from unauthorized IP: %s", ip))
c.Abort()
return
}
c.Next()
}
}
// Update main function to use Gin
func main() {
// Set Gin to release mode
gin.SetMode(gin.ReleaseMode)
r := gin.New()
// Use Recovery middleware
r.Use(gin.Recovery())
// Apply IP check middleware to all routes
r.Use(checkIPMiddleware())
// Add endpoint
r.GET("/add", func(c *gin.Context) {
clientIP := c.ClientIP()
userAccount := c.Query("user")
if userAccount == "" {
....
})
return
}
userDN := fmt.Sprintf("CN=%s,dc=example,dc=com", userAccount)
....
})
return
}
logAction(clientIP, "ADD_SUCCESS", fmt.Sprintf("user: %s", userAccount))
c.String(http.StatusOK, "user %s added successfully", userAccount)
})
// List endpoint
r.GET("/list", func(c *gin.Context) {
clientIP := c.ClientIP()
members, err := listGroupMembers()
if err != nil {
.....
})
return
}
logAction(clientIP, "LIST", fmt.Sprintf("total members: %d", len(members)))
c.JSON(http.StatusOK, members)
})
// Delete endpoint
r.GET("/delete", func(c *gin.Context) {
clientIP := c.ClientIP()
userAccount := c.Query("user")
if userAccount == "" {
c.JSON(http.StatusBadRequest, ErrorResponse{
.....
})
return
}
userDN := fmt.Sprintf("CN=%s,dc=example,dc=com", userAccount)
if err := removeUserFromGroup(userDN); err != nil {
....
Details: err.Error(),
})
return
}
logAction(clientIP, "DELETE_SUCCESS", fmt.Sprintf("user: %s", userAccount))
c.String(http.StatusOK, "user %s removed successfully", userAccount)
})
// Get certificate paths
exePath, err := os.Executable()
....
// Check certificate files
....
// Log startup information
log.Printf("Starting HTTPS server on port %s", port)
log.Printf("Using certificate: %s", certPath)
log.Printf("Using private key: %s", keyPath)
// Start HTTPS server
if err := r.RunTLS(port, certPath, keyPath); err != nil {
log.Fatalf("Failed to start HTTPS server: %v", err)
}
}