/*
 * Module: rl_fw.cc
 *
 * **** License ****
 * Version: VPL 1.0
 *
 * The contents of this file are subject to the Vyatta Public License
 * Version 1.0 ("License"); you may not use this file except in
 * compliance with the License. You may obtain a copy of the License at
 * http://www.vyatta.com/vpl
 *
 * Software distributed under the License is distributed on an "AS IS"
 * basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See
 * the License for the specific language governing rights and limitations
 * under the License.
 *
 * This code was originally developed by Vyatta, Inc.
 * Portions created by Vyatta are Copyright (C) 2005, 2006, 2007 Vyatta, Inc.
 * All Rights Reserved.
 *
 * Author: Michael Larson
 * Date: 2005
 * Description:
 *
 * **** End License ****
 *
 */

#include "../rl_firewall/rl_firewall_module.h"
#include "libxorp/xlog.h"
#include "rl_command.hh"
#include "rl_fw.hh"


/**
 *
 **/
FWData::FWData() :
  _established_state(false),
  _new_state(false),
  _related_state(false),
  _invalid_state(false),
  _rate_limit(0),
  _rate_limit_burst(0),
  _source_port_number(0),
  _source_port_start(0),
  _source_port_stop(0),
  _dest_port_number(0),
  _dest_port_start(0),
  _dest_port_stop(0),
  _active_rule(false),
  _active_log_rule(false)
{
}

/**
 *
 **/
FWData::~FWData()
{


}

/**
 *
 **/
