#!/usr/bin/python
import re, os, sys, pwd, grp, optparse, errno

_totalmem = 0
def totalmem():
    global _totalmem
    if not _totalmem:
        if options.realmem:
            _totalmem = fromunits(options.realmem)
        else:
            _totalmem = memory()['memtotal']
    return _totalmem

def pids():
    '''get a list of processes'''
    return [int(e) for e in os.listdir("/proc") if e.isdigit() and not iskernel(e)]

def pidmaps(pid):
    maps = {}
    start = None
    for l in file('/proc/%s/smaps' % pid):
    	f = l.split()
	if f[-1] == 'kB':
            maps[start][f[0][:-1].lower()] = int(f[1])
        else:
            start, end = f[0].split('-')
            start = int(start, 16)
            name = "<anonymous>"
            if len(f) > 5:
                name = f[5]
            maps[start] = dict(end=int(end, 16), mode=f[1],
                               offset=int(f[2], 16),
                               device=f[3], inode=f[4], name=name)

    return maps

def processtotals(pids):
    totals = {}
    for pid in pids:
        try:
            totals[pid] = pidtotals(pid)
        except:
            continue
    return totals

def sortmaps(totals, key):
    l = []
    for pid in totals:
        l.append((totals[pid][key], pid))
    l.sort()
    return [pid for pid,key in l]

def iskernel(pid):
    return pidcmd(pid) == ""

def pidname(pid):
    l = file('/proc/%d/stat' % pid).read()
    return l[l.find('(') + 1: l.find(')')]

def pidcmd(pid):
    c = file('/proc/%s/cmdline' % pid).read()[:-1]
    return c.replace('\0', ' ')

def piduser(pid):
    return os.stat('/proc/%d/cmdline' % pid).st_uid

def pidgroup(pid):
    return os.stat('/proc/%d/cmdline' % pid).st_gid

def memory():
    t = {}
    f = re.compile('(\\S+):\\s+(\\d+) kB')
    for l in file('/proc/meminfo'):
        m = f.match(l)
        if m:
            t[m.group(1).lower()] = int(m.group(2)) * 1024
    return t

def units(x):
    s = ''
    if x == 0:
        return '0'
    for s in ('', 'K', 'M', 'G'):
        if x < 1024:
            break
        x /= 1024.0
    return "%.1f%s" % (x, s)

def fromunits(x):
    s = dict(k=2**10, K=2**10, kB=2**10, KB=2**10,
             M=2**20, MB=2**20, G=2**30, GB=2**30)
    for k,v in s.items():
        if x.endswith(k):
            print k, int(float(x[:len(k)])*v)
            return int(float(x[:len(k)])*v)

_ucache = {}
def username(uid):
    if uid not in _ucache:
        _ucache[uid] = pwd.getpwuid(uid)[0]
    return _ucache[uid]

_gcache = {}
def groupname(gid):
    if gid not in _gcache:
        _gcache[gid] = grp.getgrgid(gid)[0]
    return _gcache[gid]

def showamount(a):
    if options.abbreviate:
        return units(a * 1024)
    elif options.percent:
        return "%.2f%%" % (102400.0 * a / totalmem())
    return a

def pidtotals(pid):
    maps = pidmaps(pid)
    t = dict(size=0, rss=0, pss=0, shared_clean=0, shared_dirty=0,
             private_clean=0, private_dirty=0, referenced=0, swap=0)
    for m in maps.iterkeys():
        for k in t:
            t[k] += maps[m].get(k, 0)

    t['uss'] = t['private_clean'] + t['private_dirty']
    t['arss'] = t['apss'] = t['auss'] = t['avss'] = 0
    for m in maps.iterkeys():
        if maps[m]['name'] in ('<anonymous>', '[heap]', '[stack]'):
            t['auss'] += maps[m].get('private_clean', 0) + maps[m].get('private_dirty', 0)
            t['arss'] += maps[m].get('rss', 0)
            t['apss'] += maps[m].get('pss', 0)
            t['avss'] += maps[m].get('size', 0)
    t['maps'] = len(maps)
    return t

def showpids():
    p = pids()
    pt = processtotals(p)

    def showuser(p):
        if options.numeric:
            return piduser(p)
        return username(piduser(p))

    fields = dict(
        pid=('PID', lambda n: n, '% 5s', lambda x: len(p)),
        user=('User', showuser, '%-8s', lambda x: len(dict.fromkeys(x))),
        name=('Name', pidname, '%-24.24s', None),
        command=('Command', pidcmd, '%-27.27s', None),
        maps=('Maps',lambda n: pt[n]['maps'], '% 5s', sum),
        swap=('Swap',lambda n: pt[n]['swap'], '% 8a', sum),
        uss=('USS', lambda n: pt[n]['uss'], '% 8a', sum),
        rss=('RSS', lambda n: pt[n]['rss'], '% 8a', sum),
        pss=('PSS', lambda n: pt[n]['pss'], '% 8a', sum),
        vss=('VSS', lambda n: pt[n]['size'], '% 8a', sum),
        auss=('AUSS', lambda n: pt[n]['auss'], '% 8a', sum),
        arss=('ARSS', lambda n: pt[n]['arss'], '% 8a', sum),
        apss=('APSS', lambda n: pt[n]['apss'], '% 8a', sum),
        avss=('AVSS', lambda n: pt[n]['avss'], '% 8a', sum),
        )
    columns = options.columns or 'pid user command swap uss pss rss'

    showtable(p, fields, columns.split(), options.sort or 'pss')

