#include "ezs_dma.h"

#include <assert.h>

#include <cyg/hal/hal_io.h>
#include <cyg/hal/hal_if.h>

void ezs_dma_set_channel(EZS_DMA *dma, uint8_t channel);
void ezs_dma_set_ports(EZS_DMA *dma);
int ezs_dma_is_port_address_valid(EZS_DMA *dma);
int ezs_dma_is_count_port_valid(EZS_DMA *dma);
int ezs_dma_is_page_port_valid(EZS_DMA *dma);
int ezs_dma_is_mask_register_valid(EZS_DMA *dma);
int ezs_dma_is_clear_register_valid(EZS_DMA *dma);
int ezs_dma_is_mode_register_valid(EZS_DMA *dma);

void ezs_dma_init(EZS_DMA *dma)
{
  dma->channel = 100;

  dma->mode_byte = 0;

  dma->port_address = 0;
  dma->count_port = 0;
  dma->page_port = 0;

  dma->mask_register = 0;
  dma->clear_register = 0;
  dma->mode_register = 0;
}

void ezs_dma_set_mode_byte(EZS_DMA *dma,
                              uint8_t channel,
                              EZS_DMA_Mode mode,
                              EZS_DMA_Direction direction,
                              EZS_DMA_Initialization auto_initialize,
                              EZS_DMA_Transfer_Type transfer_type)
{
  assert(channel <= 7 && "DMA channel number must be <= 7");

  switch (mode)
  {
    case EZS_DMA_Demand:
      dma->mode_byte &= ~(0x1 << 6);
      dma->mode_byte &= ~(0x1 << 7);
      break;
    case EZS_DMA_Single:
      dma->mode_byte |= (0x1 << 6);
      dma->mode_byte &= ~(0x1 << 7);
      break;
    case EZS_DMA_Block:
      dma->mode_byte &= ~(0x1 << 6);
      dma->mode_byte |= (0x1 << 7);
      break;
    case EZS_DMA_Cascade:
      dma->mode_byte |= (0x1 << 6);
      dma->mode_byte |= (0x1 << 7);
      break;
  }

  switch (direction)
  {
    case EZS_DMA_Decrement:
      dma->mode_byte |= (0x1 << 5);
      break;
    case EZS_DMA_Increment:
      dma->mode_byte &= ~(0x1 << 5);
      break;
  }

  switch (auto_initialize)
  {
    case EZS_DMA_Single_Cycle:
      dma->mode_byte &= ~(0x1 << 4);
      break;
    case EZS_DMA_Auto_Initialization:
      dma->mode_byte |= (0x1 << 4);
      break;
  }

  switch (transfer_type)
  {
    case EZS_DMA_Verify:
      dma->mode_byte &= ~(0x1 << 2);
      dma->mode_byte &= ~(0x1 << 3);
      break;
    case EZS_DMA_Write:
      dma->mode_byte |= (0x1 << 2);
      dma->mode_byte &= ~(0x1 << 3);
      break;
    case EZS_DMA_Read:
      dma->mode_byte &= ~(0x1 << 2);
      dma->mode_byte |= (0x1 << 3);
      break;
  }

  ezs_dma_set_channel(dma, channel);

  channel = dma->channel % 4;
  switch (channel)
  {
    case 0:
      dma->mode_byte &= ~(0x1 << 0);
      dma->mode_byte &= ~(0x1 << 1);
      break;
    case 1:
      dma->mode_byte |= (0x1 << 0);
      dma->mode_byte &= ~(0x1 << 1);
      break;
    case 2:
      dma->mode_byte &= ~(0x1 << 0);
      dma->mode_byte |= (0x1 << 1);
      break;
    case 3:
      dma->mode_byte |= (0x1 << 0);
      dma->mode_byte |= (0x1 << 1);
      break;
  }
  ezs_dma_set_ports(dma);
}
void ezs_dma_write_mode_byte(EZS_DMA *dma)
{
  assert(ezs_dma_is_mode_register_valid(dma));

  HAL_WRITE_UINT8(dma->mode_register, dma->mode_byte);
}

void ezs_dma_set_channel(EZS_DMA *dma, uint8_t channel)
{
  assert(channel <= 7 && "DMA channel number must be <= 7");
  dma->channel = channel;
}

int ezs_dma_is_port_address_valid(EZS_DMA *dma)
{
  if (dma->port_address == 0x00) return 1;
  if (dma->port_address == 0x02) return 1;
  if (dma->port_address == 0x04) return 1;
  if (dma->port_address == 0x06) return 1;
  if (dma->port_address == 0xC0) return 1;
  if (dma->port_address == 0xC4) return 1;
  if (dma->port_address == 0xC8) return 1;
  if (dma->port_address == 0xCC) return 1;

  return 0;
}

int ezs_dma_is_count_port_valid(EZS_DMA *dma)
{
  if (dma->count_port == 0x01) return 1;
  if (dma->count_port == 0x03) return 1;
  if (dma->count_port == 0x05) return 1;
  if (dma->count_port == 0x07) return 1;
  if (dma->count_port == 0xC2) return 1;
  if (dma->count_port == 0xC6) return 1;
  if (dma->count_port == 0xCA) return 1;
  if (dma->count_port == 0xCE) return 1;

  return 0;
}