XrlCmdError
FWData::get_iptables_cmds(const string &name, uint32_t rule_ct, vector<string> &coll)
{
  bool ip_range_flag = false;
  bool port_range_flag = false;
  string match_rule;

  if (_action.empty() == true) {
    return XrlCmdError::COMMAND_FAILED("Action must be specified");
  }

  if (_protocol.empty() == false) {
    match_rule += "--proto " + _protocol + " ";
  }
  if (_icmp_type.empty() == false) {
    if (strcasecmp(_protocol.c_str(), "icmp") != 0) {
      return XrlCmdError::COMMAND_FAILED("Protocol must be type ICMP");
    }
    match_rule += "--icmp-type " + _icmp_type;
    if (_icmp_code.empty() == false) {
      match_rule += "/" + _icmp_code;
    }
    match_rule += " ";
  }

  string tmp;
  if (_established_state == true) {
    tmp += "established,";
  }
  if (_new_state == true) {
    tmp += "new,";
  }
  if (_related_state == true) {
    tmp += "related,";
  }
  if (_invalid_state == true) {
    tmp += "invalid,";
  }
  if (tmp.empty() == false) {
    //truncate the last ','
    match_rule += "--match state --state " + tmp.substr(0, tmp.length()-1) + " ";
  }

  int ct = 0;
  if (_source_address.is_zero() == false) {
    match_rule += "--source " + _source_address.str() + " ";
    ++ct;
  }
  if (_source_network.is_valid() == true) {
    match_rule += "--source " + _source_network.str() + " ";
    ++ct;
  }
  if (_source_address_start.is_zero() == false) {
    match_rule += "-m iprange --src-range " + _source_address_start.str() + "-" + _source_address_stop.str() + " ";
    ip_range_flag = true;
    ++ct;
  }
  if (ct > 1) {
    match_rule = "";
    return XrlCmdError::COMMAND_FAILED("Only source address, network or start/stop range may be specified.");
  }

  //now time for source port
  ct = 0;
  char buf[80];
  if (_source_port_number > 0) {
    sprintf(buf, "%d", _source_port_number);
    match_rule += "--source-port " + string(buf) + " ";
    ++ct;
  }
  if (_source_port_name.empty() == false) {
    match_rule += "--source-port " + _source_port_name + " ";
    ++ct;
  }
  if (_source_port_start > 0) {
    if (_source_port_start > _source_port_stop) {
      return XrlCmdError::COMMAND_FAILED("Stop port is less than start port");
    }
    port_range_flag = true;
    //    match_rule += "-m multiport --source-port ";
    match_rule += "--source-port ";
    char buf[20];
    sprintf(buf, "%d", _source_port_start);
    match_rule += string(buf) + ":";
    sprintf(buf, "%d", _source_port_stop);
    match_rule += string(buf) + " ";
    ++ct;
  }
  if (ct > 1) {
    return XrlCmdError::COMMAND_FAILED("Only source port number, name or start/stop range may be specified.");
  }
  if (ct > 0 && (strcasecmp(_protocol.c_str(), "tcp") != 0 && strcasecmp(_protocol.c_str(), "udp") != 0)) {
    return XrlCmdError::COMMAND_FAILED("Protocol must be of type tcp or udp when specifying port");
  }

  ct = 0;
  if (_dest_address.is_zero() == false) {
    match_rule += "--destination " + _dest_address.str() + " ";
    ++ct;
  }
  if (_dest_network.is_valid() == true) {
    match_rule += "--destination " + _dest_network.str() + " ";
    ++ct;
  }
  if (_dest_address_start.is_zero() == false) {
    if (ip_range_flag == false) {
      match_rule += "-m iprange --dst-range " + _dest_address_start.str() + "-" + _dest_address_stop.str() + " ";
    }
    else {
      //don't specify to load the plugin for iprange then it was already specified...
      match_rule += "--dst-range " + _dest_address_start.str() + "-" + _dest_address_stop.str() + " ";
    }
    ++ct;
  }
  if (ct > 1) {
    match_rule = "";
    return XrlCmdError::COMMAND_FAILED("Only destination address,network or start/stop range may be specified.");
  }

  //now time for destination port
  ct = 0;
  if (_dest_port_number > 0) {
    sprintf(buf, "%d", _dest_port_number);
    match_rule += "--destination-port " + string(buf) + " ";
    ++ct;
  }
  if (_dest_port_name.empty() == false) {
    match_rule += "--destination-port " + _dest_port_name + " ";
    ++ct;
  }
  if (_dest_port_start > 0) {
    if (_dest_port_start > _dest_port_stop) {
      return XrlCmdError::COMMAND_FAILED("Stop port is less than start port");
    }
    //    match_rule += "-m multiport --destination-port ";
    match_rule += "--destination-port ";
    char buf[20];
    sprintf(buf, "%d", _dest_port_start);
    match_rule += string(buf) + ":";
    sprintf(buf, "%d", _dest_port_stop);
    match_rule += string(buf) + " ";
    ++ct;
  }

  if (ct > 1) {
    return XrlCmdError::COMMAND_FAILED("Only destination port number, name or start/stop range may be specified.");
  }

  if (ct > 0 && (strcasecmp(_protocol.c_str(), "tcp") != 0 && strcasecmp(_protocol.c_str(), "udp") != 0)) {
    return XrlCmdError::COMMAND_FAILED("Protocol must be of type tcp or udp when specifying port");
  }
  
  //now handle the log enable update to the command
  /*
    The easiest solution I can think of is to do away with the LOGxxx 
    sub-chains.  The LOGxxx subchains don't allow us to log individual rule 
    names and they don't support jumping to the main chains.  So any rule 
    that has "log enable" will have the following line in addition to any 
    action, if one is defined.
    
    iptables --append USERCHAIN  "match rules" --jump LOG --log-prefix 
    "userchain rule-number action "
    
    Note that the "--log-prefix" option must come after the "--jump LOG" 
    target on the command line.
    
    If a chain then defines an action, iptables would have the following, as 
    normal.
    
    iptables --append USERCHAIN "match rules" --jump ACTION
    
    where ACTION = (DROP, REJECT, RETURN)
    
    This means that there will be two iptables rules for any rule number 
    that wants an action and logging.  
  */

  //remove previous logging rule
  if (_log != "ENABLE" && _active_log_rule == true) {
    sprintf(buf, "%d", rule_ct);
    string tmp = "iptables --delete " + name + " " + string(buf);
    coll.push_back(tmp);
    _active_log_rule = false;
  }
  
  string rule;
  if (_log == "ENABLE") {
    sprintf(buf, "%d ", rule_ct); //insert after...
    rule = match_rule + "--jump LOG --log-prefix '" + name + " " + string(buf) + " " + _action + " '";
    rule = "iptables --insert " + name + " " + string(buf) + " " + rule;
    coll.push_back(rule);
    ++rule_ct; // now that we've added a rule increment...

    if (_active_log_rule == true) {
      //clean up previous rule
      sprintf(buf, "%d", rule_ct+1); //now the rule after the newly inserted rule
      string tmp = "iptables --delete " + name + " " + string(buf);
      coll.push_back(tmp);
    }
    _active_log_rule = true;
  }

  //now attach the actual rule
  rule = match_rule;
  if (_action == "ACCEPT") {
    rule += "--jump RETURN "; //yet another special case when accept really is return
  }
  else {
    rule += "--jump " + _action + " ";
  }

  sprintf(buf, "%d", rule_ct);
  rule = "iptables --insert " + name + " " + string(buf) + " " + rule;
  coll.push_back(rule);

  if (_active_rule == true) {
    sprintf(buf, "%d", rule_ct+1); //now the rule after the newly inserted rule
    string tmp = "iptables --delete " + name + " " + string(buf);
    coll.push_back(tmp);
  }

  _active_rule = true;

  _changed = false;
  return XrlCmdError::OKAY();
}

/**
 *
 **/
uint32_t
FWData::get_chain_ct() 
{
  if (_log == "ENABLE") {
    return 2;
  }
  return 1; 
}



/**
 *
 **/
/*
FirewallRule::FirewallRule() :
  _new(true)
{
}
*/
/**
 *
 **/
/*
void
FirewallRule::dump(ostream &os) const
{
  os << "[protocol] " <<  _protocol << endl;
  os << "[established_state] " << _established_state << endl;
  os << "[new_state] " <<  _new_state << endl;
  os << "[related_state] " <<  _related_state << endl;
  os << "[invalid_state] " << _invalid_state << endl;
  os << "[action] " << _action << endl;
  os << "[rate_limit] " << _rate_limit << endl;
  os << "[rate_interval] " << _rate_interval << endl;
  os << "[rate_limit_burst] " << _rate_limit_burst << endl;
  //and more to go here
}
*/

