import * as fs from "fs";
import { Buffer } from "buffer";
import * as process from "process";

type Fmt = {
  audio_format: number;
  num_channels: number;
  sample_rate: number;
  bytes_per_sample: number;
};

async function iter_wav_chunks(
  filename: string,
  callback: (fmt: Fmt, data_chunk: { start: number; length: number }) => Promise<void>,
  opts?: { treat_maxsize_chunks_as_infinite?: boolean },
) {
  let offset = 0;
  const treat_maxsize_chunks_as_infinite = opts?.treat_maxsize_chunks_as_infinite ?? false;
  const get_uint = (num_bytes: number, endiannness: "big" | "little", explicit_offset?: number) => {
    if (explicit_offset !== undefined) offset = explicit_offset;
    const o = offset;
    // inefficient, but only used for header data
    const fd = fs.openSync(filename, "r");
    try {
      const buf = Buffer.alloc(num_bytes);
      if (num_bytes !== fs.readSync(fd, buf, 0, num_bytes, o))
        throw new Error(`Unable to read ${num_bytes} from '${filename}' at offset '${o}'`);
      offset += num_bytes;
      return endiannness === "big" ? buf.readUIntBE(0, num_bytes) : buf.readUIntLE(0, num_bytes);
    } finally {
      fs.closeSync(fd);
    }
  };
  const get_uint_cond = (
    num_bytes: number,
    endianness: "big" | "little",
    condition: (v: number) => boolean,
    msg: (v: number) => string,
    explicit_offset?: number,
  ) => {
    const v = get_uint(num_bytes, endianness, explicit_offset);
    if (!condition(v)) throw new Error(msg(v));
    return v;
  };
  const file_length = fs.statSync(filename).size;
  for (let riff_start = 0; riff_start < file_length; ) {
    offset = riff_start;
    // RIFF
    get_uint_cond(
      4,
      "big",
      (v) => v === 0x52494646,
      (v) => `Got unexpected chunk_id ${v}`,
    );
    const chunk_size = get_uint(4, "little");
    riff_start = offset + chunk_size;
    // WAVE
    get_uint_cond(
      4,
      "big",
      (v) => v === 0x57415645,
      (v) => `Got unexpected chunk format ${v}`,
    );
    // iterate over subchunks, break on 'data'
    let fmt: null | Fmt = null;
    const FMT_SUBCHUNK_ID = 0x666d7420;
    const DATA_SUBCHUNK_ID = 0x64617461;
    let have_data = false;
    while (offset < file_length && !have_data) {
      const subchunk_id = get_uint(4, "big");
      const subchunk_size = get_uint(4, "little");
      const end_offset = offset + subchunk_size;
      switch (subchunk_id) {
        case FMT_SUBCHUNK_ID:
          {
            const audio_format = get_uint(2, "little");
            if (audio_format !== 1)
              process.stderr.write(`Warning: got unexpected audio format ${audio_format}`);
            const num_channels = get_uint(2, "little");
            const sample_rate = get_uint(4, "little");
            const byte_rate = get_uint(4, "little");
            const block_align = get_uint(2, "little");
            const bits_per_sample = get_uint_cond(
              2,
              "little",
              (v) => v % 8 === 0,
              (v) => `Got unsupported bits_per_sample ${v} (should be a multiple of 8)`,
            );
            const bytes_per_sample = bits_per_sample / 8;
            if (block_align !== (num_channels * bits_per_sample) / 8)
              throw new Error(`Got unexpected block_align value (${block_align})`);
            if (byte_rate !== (sample_rate * num_channels * bits_per_sample) / 8)
              throw new Error(`Got unexpected byte_rate ${byte_rate}`);
            fmt = { audio_format, bytes_per_sample, num_channels, sample_rate };
          }
          break;
        case DATA_SUBCHUNK_ID:
          {
            if (!fmt) throw new Error(`Encountered data subchunk without preceding fmt subchunk`);
            const length =
              subchunk_size === Math.pow(2, 32) - 1 && treat_maxsize_chunks_as_infinite
                ? file_length - offset
                : subchunk_size;
            await callback(fmt, { start: offset, length });
            have_data = true;
          }
          break;
        default:
        // skip chunk
      }
      offset = end_offset;
    }
  }
}

export async function process_wav_file(
  filename: string,
  // passes native (signed integer) samples to callback if format is "native",
  // otherwise converts samples to floats \in [-1, 1]
  format: "native" | "normalized",
  callback: (
    c: Array<number>,
    prev: null | Array<number>,
    n_sample: number,
    metadata: {
      num_channels: number;
      sample_rate: number;
      bytes_per_sample: number;
    },
  ) => void,
  opts?: { treat_maxsize_chunks_as_infinite?: boolean },
) {
  await iter_wav_chunks(
    filename,
    async (fmt, data_chunk) => {
      const fd = fs.openSync(filename, "r");
      try {
        const chunksize_samples = 1 << 15;
        const bytes_per_instant = fmt.bytes_per_sample * fmt.num_channels;
        const chunksize_bytes = chunksize_samples * bytes_per_instant;
        let prev = Array.from(Array(fmt.num_channels).keys());
        let c = Array.from(Array(fmt.num_channels).keys());
        const buf = Buffer.alloc(chunksize_bytes);
        let read_offset = data_chunk.start;
        const end = data_chunk.start + data_chunk.length;
        const read_chunk = (): Promise<number> => {
          process.stderr.write(`Reading at offset ${read_offset}...\n`);
          return new Promise((resolve, reject) =>
            fs.read(
              fd,
              buf,
              0,
              Math.max(1, Math.min(chunksize_bytes, end - read_offset)),
              read_offset,
              (err, bytes_read) => {
                if (err) reject(err);
                else {
                  resolve(bytes_read);
                }
              },
            ),
          );
        };
        let i_total = 0;
        const pos_factor = 1 / ((1 << (fmt.bytes_per_sample * 8 - 1)) - 1);
        const neg_factor = 1 / (1 << (fmt.bytes_per_sample * 8 - 1));
        while (read_offset < end) {
          const bytes_read = await read_chunk();
          if (bytes_read === 0) {
            break;
          }
          if (bytes_read % bytes_per_instant !== 0) {
            process.stderr.write(
              `Warning: number of bytes read (${bytes_read}) is not divisible by ${bytes_per_instant}, maybe this file was truncated?\n`,
            );
          }
          read_offset += bytes_read;
          const N = Math.floor(bytes_read / bytes_per_instant);
          let offset = 0;
          switch (format) {
            case "native":
              for (let i = 0; i < N; ++i) {
                for (let n_channel = 0; n_channel < fmt.num_channels; ++n_channel) {
                  c[n_channel] = buf.readIntLE(offset, fmt.bytes_per_sample);
                  offset += fmt.bytes_per_sample;
                }
                i_total += 1;
                callback(c, i_total === 0 ? null : prev, i_total, fmt);
                [prev, c] = [c, prev];
              }
              break;
            case "normalized":
              for (let i = 0; i < N; ++i) {
                for (let n_channel = 0; n_channel < fmt.num_channels; ++n_channel) {
                  const i = buf.readIntLE(offset, fmt.bytes_per_sample);
                  c[n_channel] = i * (i >= 0 ? pos_factor : neg_factor);
                  offset += fmt.bytes_per_sample;
                }
                i_total += 1;
                callback(c, i_total === 0 ? null : prev, i_total, fmt);
                [prev, c] = [c, prev];
              }
          }
        }
      } finally {
        fs.closeSync(fd);
      }
    },
    opts ?? {},
  );
}
