import { dot, modulus, norm, phase, rotate } from "./complex"
import { forwardFft } from "./fft"

const PILOT_BASE_SEQUENCE = [
    { i: 1, q: 0 },
    { i: 1, q: 0 },
    { i: 1, q: 0 },
    { i: -1, q: 0 }
]
const PILOT_POLARITY = [
    1, 1, 1, 1,
    -1, -1, -1, 1,
    -1, -1, -1, -1,
    1, 1, -1, 1,
    -1, -1, 1, 1,
    -1, 1, 1, -1,
    1, 1, 1, 1,
    1, 1, -1, 1,
    1, 1, -1, 1,
    1, -1, -1, 1,
    1, 1, -1, 1,
    -1, -1, -1, 1,
    -1, 1, -1, -1,
    1, -1, -1, 1,
    1, 1, 1, 1,
    -1, -1, 1, 1,
    -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, 1, -1, 1, -1, 1,
    -1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1
]

function interpolateAngle(a1, a2, alpha) {
    let z1 = {
        i: Math.cos(a1),
        q: Math.sin(a1)
    }
    let z2 = {
        i: Math.cos(a2),
        q: Math.sin(a2)
    }
    let z = {
        i: alpha * z2.i + (1 - alpha) * z1.i,
        q: alpha * z2.q + (1 - alpha) * z1.q,
    }

    return phase(z)
    const p = 2 * Math.PI
    while (a1 < 0)
        a1 += p
    while (a2 < 0)
        a2 += p
    if (a1 > a2)
        a2 += p
    if (a2 - a1 > p / 2)
        a2 -= p
    const f = alpha //* alpha * alpha * alpha
    let result = f * a2 + (1 - f) * a1
    return result
}

const BPSK_CONSTELLATION = [
    {
        name: "1",
        i: 1,
        q: 0
    },
    {
        name: "0",
        i: -1,
        q: 0
    }
]
function decodeConstellationPoint(point, constellation = BPSK_CONSTELLATION) {
    let bestScore = Number.MAX_VALUE, bestConstellationPoint = null
    for (let c of constellation) {
        const score = modulus({
            i: point.i - c.i,
            q: point.q - c.q
        })
        if (score < bestScore) {
            bestScore = score
            bestConstellationPoint = c
        }
    }
    return {
        ...point,
        point: bestConstellationPoint,
        score: bestScore
    }
}

function fitPhaseShift(inputs, outputs) {
    console.log('fit phase shift', inputs, outputs)
    return carrier => {
        let lowIdx = -1, highIdx = -1
        for (let idx = 0; idx < inputs.length; idx++) {
            const delta = carrier - inputs[idx]
            if (delta >= 0) {
                if (lowIdx < 0 || Math.abs(delta) < Math.abs(carrier - inputs[lowIdx]))
                    lowIdx = idx
            } else {
                if (!highIdx || Math.abs(delta) < Math.abs(carrier - inputs[highIdx]))
                    highIdx = idx
            }
        }
        // lowIdx = 1
        // highIdx = 2
        if (lowIdx < 0 && highIdx < 0) {
            return 0
        } else if (lowIdx < 0) {
            return outputs[highIdx]
        } else if (highIdx < 0) {
            return outputs[lowIdx]
        } else {
            let alpha, angle
            // const alpha = Math.round((carrier - lowCorrector.carrier) / (highCorrector.carrier - lowCorrector.carrier))
            alpha = ((carrier - inputs[lowIdx]) / (inputs[highIdx] - inputs[lowIdx]))
            let lowAngle = outputs[lowIdx]
            let highAngle = outputs[highIdx]
            angle = alpha * highAngle + (1 - alpha) * lowAngle
            angle = interpolateAngle(lowAngle, highAngle, alpha)
            return angle
        }
    }

}

