overload & and | operators to chain filters (refs #1426)

This commit is contained in:
Romain Bignon 2014-07-05 19:26:29 +02:00
commit 8efd37e71d
5 changed files with 66 additions and 47 deletions

View file

@ -72,12 +72,12 @@ class PastePage(BasePastebinPage):
self.env['header'] = el.find('//div[@id="content_left"]//div[@class="paste_box_info"]') self.env['header'] = el.find('//div[@id="content_left"]//div[@class="paste_box_info"]')
obj_id = Env('id') obj_id = Env('id')
obj_title = Base(Env('header'), CleanText('.//div[@class="paste_box_line1"]//h1')) obj_title = Base(Env('header')) & CleanText('.//div[@class="paste_box_line1"]//h1')
obj_contents = RawText('//textarea[@id="paste_code"]') obj_contents = RawText('//textarea[@id="paste_code"]')
obj_public = Base( obj_public = Base(Env('header')) \
Env('header'), & Attr('.//div[@class="paste_box_line1"]//img', 'title') \
CleanVisibility(Attr('.//div[@class="paste_box_line1"]//img', 'title'))) & CleanVisibility()
obj__date = Base(Env('header'), DateTime(Attr('.//div[@class="paste_box_line2"]/span[1]', 'title'))) obj__date = Base(Env('header')) & Attr('.//div[@class="paste_box_line2"]/span[1]', 'title') & DateTime()
class PostPage(BasePastebinPage): class PostPage(BasePastebinPage):

View file

@ -20,7 +20,7 @@
from weboob.tools.browser2 import HTMLPage from weboob.tools.browser2 import HTMLPage
from weboob.tools.browser2.page import ListElement, method, ItemElement, pagination from weboob.tools.browser2.page import ListElement, method, ItemElement, pagination
from weboob.tools.browser2.filters import Link, CleanText, Duration, Regexp from weboob.tools.browser2.filters import Link, CleanText, Duration, Regexp, CSS
from weboob.capabilities.base import NotAvailable from weboob.capabilities.base import NotAvailable
from weboob.capabilities.image import BaseImage from weboob.capabilities.image import BaseImage
from weboob.capabilities.video import BaseVideo from weboob.capabilities.video import BaseVideo
@ -40,9 +40,9 @@ class IndexPage(HTMLPage):
class item(ItemElement): class item(ItemElement):
klass = BaseVideo klass = BaseVideo
obj_id = Regexp(Link('.//a'), r'/videos/(.+)\.html') obj_id = CSS('a') & Link() & Regexp(pattern=r'/videos/(.+)\.html')
obj_title = CleanText('.//span[@id="title1"]') obj_title = CSS('span#title1') & CleanText()
obj_duration = Duration(CleanText('.//span[@class="thumbtime"]//span'), default=NotAvailable) obj_duration = CSS('span.thumbtime span') & CleanText() & Duration() | NotAvailable
obj_nsfw = True obj_nsfw = True
def obj_thumbnail(self): def obj_thumbnail(self):

View file

@ -39,7 +39,7 @@ class VideoPage(HTMLPage):
obj_title = CleanText('//title') obj_title = CleanText('//title')
obj_nsfw = True obj_nsfw = True
obj_ext = u'flv' obj_ext = u'flv'
obj_duration = Duration(CleanText('//div[@id="video_text"]')) obj_duration = CleanText('//div[@id="video_text"]') & Duration()
def obj_url(self): def obj_url(self):
real_id = int(self.env['id'].split('-')[-1]) real_id = int(self.env['id'].split('-')[-1])

View file

@ -66,6 +66,14 @@ class _Filter(object):
self._creation_counter = _Filter._creation_counter self._creation_counter = _Filter._creation_counter
_Filter._creation_counter += 1 _Filter._creation_counter += 1
def __or__(self, o):
self.default = o
return self
def __and__(self, o):
o.selector = self
return o
def default_or_raise(self, exception): def default_or_raise(self, exception):
if self.default is not _NO_DEFAULT: if self.default is not _NO_DEFAULT:
return self.default return self.default
@ -110,6 +118,41 @@ class Filter(_Filter):
raise NotImplementedError() raise NotImplementedError()
class _Selector(Filter):
def filter(self, txt):
if txt is not None:
return txt
else:
return self.default_or_raise(ParseError('Element %r not found' % self.selector))
class Dict(_Selector):
@classmethod
def select(cls, selector, item):
if isinstance(item, dict):
content = item
else:
content = item.el
for el in selector.split('/'):
if el not in content:
return None
content = content.get(el)
return content
class CSS(_Selector):
@classmethod
def select(cls, selector, item):
return item.cssselect(selector)
class XPath(_Selector):
pass
class Base(Filter): class Base(Filter):
""" """
Change the base element used in filters. Change the base element used in filters.
@ -119,7 +162,7 @@ class Base(Filter):
base = self.select(self.base, item) base = self.select(self.base, item)
return self.selector(base) return self.selector(base)
def __init__(self, base, selector, default=_NO_DEFAULT): def __init__(self, base, selector=None, default=_NO_DEFAULT):
super(Base, self).__init__(selector, default) super(Base, self).__init__(selector, default)
self.base = base self.base = base
@ -173,34 +216,6 @@ class TableCell(_Filter):
return self.default_or_raise(ColumnNotFound('Unable to find column %s' % ' or '.join(self.names))) return self.default_or_raise(ColumnNotFound('Unable to find column %s' % ' or '.join(self.names)))
class Dict(Filter):
@classmethod
def select(cls, selector, item):
if isinstance(selector, basestring):
if isinstance(item, dict):
content = item
else:
content = item.el
for el in selector.split('/'):
if el not in content:
return None
content = content.get(el)
return content
elif callable(selector):
return selector(item)
else:
return selector
def filter(self, txt):
if txt is not None:
return txt
else:
return self.default_or_raise(ParseError())
class CleanHTML(Filter): class CleanHTML(Filter):
def filter(self, txt): def filter(self, txt):
if isinstance(txt, (tuple, list)): if isinstance(txt, (tuple, list)):
@ -234,7 +249,7 @@ class CleanText(Filter):
Second, it replaces all symbols given in second argument. Second, it replaces all symbols given in second argument.
""" """
def __init__(self, selector, symbols='', replace=[], childs=True, **kwargs): def __init__(self, selector=None, symbols='', replace=[], childs=True, **kwargs):
super(CleanText, self).__init__(selector, **kwargs) super(CleanText, self).__init__(selector, **kwargs)
self.symbols = symbols self.symbols = symbols
self.toreplace = replace self.toreplace = replace
@ -283,7 +298,7 @@ class CleanDecimal(CleanText):
Get a cleaned Decimal value from an element. Get a cleaned Decimal value from an element.
""" """
def __init__(self, selector, replace_dots=True, default=_NO_DEFAULT): def __init__(self, selector=None, replace_dots=True, default=_NO_DEFAULT):
super(CleanDecimal, self).__init__(selector, default=default) super(CleanDecimal, self).__init__(selector, default=default)
self.replace_dots = replace_dots self.replace_dots = replace_dots
@ -318,7 +333,7 @@ class Link(Attr):
If the <a> tag is not found, an exception IndexError is raised. If the <a> tag is not found, an exception IndexError is raised.
""" """
def __init__(self, selector, default=_NO_DEFAULT): def __init__(self, selector=None, default=_NO_DEFAULT):
super(Link, self).__init__(selector, 'href', default=default) super(Link, self).__init__(selector, 'href', default=default)
@ -345,8 +360,9 @@ class Regexp(Filter):
u'1988-08-13' u'1988-08-13'
""" """
def __init__(self, selector, pattern, template=None, flags=0, default=_NO_DEFAULT): def __init__(self, selector=None, pattern=None, template=None, flags=0, default=_NO_DEFAULT):
super(Regexp, self).__init__(selector, default=default) super(Regexp, self).__init__(selector, default=default)
assert pattern is not None
self.pattern = pattern self.pattern = pattern
self.regex = re.compile(pattern, flags) self.regex = re.compile(pattern, flags)
self.template = template self.template = template
@ -379,7 +395,7 @@ class Map(Filter):
class DateTime(Filter): class DateTime(Filter):
def __init__(self, selector, default=_NO_DEFAULT, dayfirst=False, translations=None): def __init__(self, selector=None, default=_NO_DEFAULT, dayfirst=False, translations=None):
super(DateTime, self).__init__(selector, default=default) super(DateTime, self).__init__(selector, default=default)
self.dayfirst = dayfirst self.dayfirst = dayfirst
self.translations = translations self.translations = translations
@ -397,7 +413,7 @@ class DateTime(Filter):
class Date(DateTime): class Date(DateTime):
def __init__(self, selector, default=_NO_DEFAULT, dayfirst=False, translations=None): def __init__(self, selector=None, default=_NO_DEFAULT, dayfirst=False, translations=None):
super(Date, self).__init__(selector, default=default, dayfirst=dayfirst, translations=translations) super(Date, self).__init__(selector, default=default, dayfirst=dayfirst, translations=translations)
def filter(self, txt): def filter(self, txt):
@ -435,7 +451,7 @@ class Time(Filter):
regexp = re.compile(r'(?P<hh>\d+):?(?P<mm>\d+)(:(?P<ss>\d+))?') regexp = re.compile(r'(?P<hh>\d+):?(?P<mm>\d+)(:(?P<ss>\d+))?')
kwargs = {'hour': 'hh', 'minute': 'mm', 'second': 'ss'} kwargs = {'hour': 'hh', 'minute': 'mm', 'second': 'ss'}
def __init__(self, selector, default=_NO_DEFAULT): def __init__(self, selector=None, default=_NO_DEFAULT):
super(Time, self).__init__(selector, default=default) super(Time, self).__init__(selector, default=default)
def filter(self, txt): def filter(self, txt):
@ -486,7 +502,7 @@ class Format(MultiFilter):
class Join(Filter): class Join(Filter):
def __init__(self, pattern, selector, textCleaner=CleanText): def __init__(self, pattern, selector=None, textCleaner=CleanText):
super(Join, self).__init__(selector) super(Join, self).__init__(selector)
self.pattern = pattern self.pattern = pattern
self.textCleaner = textCleaner self.textCleaner = textCleaner

View file

@ -642,6 +642,9 @@ class AbstractElement(object):
def parse(self, obj): def parse(self, obj):
pass pass
def cssselect(self, *args, **kwargs):
return self.el.cssselect(*args, **kwargs)
def xpath(self, *args, **kwargs): def xpath(self, *args, **kwargs):
return self.el.xpath(*args, **kwargs) return self.el.xpath(*args, **kwargs)