package main

import (
	"bytes"
	"errors"
	"fmt"
	"log"
	"os"
	"path/filepath"
	"regexp"
	"strings"

	"github.com/checkpoint-restore/go-criu/v7/crit"
	"github.com/checkpoint-restore/go-criu/v7/crit/cli"
	"github.com/checkpoint-restore/go-criu/v7/crit/images/pstree"
)

const loopTestImgDir = "test-imgs/loop"

var pageSize = os.Getpagesize()

func main() {
	// Get list of image files
	imgs, err := getImgs()
	if err != nil {
		log.Fatal(err)
	}
	// Run recode test
	if err = recodeImgs(imgs); err != nil {
		log.Fatal(err)
	}

	// Run test for memory pages reading features
	if err = readMemoryPages(); err != nil {
		log.Fatal(err)
	}

	// Run test for process shared memory size
	if err = testGetShmemSize(); err != nil {
		log.Fatal(err)
	}

	log.Println("=== PASS")
}

func recodeImgs(imgs []string) error {
	for _, img := range imgs {
		log.Println("===", img)
		imgFile, err := os.Open(img)
		if err != nil {
			return err
		}
		defer imgFile.Close()
		testImg := img + ".test.img"
		testImgFile, err := os.Create(testImg)
		if err != nil {
			return err
		}
		defer testImgFile.Close()

		c := crit.New(imgFile, testImgFile, "", false, false)
		entryType, err := cli.GetEntryTypeFromImg(imgFile)
		if err != nil {
			return err
		}
		// Decode the binary image file
		decodedImg, err := c.Decode(entryType)
		if err != nil {
			return errors.New(fmt.Sprint("[DECODE]: ", err))
		}
		// Encode it into test binary image file
		if err = c.Encode(decodedImg); err != nil {
			return errors.New(fmt.Sprint("[ENCODE]: ", err))
		}
		// Open and compare original and test files
		imgBytes, err := os.ReadFile(img)
		if err != nil {
			return err
		}
		testImgBytes, err := os.ReadFile(testImg)
		if err != nil {
			return err
		}
		if !bytes.Equal(imgBytes, testImgBytes) {
			return errors.New("[RECODE]: Files do not match")
		}
	}

	return nil
}

func getImgs() ([]string, error) {
	// Certain image files generated by CRIU do not
	// use the protobuf format and contain raw binary
	// data. Some image files are also generated using
	// external tools (ifaddr, route, tmpfs). As these
	// images cannot be processed by CRIT, they are
	// excluded from the tests.
	skipImgs := []string{
		"pages-",
		"pages-shmem-",
		"iptables-",
		"ip6tables-",
		"nftables-",
		"route-",
		"route6-",
		"ifaddr-",
		"tmpfs-",
		"tmpfs-dev-",
		"autofs-",
		"netns-ct-",
		"netns-exp-",
		"rule-",
	}
	// "*.test.img", "*.json.img" or "tmp.*.img" files
	// must be skipped as they are generated by tests
	criuImg := regexp.MustCompile(`^[^\.]*\.img$`)
	dir, err := filepath.Glob(loopTestImgDir + "/*.img")
	if err != nil {
		return nil, err
	}
	var imgs []string

nextFile:
	for _, file := range dir {
		if filepath.Ext(file) == ".img" {
			if !criuImg.MatchString(file) {
				continue
			}
			for _, skip := range skipImgs {
				if strings.HasPrefix(filepath.Base(file), skip) {
					continue nextFile
				}
			}
			imgs = append(imgs, file)
		}
	}

	return imgs, nil
}

