package parser

import (
	"strings"

	"github.com/a-h/parse"
)

// StripType takes the parser and throws away the return value.
func StripType[T any](p parse.Parser[T]) parse.Parser[any] {
	return parse.Func(func(in *parse.Input) (out any, matched bool, err error) {
		return p.Parse(in)
	})
}

func ExpressionOf(p parse.Parser[string]) parse.Parser[Expression] {
	return parse.Func(func(in *parse.Input) (out Expression, matched bool, err error) {
		from := in.Position()

		var exp string
		if exp, matched, err = p.Parse(in); err != nil || !matched {
			return
		}

		return NewExpression(exp, from, in.Position()), true, nil
	})
}

var lt = parse.Rune('<')
var gt = parse.Rune('>')
var spaceOrTab = parse.Any(parse.Rune(' '), parse.Rune('\t'))
var spaceOrTabOrNewLine = parse.Any(spaceOrTab, parse.Rune('\n'))
var openBrace = parse.String("{")
var optionalSpaces = parse.StringFrom(parse.Optional(
	parse.AtLeast(1, spaceOrTab)))
var optionalSpacesOrNewLine = parse.StringFrom(parse.Optional(
	parse.AtLeast(1, spaceOrTabOrNewLine)))
var openBraceWithPadding = parse.StringFrom(optionalSpaces,
	openBrace,
	optionalSpaces)
var openBraceWithOptionalPadding = parse.Any(openBraceWithPadding, openBrace)

var closeBrace = parse.String("}")
var closeBraceWithOptionalPadding = parse.StringFrom(optionalSpaces, closeBrace)

var dblOpenBrace = parse.String("{{")
var dblOpenBraceWithOptionalPaddingOrNewLine = parse.StringFrom(dblOpenBrace, optionalSpacesOrNewLine)

var dblCloseBrace = parse.String("}}")
var dblCloseBraceWithOptionalPadding = parse.StringFrom(optionalSpaces, dblCloseBrace)

var openBracket = parse.String("(")
var closeBracket = parse.String(")")

var stringUntilNewLine = parse.StringUntil(parse.NewLine)
var newLineOrEOF = parse.Or(parse.NewLine, parse.EOF[string]())
var stringUntilNewLineOrEOF = parse.StringUntil(newLineOrEOF)

var jsOrGoSingleLineComment = parse.StringFrom(parse.String("//"), parse.StringUntil(parse.Any(parse.NewLine, parse.EOF[string]())))
var jsOrGoMultiLineComment = parse.StringFrom(parse.String("/*"), parse.StringUntil(parse.String("*/")))

var exp = expressionParser{
	startBraceCount: 1,
}

type expressionParser struct {
	startBraceCount int
}

func (p expressionParser) Parse(pi *parse.Input) (s Expression, matched bool, err error) {
	from := pi.Position()

	braceCount := p.startBraceCount

	sb := new(strings.Builder)
loop:
	for {
		var result string

		// Try to parse a single line comment.
		if result, matched, err = jsOrGoSingleLineComment.Parse(pi); err != nil {
			return
		}
		if matched {
			sb.WriteString(result)
			continue
		}

		// Try to parse a multi-line comment.
		if result, matched, err = jsOrGoMultiLineComment.Parse(pi); err != nil {
			return
		}
		if matched {
			sb.WriteString(result)
			continue
		}

		// Try to read a string literal.
		if result, matched, err = string_lit.Parse(pi); err != nil {
			return
		}
		if matched {
			sb.WriteString(result)
			continue
		}
		// Also try for a rune literal.
		if result, matched, err = rune_lit.Parse(pi); err != nil {
			return
		}
		if matched {
			sb.WriteString(result)
			continue
		}
		// Try opener.
		if result, matched, err = openBrace.Parse(pi); err != nil {
			return
		}
		if matched {
			braceCount++
			sb.WriteString(result)
			continue
		}
		// Try closer.
		startOfCloseBrace := pi.Index()
		if result, matched, err = closeBraceWithOptionalPadding.Parse(pi); err != nil {
			return
		}
		if matched {
			braceCount--
			if braceCount < 0 {
				err = parse.Error("expression: too many closing braces", pi.Position())
				return
			}
			if braceCount == 0 {
				pi.Seek(startOfCloseBrace)
				break loop
			}
			sb.WriteString(result)
			continue
		}

		// Read anything else.
		var c string
		c, matched = pi.Take(1)
		if !matched {
			break loop
		}
		if rune(c[0]) == 65533 { // Invalid Unicode.
			break loop
		}
		sb.WriteString(c)
	}
	if braceCount != 0 {
		err = parse.Error("expression: unexpected brace count", pi.Position())
		return
	}

	return NewExpression(sb.String(), from, pi.Position()), true, nil
}

