#!/usr/bin/env python

#############################################################################
#  findappletv.py - Find AppleTV search requests in network captures
#  Copyright (C) 2009 Matt Sabourin
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program.  If not, see <http://www.gnu.org/licenses/>.
#############################################################################

# Script was tested using Python 2.5 on OSX 10.5

# Portions of this program are based on code found in findsmtpinfo.py by
# Jeremy Rossi, provided at 
# http://forensicscontest.com/contest02/Finalists/Jeremy_Rossi/findsmtpinfo.py


# Script tested  using dpkt-read-only trunk r57
# http://code.google.com/p/dpkt/source/checkout
import dpkt
# Script tested  using pypcap-read-only trunk r87 (allow python 2.5 usage)
# http://code.google.com/p/pypcap/source/checkout
import pcap
# Next two are used to decode gzip compressed HTTP responses 
import gzip
import StringIO
import sys
import os.path, os
# Compute MD5 hashes
import hashlib
import re
import struct
# used to convert IP addresses from dpkt
import socket
from optparse import OptionParser
import pprint
import datetime
import math
import xml.etree.ElementTree


def computeMD5(filename):
        """ Function to compute MD5 hash of a file on disk """
        file = open(filename)
        md5 = hashlib.md5()
        while True:
                data = file.read(8192)
                if not data:
                        break
                md5.update(data)
        return  md5.hexdigest()



def getNumDigits(num):
        """ Function to compute number of digits in an integer """
        # http://stackoverflow.com/questions/1489830/efficient-way-to-determine-number-of-digits-in-an-integer
        if num > 0:
                return int (math.log(num,10)) + 1
        else:
                return 1


class pObject(object):
        """ Class that represents a plistObject as defined by the Apple
        Property-List-1.0.dtd. """
        def __new__ (self,xmlObject):
                """ Override the __new__ method so we can return the Python equivalent
                of the object type listed in the XML document. """

                if xmlObject.tag == 'string':
                        return xmlObject.text
                elif xmlObject.tag == 'date':
                        return xmlObject.text
                elif xmlObject.tag == 'data':
                        # DTD defines this as Base-64 encoded data; treat as string for now
                        return xmlObject.text
                elif xmlObject.tag == 'integer':
                        return int(xmlObject.text)
                elif xmlObject.tag == 'real':
                        return float(xmlObject.text)
                elif xmlObject.tag == 'true':
                        return True
                elif xmlObject.tag == 'false':
                        return False
                elif xmlObject.tag == 'dict':
                        rdict = {}
                        _children = xmlObject.getchildren()
                        if len(_children) % 2 != 0:
                                return rdict

                        for _num in range(0,len(_children),2):
                                rdict[_children[_num].text] = pObject(_children[_num+1])

                        return rdict
                elif xmlObject.tag == 'array':
                        rlist = []
                        for child in xmlObject.getchildren():
                                rlist.append(pObject(child))
                        return rlist

class pList(object):
        """ Class that represents an Apple plist XML document.  Take the root
        dictionary element and create 'attributes' for the object.  Use the
        get/set methods to manipulate these attributes.  """

        def __init__ (self,filename=None,xmlbuf=None):
                """ If xmlbuf is set, we need to use StringIO to treat it as a file.
                Otherwise, just pass in the filename string and ElementTree will open
                the file. """
                if xmlbuf:
                        _file = StringIO.StringIO(xmlbuf)
                elif filename:
                        _file = filename

                _tree = xml.etree.ElementTree.ElementTree()
                try:
                        _root = _tree.parse(_file)
                        self.valid = True
                except: 
                        # Had problems catching proper exception. 
                        # It should be xml.parsers.expat.ExpatError
                        self.invalidReason = "Invalid XML document"
                        self.valid = False
                        return

                if _root.tag != 'plist':
                        self.invalidReason = "Unexpected root element"
                        self.valid = False
                        return

                _pdict = _root.find('./dict')

                if not _pdict:
                        self.invalidReason = "Could not find main dictionary"
                        self.valid = False
                        return

                self._dictionary = pObject(_pdict)
                self.valid = True

        def get(self,name):
                if self._dictionary.has_key(name):
                        return self._dictionary[name]
                else:
                        raise AttributeError,(self.__class__.__name__,name)

        def set(self,name,value):
                if self._dictionary.has_key(name):
                        self._dictionary[name] = value
                else:
                        raise AttributeError,(self.__class__.__name__,name)

        def has(self,name):
                return self._dictionary.has_key(name)




