from http import server, HTTPStatus from hashlib import sha256 from configparser import ConfigParser from datetime import datetime from re import search from urllib import request from urllib.error import HTTPError from io import StringIO from contextlib import contextmanager from heapq import heappush, heappop from itertools import count from time import sleep import threading import portalocker import os import sys class CacheHandler(server.BaseHTTPRequestHandler): """ This class implements caching of Range-Requests. It subclasses the generic BaseHTTPRequestHandler and implements the HTTP GET, POST and HEAD methods. Other methods can be added easily if needed. """ def do_GET(self): """ This method is called when handling a GET request. It filters range requests, forwards non-range requests and implements the caching by writing cached data to a sparse file and saving the information about which block is cached in another file. """ rrange = self.headers.get("Range") if rrange is None or "," in rrange: try: headers = dict(self.headers.items()) req = request.Request(target + self.path, None, headers) answer = request.urlopen(req) self.send_response(answer.getcode()) for tup in answer.items(): self.send_header(tup[0], tup[1]) self.end_headers() self.wfile.write(answer.read()) except HTTPError as e: self.send_error(e.code, e.msg) regex = search('([A-Za-z]+)=(\d+)\s*-\s*(\d+)', rrange) start = int(regex.group(2)) start_sector = start//block_size end = int(regex.group(3)) length = end-start+1 length_sector = length//block_size m = sha256() m.update(self.path.encode('utf-8')) filename = "cache_" + m.hexdigest() data_filename = filename + ".data" missing_parts = 0 threadLock.acquire() try: with self.filelock(filename, 'r') as f: f.seek(start_sector) cached_sectors = f.read(length_sector) last_exists = True cur = sector = missing_start = 0 for cur, sector in enumerate(cached_sectors, start_sector): if sector is "1" and not last_exists: last_exists = True if missing_parts is 0: missing_parts = (missing_start*block_size, (cur-1)*block_size) else: missing_parts = (start, start+length-1) break elif not sector is "1" and last_exists: last_exists = False missing_start = cur if not last_exists: if missing_parts is 0: missing_parts = (missing_start*block_size, cur*block_size) else: missing_parts = (start, start+length-1) except OSError: try: req = request.Request(target + self.path, None, {}, None, False, 'HEAD') size = int(request.urlopen(req).getheader("Content-Length")) except HTTPError as e: self.send_error(e.code, e.msg) return with self.filelock(filename, 'w') as f: f.truncate(size//block_size) with self.filelock(data_filename, 'wb') as df: df.truncate(size) missing_parts = (start, start+length-1) part_start, part_end = missing_parts try: part_length = part_end - part_start + 1 rheaders = dict(self.headers.items()) rheaders["Range"] = "bytes=" + str(part_start) + "-" + str(part_end) req = request.Request(target + self.path, None, rheaders) answer = request.urlopen(req) data = answer.read() if answer.getcode() == HTTPStatus.PARTIAL_CONTENT and len(data) == part_length: with self.filelock(data_filename, 'rb+') as f: f.seek(part_start) f.write(data) with self.filelock(filename, 'r+') as f: f.seek(part_start//block_size) for i in range(part_length//block_size): f.write("1") self.check_cache_size() else: raise HTTPError( req.get_full_url(), HTTPStatus.BAD_GATEWAY, "", rheaders, StringIO("") ) except HTTPError as e: self.send_error(e.code, e.msg) return with self.filelock(data_filename, 'rb') as f: f.seek(start) data = f.read(length) threadLock.release() size = os.stat(data_filename).st_size self.send_response(HTTPStatus.PARTIAL_CONTENT) self.send_header("Content-Length", length) self.send_header("Content-Range", "bytes "+str(start)+"-"+str(end)+"/"+str(size)) self.end_headers() self.wfile.write(data) @contextmanager def filelock(self, fp, mode = "r+"): """ This method opens a file and locks it according to the mode. The portalocker module is used to lock the requested file. If a read only mode is chosen, a shared lock is used, otherwise the lock is exclusive. The contextmanager closes the file after leaving the "with" block and the lock is released after closing automatically. Args: fp (str): The file path to open. mode (str, optional): The mode used by open. Defaults to 'r+' """ file = open(fp, mode) if mode is 'r' or mode is 'rb': flag = portalocker.LOCK_SH else: flag = portalocker.LOCK_EX portalocker.lock(file, flag) try: yield file finally: file.close() def do_POST(self): """ This method is called when handling a POST request. It forward the headers and data of the original request to the image archive and returns the answer to the client. """ try: headers = dict(self.headers.items()) req = request.Request(target + self.path, self.rfile.read(), headers, None, False, 'POST') answer = request.urlopen(req) self.send_response(answer.getcode()) for tup in answer.getheaders(): self.send_header(tup[0], tup[1]) self.end_headers() self.wfile.write(answer.read()) except HTTPError as e: self.send_error(e.code, e.msg) def do_HEAD(self): """ This method is called when handling a HEAD request. It forwards the headers of the original request to the image archive and returns the answer to the client. """ try: headers = dict(self.headers.items()) req = request.Request(target + self.path, None, headers, None, False, 'HEAD') answer = request.urlopen(req) self.send_response(answer.getcode()) for tup in answer.getheaders(): self.send_header(tup[0], tup[1]) self.end_headers() self.wfile.write(answer.read()) except HTTPError as e: self.send_error(e.code, e.msg) class QuotaCheck (threading.Thread): """ This class checks for exceeding the allowed disk space and cleans update. It subclasses the threading.Thread to run independently from the proxy thread. """ def __init__(self, quota): """ This method is called to initiate the class. It calls the super classes __init__ and sets the selected quota. Args: quota (int): The chosen disk usage limit. """ threading.Thread.__init__(self) self.quota = quota def run(self): """This method is called by the super class to run the thread. It is called by the super classes start method. It checks whether the space is exceeded. If this is the case, it locks the proxys thread to avoid problems and cleans the cache. """ while True: sleep(1) if self.cache_full(): threadLock.acquire() self.clean_cache() threadLock.release() def cache_full(self): """ This method checks whether the disk space is exceeded. """ total_size = 0 for f in os.scandir(): if f.name.startswith("cache_") and f.is_file(): total_size += f.stat().st_blocks/2048 return total_size > self.quota def clean_cache(self): """This methods reduces the disk usage to fulfill the quota It creates a priority queue of cache files ordered by the time of the last access and deletes the files whose last access was longest ago until the quota is fulfilled. """ total_size = 0 h = [] c = count() for f in os.scandir(): if f.name.startswith("cache_") and f.is_file(): total_size += f.stat().st_blocks/2048 heappush(h, (f.stat().st_atime_ns, next(c), f)) while total_size > self.quota: f = heappop(h)[2] os.remove(f.path) total_size -= f.stat().st_blocks/2048 class ServerThread (threading.Thread): """ This class runs the proxy as a separate thread. It subclasses the threading.Thread to run independently from the quota thread. """ def __init__(self): """ This method is called to initiate the class. It calls the super classes __init__. """ threading.Thread.__init__(self) self.httpd = server.HTTPServer(server_address, CacheHandler) def run(self): """This method is called by the super class to run the thread. It is called by the super classes start method. It runs the proxys method to check for incoming requests forever. """ self.httpd.serve_forever() if __name__ == '__main__': """ This code starts the proxy and the quota threads It is executed when the file is executed. It reads the config file and sets the globals, creates the threads and waits for the proxy thread to finish (which will not happen). It catches a KeyboardInterrupt and ends the program if recieved. """ config = ConfigParser() config.read('config.ini') server_address = ('', int(config['Config']['port'])) global block_size block_size = int(config['Config']['block_size']) global target target = config['Config']['target'] global threadLock threadLock = threading.Lock() t1 = QuotaCheck(int(config['Config']['quota'])) t2 = ServerThread() t1.start() t2.start() try: while t2.isAlive(): t2.join(1) except KeyboardInterrupt: print("Interrupt recieved, stopping...") sys.exit()