/*
  Need default rule:
  IPTables -A foo -s 127.0.0.1 -d 127.0.0.1 -j ACCEPT
 */


FWChain::FWChain()
{
  FWData data;
  data._action = "DROP";
  data._changed = true;
  pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (1025, data));
}

/**
 *
 **/
FWChain::FirewallColl
FWChain::expose()
{
  return _fw_chain;
}


/**
 *
 **/
bool
FWChain::is_empty() 
{
  XLOG_ERROR("RLFirewallNode::is_empty: chain size is: %d", _fw_chain.size());
  return (_fw_chain.size() == 1);
}


/**
 *
 **/
FWData
FWChain::get(int &rule_number)
{
  rule_number = 0;
  return FWData();
}

/**
 *
 **/
XrlCmdError
FWChain::remove(const string &name, uint32_t rule_number, vector<string> &coll)
{
  //remove entry from table and return command to remove
  int ct = 1;
  FirewallIter iter = _fw_chain.begin();
  while (iter != _fw_chain.end()) {
    if (rule_number == iter->first) {
      int rule_num = iter->second.get_chain_ct();
      char buf[80];
      sprintf(buf, "%d", ct);
      for (int i = 0; i < rule_num; ++i) {
	string tmp = "iptables --delete " + name + " " + string(buf);
	coll.push_back(tmp);
      }
      
      //end of rule chain, remove
      //re-enable when there is a chance
      if (_fw_chain.size() - rule_num == 1) {
	string tmp = "iptables --delete " + name + " 1";
	coll.push_back(tmp);
      }

      break;
    }
    ct += iter->second.get_chain_ct();
    ++iter;
  }
  
  _fw_chain.erase(rule_number);

  return XrlCmdError::OKAY();
}

/**
 *
 **/
XrlCmdError
FWChain::get_iptables_cmds(const string &name, uint32_t rule_number, vector<string> &coll)
{
  UNUSED(rule_number);
  /*
    Need to compute the position in the rule_list
  */
  int ct = 1;
  FirewallIter iter = _fw_chain.begin();
  while (iter != _fw_chain.end()) {
    if (iter->second._changed == true) {
      //ct provides the total to construct the right insertion statement
      XrlCmdError err = iter->second.get_iptables_cmds(name, ct, coll);
      if (err != XrlCmdError::OKAY()) {
	return err;
      }
    }
    ct += iter->second.get_chain_ct();
    ++iter;
  }
  return XrlCmdError::OKAY();
}

/**
 *
 **/
void 
FWChain::set_protocol(const string protocol, uint32_t rule_number)
{
  //need to grab rule and make the change...
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._protocol = protocol;
  iter->second._changed = true;
}

/**
 *
 **/
void
FWChain::set_icmp_type(const string type, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._icmp_type = type;
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_icmp_code(const string code, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._icmp_code = code;
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_state_established(bool state, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._established_state = state;
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_state_new(bool state, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._new_state = state;
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_state_related(bool state, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._related_state = state;
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_state_invalid(bool state, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._invalid_state = state;
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_action(const string action, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._action = rl_utils::to_upper(action);
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_rule_log(const string log, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._log = rl_utils::to_upper(log);
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_source_address(const IPv4 address, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._source_address = address;
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_source_network(const IPv4Net network, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._source_network = rlIPv4Net(network);
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_source_address(const IPv4 start, const IPv4 stop, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._source_address_start = start;
  iter->second._source_address_stop = stop;
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_source_port_number(const uint32_t port_number, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._source_port_number = port_number;
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_source_port_name(const string port_name, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._source_port_name = port_name;
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_source_port(uint32_t start, uint32_t stop, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._source_port_start = start;
  iter->second._source_port_stop = stop;
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_dest_address(const IPv4 address, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._dest_address = address;
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_dest_network(const IPv4Net network, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._dest_network = network;
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_dest_address(const IPv4 start, const IPv4 stop, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._dest_address_start = start;
  iter->second._dest_address_stop = stop;
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_dest_port_number(const uint32_t port_number, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._dest_port_number = port_number;
  iter->second._changed = true;
}

/**
 *
 **/
  void
  FWChain::set_dest_port_name(const string port_name, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._dest_port_name = port_name;
  iter->second._changed = true;
}

/**
 *
 **/
void
FWChain::set_dest_port(uint32_t start, uint32_t stop, uint32_t rule_number)
{
  FirewallIter iter = _fw_chain.find(rule_number);
  if (iter == _fw_chain.end()) {
    pair<FirewallIter, bool> val = _fw_chain.insert(pair<uint32_t, FWData> (rule_number, FWData()));
    iter = val.first;
  }
  iter->second._dest_port_start = start;
  iter->second._dest_port_stop = stop;
  iter->second._changed = true;
}