// readMemoryPages reads and compares process arguments
// and environment variables from memory pages and corresponding test files.
func readMemoryPages() error {
	pid, err := getTestImgPID(loopTestImgDir)
	if err != nil {
		return err
	}

	mr, err := crit.NewMemoryReader(loopTestImgDir, pid, pageSize)
	if err != nil {
		return err
	}

	// Retrieve process arguments from memory pages
	argsBuff, err := mr.GetPsArgs()
	if err != nil {
		return err
	}

	// Read process environment variables from the environ test file
	testFileArgs, err := os.ReadFile(filepath.Join(loopTestImgDir, "cmdline"))
	if err != nil {
		return err
	}

	if !bytes.Equal(testFileArgs, argsBuff.Bytes()) {
		return errors.New("process arguments do not match")
	}

	// Retrieve process environment variables from memory pages
	envVarsBuffer, err := mr.GetPsEnvVars()
	if err != nil {
		return err
	}

	// Read process environment variables from the environ test file
	envVarsTestFile, err := os.ReadFile(filepath.Join(loopTestImgDir, "environ"))
	if err != nil {
		return err
	}

	if !bytes.Equal(envVarsTestFile, envVarsBuffer.Bytes()) {
		return errors.New("process environment variables do not match")
	}

	return nil
}

// testGetShmemSize tests the GetShmemSize method of the MemoryReader struct.
func testGetShmemSize() error {
	testCases := []struct {
		imgsDir           string
		expectedShmemSize int64
	}{
		{
			imgsDir:           loopTestImgDir,
			expectedShmemSize: 0,
		},
		{
			imgsDir:           "test-imgs/mm_p_257b",
			expectedShmemSize: 0,
		},
		{
			imgsDir:           "test-imgs/mm_pa_257b",
			expectedShmemSize: 0,
		},
		{
			imgsDir: "test-imgs/mm_s_257b",
			// Here and in the next test case, the expected shared memory size is set
			// to the value of pageSize because the Linux kernel uses page-aligned
			// addresses for memory mapping, so when memory size is less than a page size,
			// it aligns the mapping to a single page.
			expectedShmemSize: int64(pageSize),
		},
		{
			imgsDir:           "test-imgs/mm_sa_257b",
			expectedShmemSize: int64(pageSize),
		},
		{
			imgsDir:           "test-imgs/mm_p_4kb",
			expectedShmemSize: 0,
		},
		{
			imgsDir:           "test-imgs/mm_pa_4kb",
			expectedShmemSize: 0,
		},
		{
			imgsDir:           "test-imgs/mm_s_4kb",
			expectedShmemSize: 4096,
		},
		{
			imgsDir:           "test-imgs/mm_sa_4kb",
			expectedShmemSize: 4096,
		},
		{
			imgsDir:           "test-imgs/mm_p_32kb",
			expectedShmemSize: 0,
		},
		{
			imgsDir:           "test-imgs/mm_pa_32kb",
			expectedShmemSize: 0,
		},
		{
			imgsDir:           "test-imgs/mm_s_32kb",
			expectedShmemSize: 32768,
		},
		{
			imgsDir:           "test-imgs/mm_sa_32kb",
			expectedShmemSize: 32768,
		},
	}

	for _, test := range testCases {
		pid, err := getTestImgPID(test.imgsDir)
		if err != nil {
			return err
		}

		mr, err := crit.NewMemoryReader(test.imgsDir, pid, pageSize)
		if err != nil {
			return err
		}

		shmemSize, err := mr.GetShmemSize()
		if err != nil {
			return err
		}

		if shmemSize != test.expectedShmemSize {
			return fmt.Errorf(
				"[%s]: expected shared memory size: %d bytes, actual size: %d bytes",
				test.imgsDir, test.expectedShmemSize, shmemSize,
			)
		}
	}

	return nil
}

func getTestImgPID(dir string) (uint32, error) {
	psTreeFile, err := os.Open(filepath.Join(dir, "pstree.img"))
	if err != nil {
		return 0, err
	}
	defer psTreeFile.Close()

	c := crit.New(psTreeFile, nil, loopTestImgDir, false, true)

	psTreeImg, err := c.Decode(&pstree.PstreeEntry{})
	if err != nil {
		return 0, err
	}

	pid := psTreeImg.Entries[0].Message.(*pstree.PstreeEntry).GetPid()

	return pid, nil
}