class pcapSearcher(object):
        """ Main class used to read a pcap file and search it for potential 
        Apple TV clients. """

        def __init__(self,pcapfilename,reportdir,debug=0):
                self.filename = pcapfilename
                self.debug = debug
                # Call routine that computes MD5 of file and sets appropriate attribute
                self.md5 = computeMD5(self.filename)
                self.atvSources = {}
                self.txtReport = ''
                self.reportDir = reportdir
                self.reportFile = ''
                self.reportMD5 = ''


        def parse(self):
                """ Method to parse the pcap file and pull out possible AppleTV 
                HTTP requests.  Builds a dictionary of clients making these requests """

                pcapFH = open(self.filename)
                pcapRdr = dpkt.pcap.Reader(pcapFH)
                pktCount = 0

                for ts, buf in pcapRdr:
                        pktCount += 1
                        eth = dpkt.ethernet.Ethernet(buf)
                        ip = eth.data
                        tcp = ip.data

                        # Assume AppleTV requests are sent to standard non-encrypted
                        # web port.  Make sure TCP packet has data (the HTTP request)
                        # AppleTV requests are HTTP GET requests, check for these strings
                        if tcp.dport == 80 and len(tcp.data) > 0 \
                                        and 'HTTP' in tcp.data and 'GET' in tcp.data:

                                currReq = atvRequest(self.filename,pktCount,ts,ip)
                                if currReq.validRequest or currReq.possibleRequest:
                                        # The dpkt IP object stores addresses in packed form, 
                                        # we need to unpack them 
                                        ipquad = socket.inet_ntoa(ip.src)

                                        if self.atvSources.has_key(ipquad):
                                                atvSource = self.atvSources[ipquad]
                                        else:
                                                atvSource = atvClient(ipquad,eth.src)
                                                self.atvSources[ipquad] = atvSource

                                #if currReq.validRequest:
                                #       atvSource.addValidReq(currReq)
                                #elif currReq.possibleRequest:
                                #       atvSource.addInvalidReq(currReq)
                                atvSource.addRequest(currReq)

                pcapFH.close()
                print "[+] Parsing Complete."
                print "[+] Using report directory: %s" % (self.reportDir)
                print "[+] Writing report files..."
                self.writeReport(self.reportDir)
                print "[+] Reports complete."
                #print self.txtReport


        def log (self,line):
                self.txtReport += line + '\n'


        def writeReport(self,reportdir):
                """ Method builds the report in self.txtReport using the self.log method
                and the writes the report to a file in reportdir.  The reportdir directory
                is created if needed.  Finally, the method computes the MD5 of the report
                file and stores the value into self.MD5 """

                self.log('=' * 78)
                self.log('Report for network capture file:   %s' % self.filename)
                self.log('MD5 hash of network capture file:  %s' % self.md5)
                self.log('Report created at:                 %s' % 
                                 datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"))
                self.log('=' * 78)
                self.log('This report lists possible AppleTV clients based on the presence of HTTP')
                self.log('requests with specific parameters.')
                self.log('')
                self.log('Recognized requests are those that contain a known URI, user-agent, and')
                self.log('addtional HTTP header information.')
                self.log('')
                self.log('Unrecognized requests are those that contain one or more of the known')
                self.log('features, but do not meet all criteria')
                self.log('=' * 78)
                self.log('')
                self.log('Number of possible AppleTV clients:\t%d' % len(self.atvSources))
                self.log('')
                self.log('-' * 65)

                for source in self.atvSources.keys():
                        self.atvSources[source].writeReport(reportdir)
                        #print self.atvSources[source].txtReport

                        self.log('Client IP address:   %s' % source)
                        self.log('Client MAC address:  %s' % self.atvSources[source].MAC)
                        self.log('')
                        self.log('    Number of recognized requests:    %d' % self.atvSources[source].numValid())
                        self.log('    Number of unrecognized requests:  %d' % self.atvSources[source].numInvalid())
                        self.log('')
                        self.log('    Report filename: %s' % self.atvSources[source].reportFile)
                        self.log('    Report MD5:      %s' % self.atvSources[source].reportMD5)
                        self.log('-' * 65)
                self.log('\n')
                self.log('=' * 78)
                self.log('REPORT COMPLETE'.center(78))
                self.log('=' * 78)

                self.reportFile = os.path.join(reportdir,"Overview-Report.txt")
                if not os.path.exists(reportdir):
                        os.makedirs(reportdir)

                rptFile = open(self.reportFile, "w")
                rptFile.write(self.txtReport)
                rptFile.close()

                self.reportMD5 = computeMD5(self.reportFile)




class atvRequest(object):
        """ Class to hold potential AppleTV HTTP requests objects.  Includes method
        to determine if the HTTP request represents a known / valid AppleTV request
        Also holds dpkt HTTP request object and, for recognized AppleTV requests, 
        an httpResponse object. """

        # These are the criteria used to determine if the request is a known 
        # AppleTV request
        atvKnownURIs = ['/WebObjects/MZSearch.woa/wa/incrementalSearch',\
                                                                '/WebObjects/MZStore.woa/wa/viewMovie',\
                                                                '/WebObjects/MZStore.woa/wa/search',\
                                                                '/WebObjects/MZStore.woa/wa/viewGrouping']
        atvKnownHosts = ['ax.search.itunes.apple.com',\
                                                                 'ax.itunes.apple.com']
        #atvKnownAgents = ['AppleTV/2.4']
        atvKnownAgents = ['^AppleTV\/\d\.\d$']

        def __init__(self,filename,pktNum,timestamp,ip):
                """ Constructor expects the filename of the packet capture (to build
                httpResponse object if valid request), packet number (of the request) 
                within the packet capture, the timestamp from the capture, and the 
                dpkt IP object holding the HTTP request """

                self.possibleRequest = False
                self.validURI = False
                self.validHost = False
                self.validAgent = False
                self.validRequest = False

                self.pktNum = pktNum
                self.timestamp = timestamp

                self.filename = filename
                self.request = httpRequest(ip.data) # IP->TCP
                self.srcIP = socket.inet_ntoa(ip.src)
                self.dstIP = socket.inet_ntoa(ip.dst)
                self.srcPort = ip.data.sport    # IP->TCP
                self.dstPort = ip.data.dport    # IP->TCP
                self.uriType = ''
                self.response = ''
                self.reportFile = ''
                self.reportMD5 = ''
                self.txtReport = ''

                self.isValidRequest()


        def isValidRequest(self):
                """ Method will examine the HTTP Request headers and URI to determine
                if this request represents an AppleTV request.  Currently, the method
                examines the Host: header, the User-Agent: header, and the start of 
                the request URI. """

                if self.request.hasHeader('host') and self.request.getHeader('host') in self.atvKnownHosts:
                        self.validHost = True
                        self.possibleRequest = True


                if self.request.hasHeader('user-agent'):
                        for agent in self.atvKnownAgents:
                                if re.match(agent,self.request.getHeader('user-agent')):
                                        self.validAgent = True
                                        self.possibleRequest = True
                                        break

                for uri in self.atvKnownURIs:
                        if self.request.uri.startswith(uri):
                                self.validURI = True
                                self.possibleRequest = True
                                self.uriType = uri.split('/')[-1]
                                break

                if self.validURI and self.validHost and self.validAgent:
                        self.validRequest = True
                        # We know the request is valid, so build the response object
                        self.response = httpResponse(self.filename,self.pktNum,self.dstIP,\
                                                                                self.dstPort,self.srcIP,self.srcPort)
                        return True
                else:
                        return False

        def log (self,line):
                self.txtReport += line + '\n'


        def infoShort(self):
                """ Method return the request packet number, the request timestamp,
                the HTTP method, and the HTTP URI """
                return '%-6s   %-13s   %s %s' % (self.pktNum, self.timestamp, self.request.method, self.request.uri)


        def infoFull(self):
                """ Method returns the following information for the HTTP request:
                        HTTP method, HTTP URI, HTTP version, HTTP headers.  It also returns the following
                        information for the HTTP response (if a valid AppleTV request):
                        HTTP response status line (code, string, version), HTTP response
                        headers. """

                info = ''
                info += '%s %s HTTP/%s\n' % (self.request.method, self.request.uri,self.request.version)
                for hdr in self.request.headers:
                        info += '%s: %s\n' % (hdr[0],hdr[1])

                if self.validRequest and self.response:
                        info += '\n\n'

                        info += '%s\n' % (self.response.status)
                        for hdr in self.response.headers:
                                info += '%s: %s\n' % (hdr[0],hdr[1])

                return info

        def getMovieInfo(self):
                """ This method uses the objects representing the PropertyList DTD and
                uses them to try to find information about the movie item that was
                viewed.  There are several validity checks that are made and may cause
                the method to return an empty string. """

                info = ''

                print "[+]  Parsing information for movie..."

                if self.uriType != 'viewMovie' or not self.response:
                        return ''

                plist = pList(xmlbuf=self.response.response)

                if not plist.valid:
                        print "[-]    %s" % plist.invalidReason
                        return ''

                if not plist.has('page-type'):
                        print "[-]    Could not determine type of response (no page-type)"
                        return ''

                if plist.get('page-type')['template-name'] == 'item':
                        print "[+]    Found an item view"
                else:
                        print "[-]    Did not find an item view"
                        return ''

                if not plist.has('items') or len(plist.get('items')) == 0:
                        print "[-]    Could not find list of items viewed"
                        return ''

                items = plist.get('items')[0]

                if not items.has_key('type'):
                        print "[-]    Could not find type of item"
                        return ''

                if not items['type'] == 'movie':
                        print "[-]    Item is not a movie"
                        return ''

                if items.has_key('title'):
                        info += "%15s:  %s" % ('Title',items['title']) + '\n'

                # Copyright symbol causing issues when writing to file
                if items.has_key('copyright'):
                        info += "%15s:  %s" % ('Copyright',items['copyright'][2:]) + '\n'

                if items.has_key('rating'):
                        if items['rating'].has_key('system') and items['rating'].has_key('label'):
                                info += "%15s:  %s (%s)" % ('Rating',items['rating']['label'],items['rating']['system']) + '\n'

                if items.has_key('release-date'):
                        info += "%15s:  %s" % ('Release Date',items['release-date']) + '\n'

                if items.has_key('item-id'):
                        info += "%15s:  %s" % ('iTunes Item ID',items['item-id']) + '\n'

                if items.has_key('url'):
                        info += "%15s:  %s" % ('iTunes Item URL',items['url']) + '\n'

                if not items.has_key('store-offers'):
                        print "[-]    Could not find store offers"
                        return info

                offers = items['store-offers']
                for movie in offers.keys():
                        if offers[movie].has_key('action-display-name') and offers[movie]['action-display-name'] == 'Buy':
                                info += "%15s:  %s" % ('Can Purchase',movie) + '\n'
                                if offers[movie].has_key('preview-url'):
                                        info += "%15s:  %s" % ('Trailer URL',offers[movie]['preview-url']) + '\n'
                                if offers[movie].has_key('price-display'):
                                        info += "%15s:  %s" % ('Purchase Price',offers[movie]['price-display']) + '\n'
                        elif offers[movie].has_key('action-display-name') and offers[movie]['action-display-name'] == 'Rent':
                                info += "%15s:  %s" % ('Can Rent',movie) + '\n'
                                if offers[movie].has_key('preview-url'):
                                        info += "%15s:  %s" % ('Trailer URL',offers[movie]['preview-url']) + '\n'
                                if offers[movie].has_key('rent-price-display'):
                                        info += "%15s:  %s" % ('Rental Price',offers[movie]['rent-price-display']) + '\n'

                return info

        def writeReport(self,reportdir,maxPkt):
                """ Method to create a text file with a report describing this request.
                The report will contain text from infoFull(), and if it is a valid 
                AppleTV request, will also contain the decompressed (if needed)  HTTP 
                response. Filename contains zero-filled packet number for nicer sorting.
                The report file is created in a subdirectory of the reportdir, with the
                subdirectory name containing the IP of the AppleTV client. Finally, the
                MD5 of the report file is compted and stored in self.MD5.  """

                # Use maxPkt to determine size of packet number field in the file name
                # This allows us to zero-fill - gives nicer sorted directory listing
                fill = getNumDigits(maxPkt)
                fname = "Report-%0*d-%s" % (fill,self.pktNum,self.uriType)

                self.log('=' * 78)
                self.log('Report for request from %s:%d to %s:%d' % (self.srcIP,self.srcPort,self.dstIP,self.dstPort))
                self.log('Report created at:        %s' % datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"))
                self.log('AppleTV request type:     %s' % self.uriType)
                self.log('Request starting packet:  %-6s' % self.pktNum)
                self.log('Request timestamp:        %-13s' % self.timestamp)
                if self.isValidRequest and self.response:
                        self.log('Response starting packet: %-6s' % self.response.startPkt)
                        self.log('Response timestamp:       %-13s' % self.response.startTime)
                self.log('=' * 78)
                self.log('')

                self.log(self.infoFull())
                if self.isValidRequest and self.response:
                        self.log('')
                        self.log(self.response.response)

                self.log('=' * 78)
                self.log('REPORT COMPLETE'.center(78))
                self.log('=' * 78)

                self.reportFile = os.path.join(reportdir,'Traffic-%s' % (self.srcIP),fname)
                if not os.path.exists(os.path.join(reportdir, 'Traffic-%s' % (self.srcIP))):
                        os.makedirs(os.path.join(reportdir, 'Traffic-%s' % (self.srcIP)))

                rptFile = open(self.reportFile, "w")
                rptFile.write(self.txtReport)
                rptFile.close()

                self.reportMD5 = computeMD5(self.reportFile)


class atvClient(object):
        """ Class to represent a potential AppleTV client.  Holds two lists of 
        atvRequest objects - one for valid requests, one for invalid requests. """

        def __init__(self,ip,mac):
                self.ip = ip
                # the incoming MAC address has been packed, we need to unpack and format it
                self.MAC = '%02X:%02X:%02X:%02X:%02X:%02X' % struct.unpack('6B',mac)

                self.validRequests = []
                self.invalidRequests = []
                self.txtReport = ''
                self.reportFile = ''
                self.reportMD5 = ''
                self.maxValidPkt = 0
                self.userAgentsUsed = []
                self.searches = []
                self.reqTypeCount = {}

        def _getURIParams(self,uri):
                params = {}
                enduri = uri.split('/')[-1].split('?')[1]
                paramList = enduri.split('&')
                for param in paramList:
                        params[param.split('=')[0]] = param.split('=')[1]

                return params


        def addRequest(self,request):
                self.reqTypeCount.setdefault(request.uriType,0)
                self.reqTypeCount[request.uriType] += 1

                if request.request.hasHeader('user-agent'):
                        if request.request.getHeader('user-agent') not in self.userAgentsUsed:
                                self.userAgentsUsed.append(request.request.getHeader('user-agent'))

                if request.uriType == 'incrementalSearch':
                        #self.searches.append(request.uriType.split('q=')[1])
                        params = self._getURIParams(request.request.uri)
                        if params.has_key('q'):
                                self.searches.append(params['q'])
                elif request.uriType == 'search':
                        params = self._getURIParams(request.request.uri)
                        if params.has_key('term'):
                                self.searches.append(params['term'])


                if request.isValidRequest():
                        self.addValidReq(request)
                else:
                        self.addInvalidReq(request)

        def addValidReq(self,request):
                self.validRequests.append(request)
                if request.pktNum > self.maxValidPkt:
                        self.maxValidPkt = request.pktNum

        def addInvalidReq(self,request):
                self.invalidRequests.append(request)

        def numValid(self):
                return len(self.validRequests)

        def numInvalid(self):
                return len(self.invalidRequests)



        def log (self,line):
                self.txtReport += line + '\n'

        def writeReport(self,reportdir):
                self.log('=' * 78)
                self.log('Report for potential AppleTV client IP:  %s' % self.ip)
                self.log('Potential AppleTV client MAC:            %s' % self.MAC)
                self.log('Report created at:                       %s' % \
                                  datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"))
                self.log('=' * 78)
                self.log('This report lists the HTTP requests of the possible AppleTV client using the')
                self.log('IP address listed at the start of this report.')
                self.log('')
                self.log('Recognized requests are those that contain a known URI, user-agent, and')
                self.log('addtional HTTP header information.')
                self.log('')
                self.log('Unrecognized requests are those that contain one or more of the known')
                self.log('features, but do not meet all criteria')
                self.log('=' * 78)
                self.log('')

                self.log('-' * 78)
                self.log('Unique Values for user-agent Request Header:'.center(78))
                self.log('-' * 78)
                for agent in self.userAgentsUsed:
                        self.log(agent)
                self.log('')

                self.log('-' * 78)
                self.log('Search Terms Sent by Client:'.center(78))
                self.log('-' * 78)
                for term in self.searches:
                        self.log(term)
                self.log('')

                self.log('-' * 78)
                self.log('Movie Items Viewed by Client:'.center(78))
                self.log('-' * 78)
                for req in self.validRequests:
                        if req.uriType == 'viewMovie' and req.response:
                                self.log(req.getMovieInfo())
                self.log('')
                self.log('Note:  The copyright value has had the leading two characters removed.')
                self.log('')

                self.log('-' * 78)
                self.log('Overview of Recognized Requests:'.center(78))
                self.log('-' * 78)
                self.log('Pkt No.  Timestamp       Request')
                self.log('-' * 78)
                for req in self.validRequests:
                        self.log(req.infoShort())
                        self.log('')
                        req.writeReport(reportdir,self.maxValidPkt)
                        self.log('         Traffic report file: %s' % (req.reportFile))
                        self.log('         Traffic report MD5:  %s' % (req.reportMD5))
                        self.log('')
                        self.log('')

                self.log('-' * 78)
                self.log('Overview of Unrecognized Requests:'.center(78))
                self.log('-' * 78)
                self.log('Pkt No.  Timestamp       Request')
                self.log('-' * 78)
                for req in self.invalidRequests:
                        self.log('%-6s   %-13s   %s %s' % (req.pktNum, req.timestamp, req.request.method, req.request.uri))
                        self.log('')
                self.log('')
                self.log('-' * 78)
                self.log('Criteria Used to Identify AppleTV Requests:'.center(78))
                self.log('-' * 78)
                self.log('Recognized URI patterns (URI must start with one of the following strings):')
                for uri in atvRequest.atvKnownURIs:
                        self.log('\t%s' % (uri))
                self.log('')
                self.log('Recognized values for HTTP request "Host:" header:')
                for host in atvRequest.atvKnownHosts:
                        self.log('\t%s' % (host))
                self.log('')
                self.log('Recognized patterns for HTTP request "User-Agent:" header (Python regular expression syntax):')
                for agent in atvRequest.atvKnownAgents:
                        self.log('\t%s' % (agent))
                self.log('')
                self.log('=' * 78)
                self.log('REPORT COMPLETE'.center(78))
                self.log('=' * 78)

                self.reportFile = os.path.join(reportdir,"Report-%s.txt" % (self.ip))
                if not os.path.exists(reportdir):
                        os.makedirs(reportdir)

                rptFile = open(self.reportFile, "w")
                rptFile.write(self.txtReport)
                rptFile.close()

                self.reportMD5 = computeMD5(self.reportFile)

class httpMessage(object):
        """ Class defines basic components of an HTTP Message.  Used as a superclass for
        HTTP Request and HTTP Response objects.  Stores headers in a list instead of a 
        dictionary to handle multiple instances of header with same name (allowed by HTTP
        spec).  List storage of headers also allows us to print headers in the same order
        and capitalization as presented in the network capture. """

        # TODO - Provide separate getHeaderName and getHeaderValue functions
        #                Handle headers extended across multiple lines
        #                Handle chunked transmissions

        def addHeader(self,key,value):
                """ Function to add header and value to list, preserving order of
                headers in packet as well as capitalization of header names.  This
                also allows for multiple headers of the same name. """

                self.headers.append((key,value))


        def hasHeader(self,key):
                """ Function to check if object contains a header based on an all
                lower case version of the name. """
                for hdr in self.headers:
                        if key.lower() == hdr[0].lower():
                                return True

                return False


        def getHeader(self,key):
                """ Function to retrieve header based on all lower case verson of
                the name.  The return value is a list, to handle cases where multiple
                headers with the same name have been added to the object. """

                result = []

                for hdr in self.headers:
                        if key.lower() == hdr[0].lower():
                                result.append(hdr[1])

                # Spec states that HTTP messages can have multiple headers with the same
                # name, but only if the meaning of the message is the same when appending
                # each header value, separated by a comma.  If we have this case, let's
                # follow the spec for easier handling in code (i.e., returning a single
                # string instead of a list.

                if len(result) > 1:
                        return ','.join(result)
                else:
                        return result[0]


        def parseHeaders(self,data):
                """ Method that takes the data portion of an HTTP (TCP) packet and splits
                the headers from the body.  The body is assigned to self.data, while the
                headers are parsed and added to the self.headers list.  The method also
                returns the startline (the first line of the headers) - this is either the
                Request or Response line. """

                # HTTP message headers and body are separated by two sets 
                # of carriage return and linefeed characters (first set for last header,
                # second set represents blank line separting headers and response).
                if '\r\n\r\n' not in data:
                        raise "Premature end of HTTP headers"

                hdrs, self.data = data.split('\r\n\r\n')

                # First line of headers is actually the HTTP request or status line
                # depending on type of HTTP message
                startline = ''

                # Skip the start line, then parse the rest of the headers.  Headers end
                # in CRLF, split the data on the to get individual headers. 
                #for hdr in hdrs.split('\r\n')[1:]:
                for hdr in hdrs.split('\r\n'):
                        # HTTP spec allows for empty lines before the start line of a Request
                        if hdr and not startline:
                                startline = hdr
                                continue
                        elif not hdr:
                                continue

                        # HTTP header name and value separted by colon and possibly whitespace
                        # Leading and trailing whitespace is not significant in header value
                        if ':' in hdr:
                                name, value = hdr.split(':',1)
                                self.addHeader(name,value.strip())
                        else:
                                """ HTTP spec technically allows for extending a header across 
                                multiple lines, where additional lines start with spaces and/or 
                                tabs.  However, this is discouraged. """
                                raise "Invalid HTTP header found"

                return startline


class httpRequest(httpMessage):
        """ Class to hold an HTTP request.  Currently, this is mainly used so
        we can retain Header order, capitalization, and possible duplicate
        header names. """

        def __init__ (self,tcp):

                self.headers = []

                startline = self.parseHeaders(tcp.data)

                # Request startline is Method SP URI SP HTTP-Version CRLF
                if len(startline.split()) != 3:
                        raise "Improperly formatted HTTP Request start line"

                self.method,self.uri,self.version = startline.split()
                # Just store the version number, not the HTTP string
                self.version = re.sub('HTTP\/','',self.version)




class httpResponse(httpMessage):
        """ Class to hold an HTTP response.  Constructor will call function to
        search through the pcap and build up the response based on the value of the
        Content-Length header, web server port, and client request port. """

        def __init__ (self,pcapname,searchFrom,webIP,webPort,clientIP,clientPort):
                self.pcapname = pcapname
                self.searchFrom = searchFrom
                self.webIP = webIP
                self.webPort = webPort
                self.clientIP = clientIP
                self.clientPort = clientPort
                self.headers = []
                self.getData()
                # self.data holds raw response
                # self.response holds decompressed / viewable response
                # self.status holds the HTTP response code and status
                # self.startPkt holds packet number of first packet in response
                # self.startTime holds timestamp of first packet in response

        def getData(self):
                """ Method to open the pcap and pull in all the data from the http
                response. """

                # Use the pcap module so we can step through the pcap by packet
                pc = pcap.pcap(self.pcapname)
                pktnum = 0

                # Skip through pcap until we've reached the packet to begin our search
                # We skip the searchFrom packet because we assume this is the packet
                # that holds the HTTP request
                while pktnum < self.searchFrom:
                        ts,buf = pc.next()
                        pktnum += 1

                # Now that we're in the right spot, start searching for a possible HTTP
                # response, based on srcport, dstport, and tcp data containing the string
                # HTTP (found in HTTP responses)
                notFound = True
                while notFound:
                        ts,buf = pc.next()
                        pktnum += 1
                        tcp = dpkt.ethernet.Ethernet(buf).data.data
                        if tcp.sport == self.webPort and tcp.dport == self.clientPort:
                                #if '200 OK' in tcp.data and 'Content-Length:' in tcp.data:
                                if 'HTTP/' in tcp.data:
                                        notFound = False
                                        self.startPkt = pktnum
                                        self.startTime = ts

                self.status = self.parseHeaders(tcp.data)
                #print ''

                # If the response has a content-length header, we use that to search
                # pcap to find any packets needed to get full response data
                if self.hasHeader('content-length'):
                        length = self.getHeader('content-length')

                        if len(self.data) < int(length):
                                fullResp = False
                                while not fullResp:
                                        ts,buf = pc.next()
                                        tcp = dpkt.ethernet.Ethernet(buf).data.data
                                        if tcp.sport == self.webPort and tcp.dport == self.clientPort:
                                                self.data += tcp.data
                                                if len(self.data) >= int(length):
                                                        fullResp = True

                # HTTP response might have been compressed, if so we need to decompress
                if self.hasHeader('content-encoding'):
                        encoding = self.getHeader('content-encoding')

                        if encoding in ('gzip','x-gzip'):
                                # http://diveintopython.org/http_web_services/gzip_compression.html
                                tmpdata = gzip.GzipFile(fileobj=StringIO.StringIO(self.data))
                                self.response = tmpdata.read()
                        elif encoding == 'deflate':
                                # http://love-python.blogspot.com/2008/07/accept-encoding-gzip-to-make-your.html
                                tmpdata = StringIO.StringIO(zlib.decompress(self.data))
                                self.response = tmpdata.read()
                        else:
                                print "Unknown encoding."
                else:
                        self.response = self.data



def buildOptions(copyright):
        parser = OptionParser(usage="%prog [options] -p pcapfile", version=copyright)

        parser.add_option("-p", "--pcap", dest="pcapfile", help="[Required] Filename of the of the pcap to process")
        parser.add_option("-r", '--report', dest="reportdir", default="./report", help="[Default: ./report] Directory for reporting and processed output.  Created if needed.")
        parser.add_option("-f", '--force', dest="force", default=False, action="store_true", help="[Default: False] Force overwriting of files and directories")

        return parser

def handleOptions(options):
        """ Function that performs various checks of the command line options
        to verify everything needed is provided. """

        if not options.pcapfile:
                parser.error("-p|--pcap option must be specified please see --help for more details")

        if not os.path.isfile(options.pcapfile):
                parser.error("Could not read pcap file %s "%(options.pcapfile))

        if options.reportdir:
                if os.path.isfile(options.reportdir):
                        parser.error("-r|--report is a file and cannot be used for report output")
                elif os.path.isdir(options.reportdir):
                        if options.force:
                                print "[!] Emptying report directory %s" % options.reportdir
                                emptyDirectory(options.reportdir)
                        elif os.listdir(options.reportdir):
                                parser.error("-r|--report path is not empty! Use -f|--force to remove contents (not advised)")


def emptyDirectory(path):
        """ Function will remove all files and subdirectories in the specified path. 
        This can be dangerous as the function performs no sanity checks. """

        for root, dirs, files in os.walk(path, topdown=False):
            for name in files:
                        #print "[!] Removing %s" % os.path.join(root,name)
                os.remove(os.path.join(root, name))
            for name in dirs:
                        #print "[!] Removing %s" % os.path.join(root,name)
                os.rmdir(os.path.join(root, name))



if __name__ == '__main__':
        version = "\n%s  1.1" % (os.path.basename(sys.argv[0]))

        copyright = version + """

Copyright 2009 Matt Sabourin.
This is free software; see the source for copying conditions. There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
"""
        print
        print version
        print

        parser = buildOptions(copyright)

        (options, args) = parser.parse_args(sys.argv)


        handleOptions(options)

        print "[+] Storing reports in %s" % options.reportdir

        pcapSrch = pcapSearcher(options.pcapfile,options.reportdir)

        print "[+] Using network capture file:\t%s" % (pcapSrch.filename)
        print "[+] MD5 hash of capture file:\t%s" % (pcapSrch.md5)
        print "[+] Starting to parse network capture..."

        pcapSrch.parse()

        print "[+] Overview report saved at: %s" % (pcapSrch.reportFile)
        print "[+] Overview report file MD5: %s" % (pcapSrch.reportMD5)
        print "[+] Overview report listed below."
        print
  
        print pcapSrch.txtReport

