package agent

import (
	"context"
	"crypto/tls"
	"fmt"
	"log"
	"math"
	"math/rand"
	"sync/atomic"
	"time"
	"virtual-nezha-agent/pkg/geoip"
	"virtual-nezha-agent/pkg/proto"

	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/connectivity"
	"google.golang.org/grpc/credentials"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/status"
)

type AuthHandler struct {
	ClientSecret string
	ClientUUID   string
}

func (a *AuthHandler) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) {
	return map[string]string{"client_secret": a.ClientSecret, "client_uuid": a.ClientUUID}, nil
}

func (a *AuthHandler) RequireTransportSecurity() bool {
	return false
}

func (a *Instance) createConnection() (*grpc.ClientConn, error) {
	auth := AuthHandler{
		ClientSecret: a.ClientSecret,
		ClientUUID:   a.ClientUUID,
	}

	var opts []grpc.DialOption
	if a.UseSSL {
		opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{
			MinVersion:         tls.VersionTLS12,
			InsecureSkipVerify: true,
		})))
	} else {
		opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
	}
	opts = append(opts, grpc.WithPerRPCCredentials(&auth))

	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()

	conn, err := grpc.DialContext(ctx, a.ServerAddress, opts...)
	return conn, err
}

func (a *Instance) connect() error {
	a.updateConnectionStatus(ConnectionStatusConnecting, nil)
	log.Printf("Agent %s 正在连接服务器...", a.ID)

	conn, err := a.createConnection()
	if err != nil {
		a.updateConnectionStatus(ConnectionStatusError, err)
		return fmt.Errorf("创建连接失败: %w", err)
	}

	a.mutex.Lock()
	a.conn = conn
	a.mutex.Unlock()

	client := proto.NewNezhaServiceClient(conn)

	if err = a.reportSystemInfo(client); err != nil {
		a.updateConnectionStatus(ConnectionStatusError, err)
		return fmt.Errorf("上报系统信息失败: %w", err)
	}

	if err := a.reportGeoIP(client); err != nil {
		log.Printf("Agent %s 上报地理位置信息失败: %v", a.ID, err)
	}

	a.updateConnectionStatus(ConnectionStatusConnected, nil)
	log.Printf("Agent %s 连接成功", a.ID)

	return nil
}

func (a *Instance) reconnect(ctx context.Context) {
	if !atomic.CompareAndSwapInt32(&a.isReconnecting, 0, 1) {
		return
	}
	defer atomic.StoreInt32(&a.isReconnecting, 0)

	for {
		select {
		case <-ctx.Done():
			return
		default:
		}

		if a.reconnectConfig.MaxRetries > 0 && a.ConnectionStatus.RetryCount >= a.reconnectConfig.MaxRetries {
			log.Printf("Agent %s 达到最大重试次数，停止重连", a.ID)
			a.updateConnectionStatus(ConnectionStatusError, fmt.Errorf("达到最大重试次数"))
			return
		}

		a.updateConnectionStatus(ConnectionStatusReconnecting, nil)
		delay := a.calculateRetryDelay(a.ConnectionStatus.RetryCount)

		log.Printf("Agent %s 将在 %v 后进行第 %d 次重连...", a.ID, delay, a.ConnectionStatus.RetryCount+1)

		select {
		case <-ctx.Done():
			return
		case <-time.After(delay):
		}

		if err := a.connect(); err != nil {
			log.Printf("Agent %s 重连失败: %v", a.ID, err)
			if !isTemporaryError(err) {
				log.Printf("Agent %s 遇到永久性错误，停止重连: %v", a.ID, err)
				return
			}
			continue
		}

		go a.reportSystemState(ctx, proto.NewNezhaServiceClient(a.conn))
		return
	}
}

