// Copyright (c) 2017-2022 Tigera, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package testutils

import (
	"bytes"
	"errors"
	"fmt"
	"io"
	"os/exec"
	"strconv"
	"strings"
	"time"

	"github.com/onsi/ginkgo"

	//nolint:staticcheck // Ignore ST1001: should not use dot imports
	. "github.com/onsi/gomega"
	log "github.com/sirupsen/logrus"

	"github.com/projectcalico/calico/felix/iptables/cmdshim"
	"github.com/projectcalico/calico/libcalico-go/lib/set"
)

// This file contains shared test infrastructure for testing the iptables package.

func NewMockDataplane(table string, chains map[string][]string, dataplaneMode string) *MockDataplane {
	return &MockDataplane{
		Prologue:      "# generated by dummy iptables-save\n",
		Table:         table,
		Chains:        chains,
		FlushedChains: set.New[string](),
		ChainMods:     set.New[chainMod](),
		DeletedChains: set.New[string](),
		Version:       "iptables v1.5.9\n",
		KernelVersion: "Linux version 4.15.0-34-generic (buildd@lgw01-amd64-037) (gcc version 5.4.0 20160609 " +
			"(Ubuntu 5.4.0-6ubuntu1~16.04.10)) #37~16.04.1-Ubuntu SMP Tue Aug 28 10:44:06 UTC 2018",
		NftablesMode: dataplaneMode == "nft",
	}
}

type chainMod struct {
	name    string
	ruleNum int
}

type MockDataplane struct {
	Prologue                       string
	Table                          string
	Chains                         map[string][]string
	FlushedChains                  set.Set[string]
	ChainMods                      set.Set[chainMod]
	DeletedChains                  set.Set[string]
	Cmds                           []cmdshim.CmdIface
	CmdNames                       []string
	FailNextRestore                bool
	FailAllRestores                bool
	OnPreSave                      func()
	OnPreRestore                   func()
	FailNextSaveRead               bool
	FailNextSaveStdoutPipe         bool
	FailNextKill                   bool
	FailAllSaves                   bool
	FailNextPipeClose              bool
	FailNextStart                  bool
	FailNextGetKernelVersionReader bool
	PipeBuffers                    []*closableBuffer
	CumulativeSleep                time.Duration
	Time                           time.Time
	FailNextVersion                bool
	Version                        string
	KernelVersion                  string
	NftablesMode                   bool
}

func (d *MockDataplane) ResetCmds() {
	d.Cmds = nil
	d.CmdNames = nil
}

func (d *MockDataplane) NewCmd(name string, arg ...string) cmdshim.CmdIface {
	log.WithFields(log.Fields{
		"name":                   name,
		"args":                   arg,
		"FailNextRestore":        d.FailNextRestore,
		"FailNextSaveRead":       d.FailNextSaveRead,
		"FailNextStart":          d.FailNextStart,
		"FailNextKill":           d.FailNextKill,
		"FailNextSaveStdoutPipe": d.FailNextSaveStdoutPipe,
		"FailNextPipeClose":      d.FailNextPipeClose,
		"FailAllRestores":        d.FailAllRestores,
		"FailAllSaves":           d.FailAllSaves,
	}).Info("Simulating new command.")

	var cmd cmdshim.CmdIface
	d.CmdNames = append(d.CmdNames, name)

	if d.NftablesMode && name != "iptables" {
		Expect(name).To(ContainSubstring("-nft"))
	}

	switch name {
	case "iptables-restore", "ip6tables-restore",
		"iptables-legacy-restore", "ip6tables-legacy-restore",
		"iptables-nft-restore", "ip6tables-nft-restore":
		Expect(arg).To(Equal([]string{"--noflush", "--verbose"}))
		cmd = &restoreCmd{
			Dataplane: d,
		}
	case "iptables-save", "ip6tables-save",
		"iptables-legacy-save", "ip6tables-legacy-save",
		"iptables-nft-save", "ip6tables-nft-save":
		Expect(arg).To(Equal([]string{"-t", d.Table}))
		cmd = &saveCmd{
			Dataplane: d,
		}
	case "iptables":
		Expect(arg).To(Equal([]string{"--version"}))
		cmd = &versionCmd{
			Dataplane: d,
		}
	default:
		ginkgo.Fail(fmt.Sprintf("Unexpected command %v", name))
	}

	d.Cmds = append(d.Cmds, cmd)

	return cmd
}