int ezs_dma_is_page_port_valid(EZS_DMA *dma)
{
  if (dma->page_port == 0x81) return 1;
  if (dma->page_port == 0x82) return 1;
  if (dma->page_port == 0x83) return 1;
  if (dma->page_port == 0x87) return 1;
  if (dma->page_port == 0x89) return 1;
  if (dma->page_port == 0x8A) return 1;
  if (dma->page_port == 0x8B) return 1;
  if (dma->page_port == 0x8F) return 1;

  return 0;
}

int ezs_dma_is_mask_register_valid(EZS_DMA *dma)
{
  if (dma->mask_register == 0xA) return 1;
  if (dma->mask_register == 0xD4) return 1;

  return 0;
}

int ezs_dma_is_clear_register_valid(EZS_DMA *dma)
{
  if (dma->clear_register == 0xC) return 1;
  if (dma->clear_register == 0xD8) return 1;

  return 0;
}

int ezs_dma_is_mode_register_valid(EZS_DMA *dma)
{
  if (dma->mode_register == 0xB) return 1;
  if (dma->mode_register == 0xD6) return 1;

  return 0;
}

void ezs_dma_set_ports(EZS_DMA *dma)
{
  assert(dma->channel <= 7 && "DMA channel number must be <= 7");

  switch (dma->channel)
  {
    /*
     * 8 bit channels
     */
    case 0: dma->port_address = 0x00; dma->count_port = 0x01; dma->page_port = 0x87; break;
    case 1: dma->port_address = 0x02; dma->count_port = 0x03; dma->page_port = 0x83; break;
    case 2: dma->port_address = 0x04; dma->count_port = 0x05; dma->page_port = 0x81; break;
    case 3: dma->port_address = 0x06; dma->count_port = 0x07; dma->page_port = 0x82; break;
    /*
     * 16 bit channels
     */
    case 4: dma->port_address = 0xC0; dma->count_port = 0xC2; dma->page_port = 0x8F; break;
    case 5: dma->port_address = 0xC4; dma->count_port = 0xC6; dma->page_port = 0x8B; break;
    case 6: dma->port_address = 0xC8; dma->count_port = 0xCA; dma->page_port = 0x89; break;
    case 7: dma->port_address = 0xCC; dma->count_port = 0xCE; dma->page_port = 0x8A; break;
  }
  if (dma->channel < 4) {
    dma->mask_register = 0xA;
    dma->clear_register = 0xC;
    dma->mode_register = 0xB;
    dma->eight_bit = 1;
  } else {
    dma->mask_register = 0xD4;
    dma->clear_register = 0xD8;
    dma->mode_register = 0xD6;
    dma->channel -= 4;
    dma->eight_bit = 0;
  }

  assert(ezs_dma_is_port_address_valid(dma));
  assert(ezs_dma_is_count_port_valid(dma));
  assert(ezs_dma_is_page_port_valid(dma));
  assert(ezs_dma_is_mask_register_valid(dma));
  assert(ezs_dma_is_clear_register_valid(dma));
  assert(ezs_dma_is_mode_register_valid(dma));
}

void ezs_dma_enable_channel(EZS_DMA *dma)
{
  assert(dma->channel <= 7 && "DMA channel number must be <= 7");

  uint8_t channel = dma->channel % 4;
  HAL_WRITE_UINT8(dma->mask_register, dma->channel);
}

void ezs_dma_disable_channel(EZS_DMA *dma)
{
  assert(dma->channel <= 7 && "DMA channel number must be <= 7");

  uint8_t channel = dma->channel % 4;
  HAL_WRITE_UINT8(dma->mask_register, channel | 0x4);
}

void ezs_dma_clear_flip_flop(EZS_DMA *dma)
{
  assert(ezs_dma_is_clear_register_valid(dma));

  HAL_WRITE_UINT8(dma->clear_register, 0);
}

void ezs_dma_set_buffer_info(EZS_DMA *dma, uint8_t page, void *buffer, uint16_t length)
{
  assert(ezs_dma_is_page_port_valid(dma));
  assert(ezs_dma_is_port_address_valid(dma));
  assert(ezs_dma_is_count_port_valid(dma));

  uint32_t offset = (int16_t) (((uint32_t) buffer) & 0xFFFF);
  if (!dma->eight_bit) {
    offset = (int16_t) ((((uint32_t) buffer) >> 1) & 0xFFFF);
  }
  HAL_WRITE_UINT8(dma->port_address, ezs_low_byte(offset));
  HAL_WRITE_UINT8(dma->port_address, ezs_high_byte(offset));

  HAL_WRITE_UINT8(dma->count_port, ezs_low_byte(length - 1));
  HAL_WRITE_UINT8(dma->count_port, ezs_high_byte(length - 1));

  HAL_WRITE_UINT8(dma->page_port, page);
}