def maptotals(pids):
    totals = {}
    for pid in pids:
        try:
            maps = pidmaps(pid)
            for m in maps.iterkeys():
                name = maps[m]['name']
                if name not in totals:
                    t = dict(size=0, rss=0, pss=0, shared_clean=0,
                             shared_dirty=0, private_clean=0, count=0,
                             private_dirty=0, referenced=0, swap=0)
                else:
                    t = totals[name]

                for k in t:
                    t[k] += maps[m].get(k, 0)
                t['count'] += 1
                totals[name] = t
        except:
            raise
    return totals

def showmaps():
    p = pids()
    pt = maptotals(p)

    fields = dict(
        map=('Map', lambda n: n, '%-24.24s', len),
        count=('Count', lambda n: pt[n]['count'], '% 5s', sum),
        swap=('Swap',lambda n: pt[n]['swap'], '% 8a', sum),
        uss=('USS', lambda n: pt[n]['private_clean']
             + pt[n]['private_dirty'], '% 8a', sum),
        rss=('RSS', lambda n: pt[n]['rss'], '% 8a', sum),
        pss=('PSS', lambda n: pt[n]['pss'], '% 8a', sum),
        )
    columns = options.columns or 'map count swap uss pss rss'

    showtable(pt.keys(), fields, columns.split(), options.sort or 'pss')

def usertotals(pids):
    totals = {}
    for pid in pids:
        try:
            maps = pidmaps(pid)
        except:
            raise
        user = piduser(pid)
        if user not in totals:
            t = dict(size=0, rss=0, pss=0, shared_clean=0,
                     shared_dirty=0, private_clean=0, count=0,
                     private_dirty=0, referenced=0, swap=0)
        else:
            t = totals[user]

        for m in maps.iterkeys():
            for k in t:
                t[k] += maps[m].get(k, 0)

        t['count'] += 1
        totals[user] = t
    return totals

def showusers():
    p = pids()
    pt = usertotals(p)

    def showuser(p):
        if options.numeric:
            return p
        return username(p)

    fields = dict(
        user=('User', showuser, '%-8s', None),
        count=('Count', lambda n: pt[n]['count'], '% 5s', sum),
        swap=('Swap',lambda n: pt[n]['swap'], '% 8a', sum),
        uss=('USS', lambda n: pt[n]['private_clean']
             + pt[n]['private_dirty'], '% 8a', sum),
        rss=('RSS', lambda n: pt[n]['rss'], '% 8a', sum),
        pss=('PSS', lambda n: pt[n]['pss'], '% 8a', sum),
        )
    columns = options.columns or 'user count swap uss pss rss'

    showtable(pt.keys(), fields, columns.split(), options.sort or 'pss')

def showtable(rows, fields, columns, sort):
    header = ""
    format = ""
    formatter = []
    for n in columns:
        f = fields[n][2]
        if 'a' in f:
            formatter.append(showamount)
            f = f.replace('a', 's')
        else:
            formatter.append(lambda x: x)
        format += f + " "
        header += f % fields[n][0] + " "

    if not options.no_header:
        print header

    l = []
    for n in rows:
        r = [fields[c][1](n) for c in columns]
        l.append((fields[sort][1](n), r))

    l.sort(reverse=bool(options.reverse))

    for k,r in l:
        print format % tuple([f(v) for f,v in zip(formatter, r)])

    if options.totals:
        # totals
        t = []
        for c in columns:
            f = fields[c][3]
            if f:
                t.append(f([fields[c][1](n) for n in rows]))
            else:
                t.append("")

        print "-" * len(header)
        print format % tuple([f(v) for f,v in zip(formatter, t)])

parser = optparse.OptionParser("%prog [options]")
parser.add_option("-H", "--no-header", action="store_true",
                  help="disable header line")
parser.add_option("-n", "--numeric", action="store_true",
                  help="numeric output")
parser.add_option("-s", "--sort", type="str",
                  help="field to sort on")
parser.add_option("-t", "--totals", action="store_true",
                  help="show totals")
parser.add_option("-c", "--columns", type="str",
                  help="columns to show")
parser.add_option("", "--realmem", type="str",
                  help="amount of physical RAM")
parser.add_option("-m", "--mappings", action="store_true",
                  help="show mappings")
parser.add_option("-u", "--users", action="store_true",
                  help="show users")
parser.add_option("-r", "--reverse", action="store_true",
                  help="reverse sort")
parser.add_option("-p", "--percent", action="store_true",
                  help="show percentage")
parser.add_option("-k", "--abbreviate", action="store_true",
                  help="show unit suffixes")


defaults = {}
parser.set_defaults(**defaults)
(options, args) = parser.parse_args()

try:
    if options.mappings:
        showmaps()
    elif options.users:
        showusers()
    else:
        showpids()
except IOError, e:
    if e.errno == errno.EPIPE:
        pass