func (d *MockDataplane) GetKernelVersionReader() (io.Reader, error) {
	if d.FailNextGetKernelVersionReader {
		d.FailNextGetKernelVersionReader = false
		return nil, errors.New("dummy error")
	}
	return bytes.NewBufferString(d.KernelVersion), nil
}

func (d *MockDataplane) Sleep(duration time.Duration) {
	d.CumulativeSleep += duration
	d.Time = d.Time.Add(duration)
}

func (d *MockDataplane) Now() time.Time {
	return d.Time
}

func (d *MockDataplane) AdvanceTimeBy(amount time.Duration) {
	d.Time = d.Time.Add(amount)
}

func (d *MockDataplane) ChainFlushed(chainName string) bool {
	return d.FlushedChains.Contains(chainName)
}

func (d *MockDataplane) RuleTouched(chainName string, ruleNum int) bool {
	if d.ChainFlushed(chainName) {
		// Whole chain blown away.
		return true
	}
	return d.ChainMods.Contains(chainMod{name: chainName, ruleNum: ruleNum})
}

type restoreCmd struct {
	Dataplane     *MockDataplane
	Stdin         io.Reader
	CapturedStdin string
	Stdout        io.Writer
	Stderr        io.Writer
}

func (d *restoreCmd) SetStdin(r io.Reader) {

	var buf bytes.Buffer
	_, err := io.Copy(&buf, r)
	if err != nil {
		panic(err)
	}

	d.Stdin = &buf
	d.CapturedStdin = buf.String()
}

func (d *restoreCmd) SetStdout(w io.Writer) {
	d.Stdout = w
}

func (d *restoreCmd) SetStderr(w io.Writer) {
	d.Stderr = w
}

func (d *restoreCmd) Output() ([]byte, error) {
	ginkgo.Fail("Not implemented")
	return nil, errors.New("Not implemented")
}

func (d *restoreCmd) StdoutPipe() (io.ReadCloser, error) {
	ginkgo.Fail("Not implemented")
	return nil, errors.New("Not implemented")
}

func (d *restoreCmd) Start() error {
	ginkgo.Fail("Not implemented")
	return errors.New("Not implemented")
}

func (d *restoreCmd) Wait() error {
	ginkgo.Fail("Not implemented")
	return errors.New("Not implemented")
}

func (d *restoreCmd) Kill() error {
	return nil
}

func (d *restoreCmd) String() string {
	return fmt.Sprintf("restoreCmd %#v", d.CapturedStdin)
}