func (a *Instance) healthCheck(ctx context.Context) {
	ticker := time.NewTicker(a.reconnectConfig.HealthCheckFreq)
	defer ticker.Stop()

	for {
		select {
		case <-ctx.Done():
			return
		case <-a.healthCheckStop:
			return
		case <-ticker.C:
			if a.conn == nil {
				continue
			}

			state := a.conn.GetState()
			switch state {
			case connectivity.TransientFailure, connectivity.Shutdown:
				log.Printf("Agent %s 检测到连接异常，状态: %s", a.ID, state)
				if a.ConnectionStatus.Status == ConnectionStatusConnected {
					a.updateConnectionStatus(ConnectionStatusError, fmt.Errorf("连接状态异常: %s", state))
					go a.reconnect(ctx)
				}
			case connectivity.Ready:
				a.lastHealthCheck = time.Now()
			}
		}
	}
}

func (a *Instance) Start() {
	a.mutex.Lock()
	defer a.mutex.Unlock()

	log.Printf("Agent %s 开始启动...", a.ID)

	ctx, cancel := context.WithCancel(context.Background())
	a.cancelFunc = cancel

	go func() {
		if err := a.connect(); err != nil {
			log.Printf("Agent %s 初始连接失败: %v", a.ID, err)
			if isTemporaryError(err) {
				go a.reconnect(ctx)
			}
			return
		}

		go a.healthCheck(ctx)
		go a.reportSystemState(ctx, proto.NewNezhaServiceClient(a.conn))
	}()
}

func (a *Instance) Stop() {
	if a.cancelFunc != nil {
		a.cancelFunc()
		a.cancelFunc = nil
	}

	select {
	case a.healthCheckStop <- struct{}{}:
	default:
	}

	a.mutex.Lock()
	if a.conn != nil {
		a.conn.Close()
		a.conn = nil
	}

	a.ConnectionStatus.Status = ConnectionStatusStopped
	a.Status = "stopped"
	a.mutex.Unlock()

	log.Printf("Agent %s 已停止", a.ID)
}

func (a *Instance) reportSystemInfo(client proto.NezhaServiceClient) error {
	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()

	_, err := client.ReportSystemInfo2(ctx, &proto.Host{
		Platform:        a.config.System.Platform,
		PlatformVersion: a.config.System.PlatformVersion,
		Cpu:             a.config.GetCPUInfo(),
		MemTotal:        a.config.Memory.Total,
		DiskTotal:       a.config.Disk.Total,
		SwapTotal:       a.config.Swap.Total,
		Arch:            a.config.System.Arch,
		Virtualization:  a.config.System.Virtualization,
		BootTime:        uint64(time.Now().Unix()),
		Version:         a.config.System.Version,
		Gpu:             a.config.GPU.Models,
	})

	return err
}

func (a *Instance) reportGeoIP(client proto.NezhaServiceClient) error {
	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()

	ip := geoip.GetIP(a.CountryCode)
	_, err := client.ReportGeoIP(ctx, &proto.GeoIP{
		Use6: false,
		Ip: &proto.IP{
			Ipv4: ip,
		},
	})
	log.Printf("Agent %s 地理位置上报: %s -> %s", a.ID, a.CountryCode, ip)
	return err
}

