Restart

话不多说,直接上代码

1、版本检查和进程自身重启间的调度采用 signal chan

package main

import (
    "GoRadar/core"
    "GoRadar/lib"
    "fmt"
    "os/exec"
    "github.com/benmanns/goworker"
    "os"
    "path/filepath"
    "github.com/levigross/grequests"
    "log"
    "runtime"
    "net/http"
    "io"
    "time"
)

// 扫描
func ScanActivityTask(queue string, args ...interface{}) error {
    fmt.Println("调用队列:" + queue)
    ip_range := args[0].(string)
    is_limit_scan_rate := args[1].(bool)
    sa := core.NewScanActivity()
    sa.Scanner(ip_range,is_limit_scan_rate)
    return nil
}
func ScanPortTask(queue string, args ...interface{}) error {
    fmt.Println("调用队列:" + queue)
    ip_range := args[0].(string)
    scan_mode := args[1].(string)
    is_limit_scan_rate := args[2].(bool)
    sa := core.NewScanPort()
    sa.Scan(ip_range,scan_mode,is_limit_scan_rate)
    return nil
}

var (
    version = "1.0.1"
    download_url = ""
)

func Version_validate(c chan string) (bool){
    resp, err := grequests.Get("http://192.168.0.7/version.txt", nil)
    // You can modify the request by passing an optional RequestOptions struct
    if err != nil {
        fmt.Println("Validate version error: Unable to make request ")
        return false
    }else {
        new_version := resp.String()[0:5]
        fmt.Println("new_version:"+ new_version)
        fmt.Println("version:"+ version)
        if version < new_version {
            os_name := runtime.GOOS
            if os_name == "linux" {
                download_url = "http://192.168.0.7/linux/" + new_version
            } else if os_name == "windows" {
                download_url = "http://192.168.0.7/windows/" + new_version
            }
            download,_ := Download_new_agent(download_url,os_name)
            if download == true {
                c <- "new"
                fmt.Println("New agent version found !")
                return true
            }else{
                c <- "old"
                return false
            }
        }else{
            return false
        }
    }
}

func Download_new_agent(url string,os_name string) (bool,error){
    res, err := http.Get(url)
    if err != nil {
        return false,err
    }
    var file_name string
    if os_name == "windows"{
        file_name = "Agent.exe"
    }else if os_name == "linux"{
        file_name = "Agent"
    }else{
        file_name = "Agent"
    }
    cmd := exec.Command("rm","-rf",file_name)
    cmd.Run()
    f, err := os.Create(file_name)
    if err != nil {
        return false,err
    }
    _, er := io.Copy(f, res.Body)
    if er != nil {
        return false,er
    }
    if os_name == "linux"{
        cmdd := exec.Command("chmod","+x",file_name)
        cmdd.Run()
    }
    res.Body.Close()
    f.Close()
    return true,er

}

func Restart_process(){
    filePath, _ := filepath.Abs(os.Args[0])
    cmd := exec.Command(filePath)
    cmd.Stdout = os.Stdout
    cmd.Stderr = os.Stderr
    err := cmd.Start()
    if err != nil {
        log.Fatalf("GracefulRestart: Failed to launch, error: %v", err)
    }

}


func init() {

    cfg := lib.NewConfigUtil("")
    redis_host, _ := cfg.GetString("redis_default", "host")
    redis_port, _ := cfg.GetString("redis_default", "port")
    redis_pass, _ := cfg.GetString("redis_default", "pass")
    redis_db, _ := cfg.GetString("redis_default", "db")
    var dsn_addr string
    if redis_pass != "" {
        dsn_addr = fmt.Sprintf("redis://:%s@%s:%s/%s", redis_pass, redis_host, redis_port, redis_db)
    } else {
        dsn_addr = fmt.Sprintf("redis://%s:%s/%s", redis_host, redis_port, redis_db)
    }

    settings := goworker.WorkerSettings{
        URI:            dsn_addr,
        Connections:    100,
        Queues:         []string{"ScanActivityQueue","ScanPortQuene"},
        UseNumber:      true,
        ExitOnComplete: false,
        Concurrency:    50,
        Namespace:      "goradar:",
        Interval:       5.0,
    }

    goworker.SetSettings(settings)
    //read scan option
    activeswitch, _ := cfg.GetString("agent_default","scanactivity")
    portswitch, _ := cfg.GetString("agent_default","scanport")
    if activeswitch == "yes" {
        goworker.Register("ScanActivityTask", ScanActivityTask)
        fmt.Println("Start active scan !")
    }else if activeswitch == "no"{
        fmt.Println("Doesn't start active scan !")
    }else{
        fmt.Println("Error: config anget->scanactivity param error, only 'yes' or 'no' allowed")
    }
    if portswitch == "yes" {
        goworker.Register("ScanPortTask", ScanPortTask)
        fmt.Println("Start ports scan !")
    }else if portswitch == "no"{
        fmt.Println("Doesn't start ports scan !")
    }else{
        fmt.Println("Error: Config anget->scanport param error,only 'yes' or 'no' allowed")
    }
}

func main() {

    // 加入守护进程机制

    if os.Getppid() != 1 {
        //判断当其是否是子进程,当父进程return之后,子进程会被 系统1 号进程接管
        filePath, _ := filepath.Abs(os.Args[0])
        //将命令行参数中执行文件路径转换成可用路径
        cmd := exec.Command(filePath)
        //将其他命令传入生成出的进程
        cmd.Stdin = os.Stdin
        //给新进程设置文件描述符,可以重定向到文件中
        cmd.Stdout = os.Stdout
        cmd.Stderr = os.Stderr
        //开始执行新进程,不等待新进程退出
        cmd.Start()
        return
    }


    signals := make(chan string)
    go func(){
        for{
            Version_validate(signals)
            time.Sleep(10 * time.Minute)
        }
    }()

    go func() {
        for{
            if err := goworker.Work(); err != nil {
                fmt.Println("Error:", err)
            }
        }
    }()

    for{
        select {
        case signal := <-signals:
            if signal == "new" {
                Restart_process()
                return
            }
        case <-time.After(time.Second * 10):
            fmt.Println("timeout, check again...")
            continue
        }
    }
}