// Letters and digits

var octal_digit = parse.RuneIn("01234567")
var hex_digit = parse.RuneIn("0123456789ABCDEFabcdef")

// https://go.dev/ref/spec#Rune_literals

var rune_lit = parse.StringFrom(
	parse.Rune('\''),
	parse.StringFrom(parse.Until(
		parse.Any(unicode_value_rune, byte_value),
		parse.Rune('\''),
	)),
	parse.Rune('\''),
)
var unicode_value_rune = parse.Any(little_u_value, big_u_value, escaped_char, parse.RuneNotIn("'"))

// byte_value       = octal_byte_value | hex_byte_value .
var byte_value = parse.Any(octal_byte_value, hex_byte_value)

// octal_byte_value = `\` octal_digit octal_digit octal_digit .
var octal_byte_value = parse.StringFrom(
	parse.String(`\`),
	octal_digit, octal_digit, octal_digit,
)

// hex_byte_value   = `\` "x" hex_digit hex_digit .
var hex_byte_value = parse.StringFrom(
	parse.String(`\x`),
	hex_digit, hex_digit,
)

// little_u_value   = `\` "u" hex_digit hex_digit hex_digit hex_digit .
var little_u_value = parse.StringFrom(
	parse.String(`\u`),
	hex_digit, hex_digit,
	hex_digit, hex_digit,
)

// big_u_value      = `\` "U" hex_digit hex_digit hex_digit hex_digit
var big_u_value = parse.StringFrom(
	parse.String(`\U`),
	hex_digit, hex_digit, hex_digit, hex_digit,
	hex_digit, hex_digit, hex_digit, hex_digit,
)

// escaped_char     = `\` ( "a" | "b" | "f" | "n" | "r" | "t" | "v" | `\` | "'" | `"` ) .
var escaped_char = parse.StringFrom(
	parse.Rune('\\'),
	parse.Any(
		parse.Rune('a'),
		parse.Rune('b'),
		parse.Rune('f'),
		parse.Rune('n'),
		parse.Rune('r'),
		parse.Rune('t'),
		parse.Rune('v'),
		parse.Rune('\\'),
		parse.Rune('\''),
		parse.Rune('"'),
	),
)

// https://go.dev/ref/spec#String_literals

var string_lit = parse.Any(parse.String(`""`), parse.String(`''`), interpreted_string_lit, raw_string_lit)

var interpreted_string_lit = parse.StringFrom(
	parse.Rune('"'),
	parse.StringFrom(parse.Until(
		parse.Any(unicode_value_interpreted, byte_value),
		parse.Rune('"'),
	)),
	parse.Rune('"'),
)
var unicode_value_interpreted = parse.Any(little_u_value, big_u_value, escaped_char, parse.RuneNotIn("\n\""))

var raw_string_lit = parse.StringFrom(
	parse.Rune('`'),
	parse.StringFrom(parse.Until(
		unicode_value_raw,
		parse.Rune('`'),
	)),
	parse.Rune('`'),
)
var unicode_value_raw = parse.Any(little_u_value, big_u_value, escaped_char, parse.RuneNotIn("`"))
