#!/usr/bin/env /usr/bin/python
# Why don't people use shebangs anymore. It's muppetry I say
# -*- coding: utf-8 -*-

"""
  FileResolver - example resolver which responds with fixed response
          to all requests
"""

from __future__ import print_function

import copy
import os
import os.path
import base64
import time

from dnslib import RR, QTYPE, TXT, CLASS, RCODE
from dnslib.server import DNSServer,DNSHandler,BaseResolver,DNSLogger
from dnslib.label import DNSLabel

class FileResolver(BaseResolver):
  """
    Respond with fixed response to all requests
  """
  def __init__(self,directory,domain):
    self.filelist=[]
    self.domain=domain
    self.ttl=60
    self.directory=directory
    self.cache={}
    if directory:
      if not os.path.isdir(directory):
        print("Directory " + directory + " doesn't exist")
        exit()

  def getcache(self,path):
    cooked=""
    pname=os.path.basename(path);
    if pname in self.cache:
      # it's in the cache - so lets just read it from the cache
      print("Taking from cache")
      cooked=self.cache[pname]["base64"]
      # And update time
      self.cache[pname]["time"]=time.time()
    else:
      # It's not in cache - read the file and cache it
      print("Adding to cache")
      fin=open(path,"rb")
      raw=fin.read()
      fin.close()
      cooked=base64.b64encode(raw)
      self.cache[pname]={}
      self.cache[pname]["base64"]=cooked
      self.cache[pname]["time"]=time.time() 
    
    self.cachecheck()
    return cooked

  def cachecheck(self):
    # Remove any entries in the cache over an hour old
    curtime = time.time()
    staletime = 3600
    for key in self.cache:
      if curtime-self.cache[key]["time"] > staletime:
        # it's stale
        del self.cache[key]

  def resolve(self,request,handler):
    type=QTYPE[request.q.qtype]
    name=request.q.qname
    reply = request.reply()

    if type == "TXT":
      # Check for Chaos stuff first
      if request.q.qclass == CLASS.CH:
        if name == "version.bind":
          reply.add_answer(*RR.fromZone('version.bind 60 CH TXT "Uninvited Guest 0.3"'))
        if name == "authors.bind":
          reply.add_answer(*RR.fromZone('authors.bind 60 CH TXT "David Lodge"'))
          reply.add_answer(*RR.fromZone('authors.bind 60 CH TXT "Ian Williams"'))
        if not reply.rr:
          reply.header.rcode = RCODE.NXDOMAIN
        return reply

      # Format is filename.count.domain for count
      # First check domain is at the end of the name
      if not name.matchSuffix(self.domain):
        reply.add_answer(RR(name, QTYPE.TXT, ttl=self.ttl,
                         rdata=TXT("Domain not found")))
        return reply

      # Now look whether we're looking for a count or a part
      parts=str(name.stripSuffix(self.domain)).split(".")
      # parts[0] should be filename, parts[1] should be segment or count
      pname='.'.join(parts[:-2])
      path=self.directory + "/" + pname
      if not os.path.isfile(path):
        reply.add_answer(RR(name, QTYPE.TXT, ttl=self.ttl,
                  rdata=TXT("File not found")))
        return reply

      # work out the count mathematically
      # First work out the base64 size
      l=os.path.getsize(path)
      l=(4*l)/3
      # And padding
      l+=(l%4) 
      # Finally divide into number of 254 byte chunks
      chunks=(l/254)

      if ''.join(parts[-2:-1]) == "count":
        reply.add_answer(RR(name, QTYPE.TXT, ttl=self.ttl,
                  rdata=TXT(str(chunks))))
        return reply
       
      if ''.join(parts[-2:-1]).isdigit():
        # Woo it's a number
        # lets base64 the file
        chunk=int(''.join(parts[-2:-1])) 
        if chunk > chunks or chunk < 0:
          reply.add_answer(RR(name, QTYPE.TXT, ttl=self.ttl,
                           rdata=TXT("Chunk out of range")))
          return reply
        # It's actually trying to read the file, so let's cache it in memory
        cooked=self.getcache(path)

        # Now lets just grab the chunk
        start=chunk*254
        txtr=cooked[start:start+254]
        reply.add_answer(RR(name, QTYPE.TXT, ttl=self.ttl,
                         rdata=TXT(txtr)))
        return reply
 
    # Replace labels with request label
    return reply

if __name__ == '__main__':

  import argparse,sys,time

  p = argparse.ArgumentParser(description="Fixed DNS Resolver")
  p.add_argument("--directory","-r",
          metavar="<directory>",
          help="Directory to pull files from")
  p.add_argument("--domain","-d",
          metavar="<domain>",
          help="Domain we're working in")
  p.add_argument("--port","-p",type=int,default=53,
          metavar="<port>",
          help="Server port (default:53)")
  p.add_argument("--address","-a",default="",
          metavar="<address>",
          help="Listen address (default:all)")
  p.add_argument("--udplen","-u",type=int,default=0,
          metavar="<udplen>",
          help="Max UDP packet length (default:0)")
  p.add_argument("--tcp",action='store_true',default=False,
          help="TCP server (default: UDP only)")
  p.add_argument("--log",default="request,reply,truncated,error",
          help="Log hooks to enable (default: +request,+reply,+truncated,+error,-recv,-send,-data)")
  p.add_argument("--log-prefix",action='store_true',default=False,
          help="Log prefix (timestamp/handler/resolver) (default: False)")
  args = p.parse_args()
  
  resolver = FileResolver(args.directory,args.domain)
  logger = DNSLogger(args.log,args.log_prefix)

  print("Starting File Resolver (%s:%d) [%s]" % (
            args.address or "*",
            args.port,
            "UDP/TCP" if args.tcp else "UDP"))

  if args.udplen:
    DNSHandler.udplen = args.udplen

  udp_server = DNSServer(resolver,
               port=args.port,
               address=args.address,
               logger=logger)
  udp_server.start_thread()

  if args.tcp:
    tcp_server = DNSServer(resolver,
                 port=args.port,
                 address=args.address,
                 tcp=True,
                 logger=logger)
    tcp_server.start_thread()

  while udp_server.isAlive():
    time.sleep(1)