func (d *restoreCmd) Run() error {
	log.Info("Running simulated iptables-restore")
	// Get the input.
	var buf bytes.Buffer
	_, err := buf.ReadFrom(d.Stdin)
	Expect(err).NotTo(HaveOccurred())
	input := buf.String()

	if d.Dataplane.OnPreRestore != nil {
		log.Warn("OnPreRestore set, calling it")
		d.Dataplane.OnPreRestore()
		d.Dataplane.OnPreRestore = nil
	}
	if d.Dataplane.FailNextRestore {
		log.Warn("Simulating an iptables-restore failure")
		d.Dataplane.FailNextRestore = false
		return errors.New("simulated failure")
	}
	if d.Dataplane.FailAllRestores {
		log.Warn("Simulating an iptables-restore failure")
		return errors.New("simulated failure")
	}

	// Process it line by line.
	lines := strings.Split(input, "\n")
	commitSeen := false
	tableSeen := false

	for i, line := range lines {
		log.WithFields(log.Fields{"line": line, "lineNum": i + 1}).Info("Parsing line")
		if strings.Trim(line, " \n") == "" {
			// Ignore empty lines (including final trailing return).
			continue
		}
		if strings.HasPrefix(line, "#") {
			// Ignore comments.
			continue
		}
		if strings.HasPrefix(line, "*") {
			// Start of a table.
			if tableSeen {
				Expect(d.Dataplane.NftablesMode).To(BeTrue(), "Only nft mode should use more than one transaction")
				// We've already had one transaction, check that it was committed
				Expect(commitSeen).To(BeTrue())
				commitSeen = false
			}
			Expect(line[1:]).To(Equal(d.Dataplane.Table))
			tableSeen = true
			continue
		}
		Expect(tableSeen).To(BeTrue(), "No *table stanza before starting input")
		Expect(commitSeen).To(BeFalse(), "Unexpected line after COMMIT")
		if line == "COMMIT" {
			commitSeen = true
			continue
		}

		chains := d.Dataplane.Chains

		if strings.HasPrefix(line, ":") {
			// Chain forward-ref, creates and flushes the chain as needed.
			parts := strings.Split(line[1:], " ")
			chainName := parts[0]
			Expect(parts[1:]).To(Equal([]string{"-", "-"}))
			chains[chainName] = []string{}
			d.Dataplane.FlushedChains.Add(chainName)
			continue
		}

		parts := strings.Split(line, " ")
		action := parts[0]
		var chainName string
		switch action {
		case "-A", "--append":
			chainName = parts[1]
			if strings.HasPrefix(chainName, "cali") && d.Dataplane.NftablesMode {
				Expect(d.Dataplane.FlushedChains.Contains(chainName)).To(BeTrue(),
					"In nft mode, it's not safe to modify chain without flushing")
			}
			rest := strings.Join(parts[2:], " ")
			Expect(chains[chainName]).NotTo(BeNil(), "Append to unknown chain: "+chainName)
			chains[chainName] = append(chains[chainName], rest)
			d.Dataplane.ChainMods.Add(chainMod{name: chainName, ruleNum: len(chains[chainName])})
		case "-I", "--insert":
			chainName = parts[1]
			Expect(chains[chainName]).NotTo(BeNil(), "Insert to unknown chain: "+chainName)
			chains[chainName] = append(chains[chainName], "") // Make room
			chain := chains[chainName]

			// If the first arg after the chain name is a line number, then insert by line number.
			if lineNum, err := strconv.Atoi(parts[2]); err == nil {
				ruleIdx := lineNum - 1 // 0-indexed
				copy(chain[ruleIdx+1:], chain[ruleIdx:])
				chain[ruleIdx] = strings.Join(parts[3:], " ")
				chains[chainName] = chain
				d.Dataplane.ChainMods.Add(chainMod{name: chainName, ruleNum: lineNum})
			} else {
				// Otherwise insert at the top.
				copy(chain[1:], chain[:len(chain)-1])
				chain[0] = strings.Join(parts[2:], " ")
				d.Dataplane.ChainMods.Add(chainMod{name: chainName, ruleNum: 1})
			}
		case "-R", "--replace":
			Expect(d.Dataplane.NftablesMode).To(BeFalse(), "Replace shouldn't be used in nft mode")
			chainName = parts[1]
			ruleNum, err := strconv.Atoi(parts[2]) // 1-indexed position of rule.
			Expect(err).NotTo(HaveOccurred())
			rest := strings.Join(parts[3:], " ")
			ruleIdx := ruleNum - 1 // 0-indexed array index of rule.
			chain := chains[chainName]
			Expect(len(chain)).To(BeNumerically(">", ruleIdx), "Replace of nonexistent rule")
			chain[ruleIdx] = rest
			d.Dataplane.ChainMods.Add(chainMod{name: chainName, ruleNum: ruleNum})
		case "-D", "--delete":
			chainName = parts[1]

			// If second arg is numeric, this is a delete by line number.
			if ruleNum, err := strconv.Atoi(parts[2]); err == nil {
				Expect(parts).To(HaveLen(3), "Unexpected argument after rule position in --delete")
				Expect(chainName).To(HavePrefix("cali"), "Deleting rule from non-calico chain by number can cause races")

				ruleIdx := ruleNum - 1 // 0-indexed array index of rule.
				chain := chains[chainName]
				Expect(len(chain)).To(BeNumerically(">", ruleIdx), "Delete of nonexistent rule")

				for i := ruleIdx; i < len(chain)-1; i++ {
					chain[i] = chain[i+1]
				}
				chains[chainName] = chain[:len(chain)-1]
				d.Dataplane.ChainMods.Add(chainMod{name: chainName, ruleNum: ruleNum})
			} else {
				// Otherwise, treat this as a delete by full rule.

				// Rule is inserted without chain name
				rule := strings.Join(parts[2:], " ")
				chain := chains[chainName]
				i := 0

				newChain := []string{}
				var found bool
				for ; i < len(chain); i++ {
					if chain[i] == rule {
						found = true
						continue
					}
					newChain = append(newChain, chain[i])
				}

				Expect(found).To(BeTrue(), "Delete of nonexistent rule")
				chains[chainName] = newChain
				d.Dataplane.ChainMods.Add(chainMod{name: chainName, ruleNum: i})

			}
		case "-X", "--delete-chain":
			chainName = parts[1]
			Expect(parts).To(HaveLen(2), "--delete-chain only has one argument")
			Expect(chains[chainName]).To(Equal([]string{}), "Only empty chains can be deleted")
			delete(chains, chainName)
			d.Dataplane.DeletedChains.Add(chainName)
		default:
			ginkgo.Fail("Unknown action: " + action)
		}
		log.Debugf("Updated chain '%s' (len=%v); new contents:\n\t%v",
			chainName, len(chains[chainName]), strings.Join(chains[chainName], "\n\t"))
	}
	Expect(commitSeen).To(BeTrue(), "didn't see a COMMIT line")
	return nil
}

