overload & and | operators to chain filters (refs #1426)

This commit is contained in:
Romain Bignon 2014-07-05 19:26:29 +02:00
commit 8efd37e71d
5 changed files with 66 additions and 47 deletions

View file

@ -66,6 +66,14 @@ class _Filter(object):
self._creation_counter = _Filter._creation_counter
_Filter._creation_counter += 1
def __or__(self, o):
self.default = o
return self
def __and__(self, o):
o.selector = self
return o
def default_or_raise(self, exception):
if self.default is not _NO_DEFAULT:
return self.default
@ -110,6 +118,41 @@ class Filter(_Filter):
raise NotImplementedError()
class _Selector(Filter):
def filter(self, txt):
if txt is not None:
return txt
else:
return self.default_or_raise(ParseError('Element %r not found' % self.selector))
class Dict(_Selector):
@classmethod
def select(cls, selector, item):
if isinstance(item, dict):
content = item
else:
content = item.el
for el in selector.split('/'):
if el not in content:
return None
content = content.get(el)
return content
class CSS(_Selector):
@classmethod
def select(cls, selector, item):
return item.cssselect(selector)
class XPath(_Selector):
pass
class Base(Filter):
"""
Change the base element used in filters.
@ -119,7 +162,7 @@ class Base(Filter):
base = self.select(self.base, item)
return self.selector(base)
def __init__(self, base, selector, default=_NO_DEFAULT):
def __init__(self, base, selector=None, default=_NO_DEFAULT):
super(Base, self).__init__(selector, default)
self.base = base
@ -173,34 +216,6 @@ class TableCell(_Filter):
return self.default_or_raise(ColumnNotFound('Unable to find column %s' % ' or '.join(self.names)))
class Dict(Filter):
@classmethod
def select(cls, selector, item):
if isinstance(selector, basestring):
if isinstance(item, dict):
content = item
else:
content = item.el
for el in selector.split('/'):
if el not in content:
return None
content = content.get(el)
return content
elif callable(selector):
return selector(item)
else:
return selector
def filter(self, txt):
if txt is not None:
return txt
else:
return self.default_or_raise(ParseError())
class CleanHTML(Filter):
def filter(self, txt):
if isinstance(txt, (tuple, list)):
@ -234,7 +249,7 @@ class CleanText(Filter):
Second, it replaces all symbols given in second argument.
"""
def __init__(self, selector, symbols='', replace=[], childs=True, **kwargs):
def __init__(self, selector=None, symbols='', replace=[], childs=True, **kwargs):
super(CleanText, self).__init__(selector, **kwargs)
self.symbols = symbols
self.toreplace = replace
@ -283,7 +298,7 @@ class CleanDecimal(CleanText):
Get a cleaned Decimal value from an element.
"""
def __init__(self, selector, replace_dots=True, default=_NO_DEFAULT):
def __init__(self, selector=None, replace_dots=True, default=_NO_DEFAULT):
super(CleanDecimal, self).__init__(selector, default=default)
self.replace_dots = replace_dots
@ -318,7 +333,7 @@ class Link(Attr):
If the <a> tag is not found, an exception IndexError is raised.
"""
def __init__(self, selector, default=_NO_DEFAULT):
def __init__(self, selector=None, default=_NO_DEFAULT):
super(Link, self).__init__(selector, 'href', default=default)
@ -345,8 +360,9 @@ class Regexp(Filter):
u'1988-08-13'
"""
def __init__(self, selector, pattern, template=None, flags=0, default=_NO_DEFAULT):
def __init__(self, selector=None, pattern=None, template=None, flags=0, default=_NO_DEFAULT):
super(Regexp, self).__init__(selector, default=default)
assert pattern is not None
self.pattern = pattern
self.regex = re.compile(pattern, flags)
self.template = template
@ -379,7 +395,7 @@ class Map(Filter):
class DateTime(Filter):
def __init__(self, selector, default=_NO_DEFAULT, dayfirst=False, translations=None):
def __init__(self, selector=None, default=_NO_DEFAULT, dayfirst=False, translations=None):
super(DateTime, self).__init__(selector, default=default)
self.dayfirst = dayfirst
self.translations = translations
@ -397,7 +413,7 @@ class DateTime(Filter):
class Date(DateTime):
def __init__(self, selector, default=_NO_DEFAULT, dayfirst=False, translations=None):
def __init__(self, selector=None, default=_NO_DEFAULT, dayfirst=False, translations=None):
super(Date, self).__init__(selector, default=default, dayfirst=dayfirst, translations=translations)
def filter(self, txt):
@ -435,7 +451,7 @@ class Time(Filter):
regexp = re.compile(r'(?P<hh>\d+):?(?P<mm>\d+)(:(?P<ss>\d+))?')
kwargs = {'hour': 'hh', 'minute': 'mm', 'second': 'ss'}
def __init__(self, selector, default=_NO_DEFAULT):
def __init__(self, selector=None, default=_NO_DEFAULT):
super(Time, self).__init__(selector, default=default)
def filter(self, txt):
@ -486,7 +502,7 @@ class Format(MultiFilter):
class Join(Filter):
def __init__(self, pattern, selector, textCleaner=CleanText):
def __init__(self, pattern, selector=None, textCleaner=CleanText):
super(Join, self).__init__(selector)
self.pattern = pattern
self.textCleaner = textCleaner

View file

@ -642,6 +642,9 @@ class AbstractElement(object):
def parse(self, obj):
pass
def cssselect(self, *args, **kwargs):
return self.el.cssselect(*args, **kwargs)
def xpath(self, *args, **kwargs):
return self.el.xpath(*args, **kwargs)