diff --git a/server b/server index a105ff6..3c6e3d6 100755 --- a/server +++ b/server @@ -1,5 +1,4 @@ #!/usr/bin/env /usr/bin/python -# Why don't people use shebangs anymore. It's muppetry I say # -*- coding: utf-8 -*- """ @@ -7,187 +6,155 @@ to all requests """ -from __future__ import print_function - -import copy import os import os.path import base64 -import time -from dnslib import RR, QTYPE, TXT, CLASS, RCODE -from dnslib.server import DNSServer,DNSHandler,BaseResolver,DNSLogger -from dnslib.label import DNSLabel +from dnslib import RR, QTYPE, TXT +from dnslib.server import DNSServer, DNSHandler, BaseResolver, DNSLogger -class FileResolver(BaseResolver): - """ - Respond with fixed response to all requests - """ - def __init__(self,directory,domain): - self.filelist=[] - self.domain=domain - self.ttl=60 - self.directory=directory - self.cache={} - if directory: - if not os.path.isdir(directory): - print("Directory " + directory + " doesn't exist") - exit() - - def getcache(self,path): - cooked="" - pname=os.path.basename(path); - if pname in self.cache: - # it's in the cache - so lets just read it from the cache - print("Taking from cache") - cooked=self.cache[pname]["base64"] - # And update time - self.cache[pname]["time"]=time.time() - else: - # It's not in cache - read the file and cache it - print("Adding to cache") - fin=open(path,"rb") - raw=fin.read() - fin.close() - cooked=base64.b64encode(raw) - self.cache[pname]={} - self.cache[pname]["base64"]=cooked - self.cache[pname]["time"]=time.time() - - self.cachecheck() - return cooked - - def cachecheck(self): - # Remove any entries in the cache over an hour old - curtime = time.time() - staletime = 3600 - for key in self.cache: - if curtime-self.cache[key]["time"] > staletime: - # it's stale - del self.cache[key] - - def resolve(self,request,handler): - type=QTYPE[request.q.qtype] - name=request.q.qname - reply = request.reply() - - if type == "TXT": - # Check for Chaos stuff first - if request.q.qclass == CLASS.CH: - if name == "version.bind": - reply.add_answer(*RR.fromZone('version.bind 60 CH TXT "Uninvited Guest 0.3"')) - if name == "authors.bind": - reply.add_answer(*RR.fromZone('authors.bind 60 CH TXT "David Lodge"')) - reply.add_answer(*RR.fromZone('authors.bind 60 CH TXT "Ian Williams"')) - if not reply.rr: - reply.header.rcode = RCODE.NXDOMAIN - return reply - - # Format is filename.count.domain for count - # First check domain is at the end of the name - if not name.matchSuffix(self.domain): - reply.add_answer(RR(name, QTYPE.TXT, ttl=self.ttl, - rdata=TXT("Domain not found"))) - return reply - # Now look whether we're looking for a count or a part - parts=str(name.stripSuffix(self.domain)).split(".") - # parts[0] should be filename, parts[1] should be segment or count - pname='.'.join(parts[:-2]) - path=self.directory + "/" + pname - if not os.path.isfile(path): - reply.add_answer(RR(name, QTYPE.TXT, ttl=self.ttl, - rdata=TXT("File not found"))) +class FileResolver(BaseResolver): + """ + Respond with fixed response to all requests + """ + + def __init__(self, directory, domain): + self.filelist = [] + self.domain = domain + self.ttl = 60 + self.directory = directory + self.cache = {} + if directory: + if not os.path.isdir(directory): + print(f"Directory {directory} doesn't exist") + exit() + + def getcache(self, path): + pname = os.path.basename(path) + if pname in self.cache: + # it's in the cache - so lets just read it from the cache + print("Taking from cache") + cooked = self.cache[pname]["base64"] + # And update time + self.cache[pname]["time"] = time.time() + else: + # It's not in cache - read the file and cache it + print("Adding to cache") + fin = open(path, "rb") + raw = fin.read() + fin.close() + cooked = base64.b64encode(raw) + self.cache[pname] = {} + self.cache[pname]["base64"] = cooked + self.cache[pname]["time"] = time.time() + + self.cachecheck() + return cooked + + def cachecheck(self): + # Remove any entries in the cache over an hour old + curtime = time.time() + staletime = 3600 + for key in self.cache: + if curtime - self.cache[key]["time"] > staletime: + # it's stale + del self.cache[key] + + def resolve(self, request, handler): + qry_type = QTYPE[request.q.qtype] + name = request.q.qname + reply = request.reply() + + # First check domain is at the end of the name + if not name.matchSuffix(self.domain): + reply.add_answer(RR(name, QTYPE.TXT, ttl=self.ttl, rdata=TXT("Domain not found"))) + return reply + + # Request method for uploads + if qry_type == "SRV": + with open('/tmp/dns-srv.log', 'w') as f: + f.write(name.stripSuffix(self.domain)) + reply.add_answer(RR(name, QTYPE.SRV, ttl=self.ttl, rdata=TXT("accepted"))) + # Request method for downloads + if qry_type == "TXT": + # Format is filename.count.domain for count + if not name.matchSuffix(self.domain): + reply.add_answer(RR(name, QTYPE.TXT, ttl=self.ttl, rdata=TXT("Domain not found"))) + return reply + + # Now look whether we're looking for a count or a part + parts = str(name.stripSuffix(self.domain)).split(".") + # parts[0] should be filename, parts[1] should be segment or count + pname = '.'.join(parts[:-2]) + path = self.directory + "/" + pname + if not os.path.isfile(path): + reply.add_answer(RR(name, QTYPE.TXT, ttl=self.ttl, rdata=TXT("File not found"))) + return reply + + # work out the count mathematically + # First work out the base64 size + length = os.path.getsize(path) + length = (4 * length) / 3 + # And padding + length += (length % 4) + # Finally divide into number of 254 byte chunks + chunks = (length / 254) + + if ''.join(parts[-2:-1]) == "count": + reply.add_answer(RR(name, QTYPE.TXT, ttl=self.ttl, rdata=TXT(str(chunks)))) + return reply + + if ''.join(parts[-2:-1]).isdigit(): + # Woo it's a number + # lets base64 the file + chunk = int(''.join(parts[-2:-1])) + if chunk > chunks or chunk < 0: + reply.add_answer(RR(name, QTYPE.TXT, ttl=self.ttl, rdata=TXT("Chunk out of range"))) + return reply + # It's actually trying to read the file, so let's cache it in memory + cooked = self.getcache(path) + + # Now lets just grab the chunk + start = chunk * 254 + txtr = cooked[start:start + 254] + reply.add_answer(RR(name, QTYPE.TXT, ttl=self.ttl, rdata=TXT(txtr))) + return reply + + # Replace labels with request label return reply - # work out the count mathematically - # First work out the base64 size - l=os.path.getsize(path) - l=(4*l)/3 - # And padding - l+=(l%4) - # Finally divide into number of 254 byte chunks - chunks=(l/254) - - if ''.join(parts[-2:-1]) == "count": - reply.add_answer(RR(name, QTYPE.TXT, ttl=self.ttl, - rdata=TXT(str(chunks)))) - return reply - - if ''.join(parts[-2:-1]).isdigit(): - # Woo it's a number - # lets base64 the file - chunk=int(''.join(parts[-2:-1])) - if chunk > chunks or chunk < 0: - reply.add_answer(RR(name, QTYPE.TXT, ttl=self.ttl, - rdata=TXT("Chunk out of range"))) - return reply - # It's actually trying to read the file, so let's cache it in memory - cooked=self.getcache(path) - - # Now lets just grab the chunk - start=chunk*254 - txtr=cooked[start:start+254] - reply.add_answer(RR(name, QTYPE.TXT, ttl=self.ttl, - rdata=TXT(txtr))) - return reply - - # Replace labels with request label - return reply if __name__ == '__main__': - - import argparse,sys,time - - p = argparse.ArgumentParser(description="Fixed DNS Resolver") - p.add_argument("--directory","-r", - metavar="", - help="Directory to pull files from") - p.add_argument("--domain","-d", - metavar="", - help="Domain we're working in") - p.add_argument("--port","-p",type=int,default=53, - metavar="", - help="Server port (default:53)") - p.add_argument("--address","-a",default="", - metavar="
", - help="Listen address (default:all)") - p.add_argument("--udplen","-u",type=int,default=0, - metavar="", - help="Max UDP packet length (default:0)") - p.add_argument("--tcp",action='store_true',default=False, - help="TCP server (default: UDP only)") - p.add_argument("--log",default="request,reply,truncated,error", - help="Log hooks to enable (default: +request,+reply,+truncated,+error,-recv,-send,-data)") - p.add_argument("--log-prefix",action='store_true',default=False, - help="Log prefix (timestamp/handler/resolver) (default: False)") - args = p.parse_args() - - resolver = FileResolver(args.directory,args.domain) - logger = DNSLogger(args.log,args.log_prefix) - - print("Starting File Resolver (%s:%d) [%s]" % ( - args.address or "*", - args.port, - "UDP/TCP" if args.tcp else "UDP")) - - if args.udplen: - DNSHandler.udplen = args.udplen - - udp_server = DNSServer(resolver, - port=args.port, - address=args.address, - logger=logger) - udp_server.start_thread() - - if args.tcp: - tcp_server = DNSServer(resolver, - port=args.port, - address=args.address, - tcp=True, - logger=logger) - tcp_server.start_thread() - - while udp_server.isAlive(): - time.sleep(1) - + import argparse + import time + + p = argparse.ArgumentParser(description="Fixed DNS Resolver") + p.add_argument("--directory", "-r", metavar="", help="Directory to pull files from") + p.add_argument("--domain", "-d", metavar="", help="Domain we're working in") + p.add_argument("--port", "-p", type=int, default=53, metavar="", help="Server port (default:53)") + p.add_argument("--address", "-a", default="", metavar="
", help="Listen address (default:all)") + p.add_argument("--udplen", "-u", type=int, default=0, metavar="", help="Max UDP packet length (default:0)") + p.add_argument("--tcp", action='store_true', default=False, help="TCP server (default: UDP only)") + p.add_argument("--log", default="request,reply,truncated,error", + help="Log hooks (default: +request,+reply,+truncated,+error,-recv,-send,-data)") + p.add_argument("--log-prefix", action='store_true', default=False, help="Log prefix (default: False)") + args = p.parse_args() + + resolver = FileResolver(args.directory, args.domain) + logger = DNSLogger(args.log, args.log_prefix) + + print(f"Starting File Resolver ({args.address or '*'}:{args.port}) [{'UDP/TCP' if args.tcp else 'UDP'}]") + + if args.udplen: + DNSHandler.udplen = args.udplen + + udp_server = DNSServer(resolver, port=args.port, address=args.address, logger=logger) + udp_server.start_thread() + + if args.tcp: + tcp_server = DNSServer(resolver, port=args.port, address=args.address, tcp=True, logger=logger) + tcp_server.start_thread() + + while udp_server.isAlive(): + time.sleep(1)