func PrependLine(src []string, line string) []string {
	// Make space for the line - the value doesn't matter.
	src = append(src, "")
	// "Shift" the elements to the right
	copy(src[1:], src[0:])

	src[0] = line
	return src
}

type saveCmd struct {
	Dataplane  *MockDataplane
	stdoutPipe *closableBuffer
}

func (d *saveCmd) String() string {
	return "saveCmd"
}

func (d *saveCmd) SetStdin(r io.Reader) {
	ginkgo.Fail("Not implemented")
}

func (d *saveCmd) SetStdout(w io.Writer) {
	ginkgo.Fail("Not implemented")
}

func (d *saveCmd) SetStderr(w io.Writer) {
	ginkgo.Fail("Not implemented")
}

func (d *saveCmd) Start() error {
	if d.Dataplane.FailNextStart {
		d.Dataplane.FailNextStart = false
		return errors.New("dummy start failure")
	}
	if d.Dataplane.OnPreSave != nil {
		log.Warn("OnPreSave set, calling it")
		d.Dataplane.OnPreSave()
		d.Dataplane.OnPreSave = nil
	}
	return nil
}

func (d *saveCmd) Wait() error {
	if d.stdoutPipe != nil {
		return d.stdoutPipe.Close()
	}
	return nil
}

func (d *saveCmd) Kill() error {
	if d.Dataplane.FailNextKill {
		d.Dataplane.FailNextKill = false
		return errors.New("kill failed")
	}
	return nil
}