func (a *Instance) reportSystemState(ctx context.Context, client proto.NezhaServiceClient) {
	defer func() {
		log.Printf("Agent %s 状态上报已停止", a.ID)
	}()

	state, err := client.ReportSystemState(ctx)
	if err != nil {
		log.Printf("Agent %s 创建状态上报流失败: %v", a.ID, err)
		if isTemporaryError(err) {
			a.updateConnectionStatus(ConnectionStatusError, err)
			go a.reconnect(ctx)
		} else {
			a.updateConnectionStatus(ConnectionStatusError, err)
		}
		return
	}

	ticker := time.NewTicker(time.Duration(a.config.Server.Interval) * time.Millisecond)
	defer ticker.Stop()

	consecutiveErrors := 0
	maxConsecutiveErrors := 3

	for {
		select {
		case <-ctx.Done():
			return
		case <-ticker.C:
			cpuUsage := a.simulator.GetCPUUsage()
			memUsage := a.simulator.GetMemoryUsage()
			diskUsage := a.simulator.GetDiskUsage()
			networkInTotal := a.simulator.GetNetworkInTotal()
			networkOutTotal := a.simulator.GetNetworkOutTotal()

			stateData := &proto.State{
				Cpu:            cpuUsage,
				MemUsed:        memUsage,
				SwapUsed:       a.simulator.GetSwapUsage(),
				DiskUsed:       diskUsage,
				NetInTransfer:  networkInTotal,
				NetOutTransfer: networkOutTotal,
				NetInSpeed:     a.simulator.GetNetworkInSpeed(),
				NetOutSpeed:    a.simulator.GetNetworkOutSpeed(),
				Uptime:         a.simulator.GetUptime(),
				Load1:          a.simulator.GetLoad1(),
				Load5:          a.simulator.GetLoad5(),
				Load15:         a.simulator.GetLoad15(),
				TcpConnCount:   uint64(a.simulator.GetTcpConnCount()),
				UdpConnCount:   uint64(a.simulator.GetUdpConnCount()),
				ProcessCount:   uint64(a.simulator.GetProcessCount()),
				Gpu:            a.simulator.GetGPUUsage(),
			}

			a.mutex.Lock()
			a.CPUUsage = cpuUsage
			a.MemUsage = memUsage
			a.DiskUsage = diskUsage
			a.NetworkIn = networkInTotal
			a.NetworkOut = networkOutTotal
			a.UpdatedAt = time.Now()
			a.mutex.Unlock()

			err = state.SendMsg(stateData)
			if err != nil {
				consecutiveErrors++
				log.Printf("Agent %s 发送状态数据失败 (%d/%d): %v", a.ID, consecutiveErrors, maxConsecutiveErrors, err)

				if consecutiveErrors >= maxConsecutiveErrors {
					log.Printf("Agent %s 连续发送失败次数过多，触发重连", a.ID)
					if isTemporaryError(err) {
						a.updateConnectionStatus(ConnectionStatusError, err)
						go a.reconnect(ctx)
					} else {
						a.updateConnectionStatus(ConnectionStatusError, err)
					}
					return
				}
				continue
			}

			consecutiveErrors = 0
		}
	}
}

func (a *Instance) calculateRetryDelay(attempt int) time.Duration {
	if attempt == 0 {
		return a.reconnectConfig.BaseDelay
	}

	delay := time.Duration(float64(a.reconnectConfig.BaseDelay) * math.Pow(a.reconnectConfig.Multiplier, float64(attempt-1)))
	if delay > a.reconnectConfig.MaxDelay {
		delay = a.reconnectConfig.MaxDelay
	}

	if a.reconnectConfig.Jitter {
		jitter := time.Duration(float64(delay) * 0.1 * (2*rand.New(rand.NewSource(time.Now().UnixNano())).Float64() - 1))
		delay += jitter
	}

	return delay
}

func (a *Instance) updateConnectionStatus(status string, err error) {
	a.mutex.Lock()
	defer a.mutex.Unlock()

	a.ConnectionStatus.Status = status
	if err != nil {
		a.ConnectionStatus.LastError = err.Error()
	} else {
		a.ConnectionStatus.LastError = ""
	}

	if status == ConnectionStatusConnected {
		a.ConnectionStatus.LastConnected = time.Now()
		a.ConnectionStatus.RetryCount = 0
		a.Status = "running"
	} else if status == ConnectionStatusReconnecting {
		a.ConnectionStatus.RetryCount++
		delay := a.calculateRetryDelay(a.ConnectionStatus.RetryCount)
		a.ConnectionStatus.NextRetryAt = time.Now().Add(delay)
		a.Status = "reconnecting"
	} else if status == ConnectionStatusError {
		a.Status = "error"
	}
}

func isTemporaryError(err error) bool {
	if err == nil {
		return false
	}

	if grpcStatus, ok := status.FromError(err); ok {
		switch grpcStatus.Code() {
		case codes.Unavailable, codes.DeadlineExceeded, codes.Aborted:
			return true
		case codes.Unauthenticated, codes.PermissionDenied, codes.InvalidArgument:
			return false
		default:
			return true
		}
	}

	return true
}