export default {
    name: 'ofdm2',
    apply(samples, tr, datasets) {
        const defaultConfig = {
            header: 0,
            size: 64,
            gi: 16,
            pilots: [
                -21,
                -7,
                7,
                21
            ],

            data: [
                -26, -25, -24, -23, -22, // -21 : pilot
                -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, // -7 : pilot
                -6, -5, -4, -3, -2, -1,
                // 0,
                1, 2, 3, 4, 5, 6, // 7 : pilot
                8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, // 21 : pilot
                22, 23, 24, 25, 26
            ]
        }
        let { size, gi, pilots, data } = { ...defaultConfig, ...tr } as any
        const symbolLength = gi + size
        if (!this.samples) {
            this.samples = []
        }
        function symbolDetector(samples, offset) {
            const v1 = samples.slice(offset, offset + gi)
            const v2 = samples.slice(offset + size, offset + gi + size)
            const n1 = norm(v1)
            const n2 = norm(v2)
            const n1n2 = n1 * n2
            let { i, q } = dot(v1, v2)
            i /= n1n2
            q /= n1n2
            const n = Math.sqrt(i * i + q * q)
            return n
        }
        let result = []
        samples = [...this.samples, ...samples]

        let o = 0
        let symbolIndex = 0
        while (o + symbolLength <= samples.length) {
            if (true || symbolDetector(samples, o) > 0.8) {
                const symbol = samples.slice(o, o + symbolLength)
                const guardInterval = symbol.slice(0, gi)
                const cursor = 0
                const dataSamples = symbol.slice(gi - cursor, gi + size - cursor)

                const spectrum = forwardFft(dataSamples)
                console.log(symbolIndex, 'ofdm symbol at', o, symbolDetector(samples, o),)

                const pilotSpectrum = pilots.map(idx => ({ carrier: idx, carrierType: 'pilot', ...spectrum[(idx + size) % size] }))
                const expectedPilotSpectrum = PILOT_BASE_SEQUENCE.map(({ i, q }, idx) => {
                    // const nPilots = PILOT_POLARITY.length
                    // const sgn = PILOT_POLARITY[(symbolIndex * 4 + idx) % (nPilots)]
                    const sgn = PILOT_POLARITY[(symbolIndex) % (PILOT_POLARITY.length)]
                    return {
                        i: i * sgn,
                        q: q * sgn
                    }
                })
                const pilotCorrectors = pilotSpectrum.map((pilot, index) => {
                    const expected = expectedPilotSpectrum[index]
                    let angle = phase(expected) - phase(pilot)
                    while (angle < 0)
                        angle += 2 * Math.PI
                    const n = modulus(expected) / modulus(pilot)
                    return { angle, norm: n, carrier: pilot.carrier }
                })
                console.log("pilot correctors", pilotCorrectors)
                const estimateAngle = fitPhaseShift(pilotCorrectors.map(s => s.carrier), pilotCorrectors.map(s => s.angle))
                function correctSample(s) {
                    const { carrier } = s
                    let angle, norm, alpha = null, lowAngle = null, highAngle = null

                    angle = estimateAngle(carrier)



                    let corrected = rotate(s, angle)
                    // corrected.i = corrected.i * norm
                    // corrected.q = corrected.q * norm
                    corrected.alpha = alpha
                    corrected.angle = angle

                    return corrected
                }

                const dataSpectrum = data.map(idx => ({ carrier: idx, carrierType: 'data', ...spectrum[(idx + size) % size] }))

                result = [...result,
                ...pilotSpectrum.map(correctSample),
                ...dataSpectrum.map(correctSample)
                ]

                o += symbolLength
                symbolIndex++
            } else {
                o++
            }
        }

        const carrierAnalysis = {}
        for (let sample of result) {
            const { carrier, i, q } = sample
            if (!carrierAnalysis[carrier]) {
                carrierAnalysis[carrier] = {
                    samples: [],
                    sum: { i: 0, q: 0 },
                    sumn: 0
                }
            }
            const a = carrierAnalysis[carrier]
            a.samples.push(sample)
            a.sum.i += i
            a.sum.i += q
            a.sumn += Math.sqrt(i * i + q * q)
        }
        for (let carrier in carrierAnalysis) {
            const a = carrierAnalysis[carrier]
            const n = a.samples.length
            a.meann = a.sumn / n
            a.mean = {
                i: a.sum.i / n,
                q: a.sum.q.n
            }
        }

        result = result.map(s => {
            // return s
            const { carrier } = s
            if (carrierAnalysis && carrierAnalysis[carrier]) {
                const { meann } = carrierAnalysis[carrier]
                s.i = s.i / meann
                s.q = s.q / meann
            }
            return s
        })
        const lines = [
            'carrier,offset'
        ]
        for (let point of result) {
            const decoded = decodeConstellationPoint(point)
            let dangle = phase(point) - phase(decoded.point)
            const p = 2 * Math.PI
            if (dangle < 0)
                dangle += p
            if (dangle > p)
                dangle -= p
            if (dangle < p / 2)
                dangle += p
            lines.push(`${point.carrier},${dangle}`)
        }
        console.log(lines.join("\n"))
        this.samples = samples.slice(o)
        return result//.filter(s=>s.db > -300)
    },
    schema: {
        header: {
            type: 'number'
        },
        size: {
            type: 'number'
        },
        gi: {
            type: 'number'
        },
        pilots: { type: 'array', items: 'number' },
        data: { type: 'array', items: 'number' }
    }
}