func (d *saveCmd) Output() ([]byte, error) {
	if d.Dataplane.FailNextSaveRead {
		d.Dataplane.FailNextSaveRead = false
		return nil, errors.New("simulated failure")
	}
	if d.Dataplane.FailAllSaves {
		return nil, errors.New("simulated failure")
	}
	var buf bytes.Buffer

	buf.WriteString(d.Dataplane.Prologue)
	buf.WriteString(fmt.Sprintf("*%s\n", d.Dataplane.Table))
	for chainName := range d.Dataplane.Chains {
		buf.WriteString(fmt.Sprintf(":%s - [123:456]\n", chainName))
	}

	for chainName, chain := range d.Dataplane.Chains {
		for _, rule := range chain {
			buf.WriteString(fmt.Sprintf("-A %s %s\n", chainName, rule))
		}
	}
	buf.WriteString("COMMIT\n")
	buf.WriteString("# completed\n")

	log.Debugf("Calculated save output:\n%v", buf.String())

	return buf.Bytes(), nil
}

func (d *saveCmd) StdoutPipe() (io.ReadCloser, error) {
	var readErr error
	if d.Dataplane.FailNextSaveRead {
		d.Dataplane.FailNextSaveRead = false
		readErr = errors.New("simulated Read() failure")
	}

	if d.Dataplane.FailNextSaveStdoutPipe {
		d.Dataplane.FailNextSaveStdoutPipe = false
		return nil, errors.New("simulated StdoutPipe() failure")
	}

	buf, err := d.Output()
	if err != nil {
		return nil, err
	}
	var closeErr error
	if d.Dataplane.FailNextPipeClose {
		closeErr = errors.New("dummy deferred flush error")
		d.Dataplane.FailNextPipeClose = false
	}
	cb := &closableBuffer{
		b:        bytes.NewBuffer(buf),
		ReadErr:  readErr,
		CloseErr: closeErr,
	}
	d.Dataplane.PipeBuffers = append(d.Dataplane.PipeBuffers, cb)
	if d.stdoutPipe != nil {
		ginkgo.Fail("StdoutPipe() called more than once")
	}
	d.stdoutPipe = cb
	return cb, nil
}

func (d *saveCmd) Run() error {
	return errors.New("not implemented")
}

type versionCmd struct {
	Dataplane *MockDataplane
}

func (d *versionCmd) String() string {
	return "versionCmd"
}

func (d *versionCmd) SetStdin(r io.Reader) {
	ginkgo.Fail("Not implemented")
}

func (d *versionCmd) SetStdout(w io.Writer) {
	ginkgo.Fail("Not implemented")
}

func (d *versionCmd) SetStderr(w io.Writer) {
	ginkgo.Fail("Not implemented")
}

func (d *versionCmd) Start() error {
	if d.Dataplane.FailNextStart {
		d.Dataplane.FailNextStart = false
		return errors.New("dummy start failure")
	}
	return nil
}

func (d *versionCmd) Wait() error {
	ginkgo.Fail("Not implemented")
	return nil
}

func (d *versionCmd) Kill() error {
	ginkgo.Fail("Not implemented")
	return nil
}

func (d *versionCmd) Output() ([]byte, error) {
	if d.Dataplane.FailNextVersion {
		d.Dataplane.FailNextVersion = false
		return nil, errors.New("simulated failure")
	}

	return []byte(d.Dataplane.Version), nil
}

func (d *versionCmd) StdoutPipe() (io.ReadCloser, error) {
	return nil, errors.New("not implemented")
}

func (d *versionCmd) Run() error {
	return errors.New("not implemented")
}

type closableBuffer struct {
	b                 *bytes.Buffer
	Closed            bool
	CloseErr, ReadErr error
}

func (b *closableBuffer) Read(p []byte) (n int, err error) {
	if b.ReadErr != nil {
		return 0, b.ReadErr
	}
	return b.b.Read(p)
}

func (b *closableBuffer) Close() error {
	if b.Closed {
		ginkgo.Fail("Already closed")
	}
	b.Closed = true
	return b.CloseErr
}

func LookPathNoLegacy(p string) (string, error) {
	if strings.Contains(p, "legacy") {
		return "", &exec.Error{}
	}
	return p, nil
}

func LookPathAll(p string) (string, error) {
	return p, nil
}
