from contextlib import closing
from bitstring import Bits
from django.db import models
from django.db.backends.postgresql_psycopg2.base import DatabaseWrapper as PGDatabaseWrapper
from django.db.backends.signals import connection_created
from psycopg2 import extensions as ext
try:
from rest_framework import serializers,exceptions
rfw = True
except ImportError:
rfw = False
__all__ = ['Bits', 'BitStringField', 'BitStringExpression']
def adapt_bits(bits):
"""psycopg2 adapter function for ``bitstring.Bits``.
Encode SQL parameters from ``bitstring.Bits`` instances to SQL strings.
"""
if bits.length % 4 == 0:
return ext.AsIs("X'%s'" % (bits.hex,))
return ext.AsIs("B'%s'" % (bits.bin,))
ext.register_adapter(Bits, adapt_bits)
def cast_bits(value, cur):
"""psycopg2 caster for bit strings.
Turns query results from the database into ``bitstring.Bits`` instances.
"""
if value is None:
return None
return Bits(bin=value)
def register_bitstring_types(connection):
"""Register the BIT and VARBIT casters on the provided connection.
This ensures that BIT and VARBIT instances returned from the database will
be represented in Python as ``bitstring.Bits`` instances.
"""
with closing(connection.cursor()) as cur:
cur.execute("SELECT NULL::BIT")
bit_oid = cur.description[0].type_code
cur.execute("SELECT NULL::VARBIT")
varbit_oid = cur.description[0].type_code
bit_caster = ext.new_type((bit_oid, varbit_oid), 'BIT', cast_bits)
ext.register_type(bit_caster, connection)
def register_types_on_connection_creation(connection, sender, *args, **kwargs):
if not issubclass(sender, PGDatabaseWrapper):
return
register_bitstring_types(connection.connection)
connection_created.connect(register_types_on_connection_creation)
if rfw:
class SerializerBitStringField(serializers.Field):
def __init__(self,*arg,fix_length=None,**kw):
self.fix_length = fix_length
super().__init__(*arg,**kw)
def to_representation(self, value):
return value.hex
def to_internal_value(self, data):
print(len(data)*4,self.fix_length)
if self.fix_length and len(data)*4 != self.fix_length:
raise exceptions.ValidationError("invalid size for bitstring")
return Bits(hex=data)
[docs]class BitStringField(models.Field):
"""A Postgres bit string."""
def __init__(self, *args, **kwargs):
self.max_length = kwargs.setdefault('max_length', 1)
self.varying = kwargs.pop('varying', False)
if 'default' in kwargs:
default = kwargs.pop('default')
elif kwargs.get('null', False):
default = None
elif self.max_length is not None and not self.varying:
default = '0' * self.max_length
else:
default = '0'
kwargs['default'] = self.to_python(default)
super(BitStringField, self).__init__(*args, **kwargs)
[docs] def db_type(self, connection):
if self.varying:
if self.max_length is not None:
return 'VARBIT(%d)' % (self.max_length,)
return 'VARBIT'
elif self.max_length is not None:
return 'BIT(%d)' % (self.max_length,)
return 'BIT'
[docs] def to_python(self, value):
if value is None or isinstance(value, Bits):
return value
elif isinstance(value, str):
if value.startswith('0x'):
return Bits(hex=value)
return Bits(bin=value)
raise TypeError("Cannot coerce into bit string: %r" % (value,))
[docs] def get_prep_value(self, value):
return self.to_python(value)
[docs] def get_prep_lookup(self, lookup_type, value):
if lookup_type == 'exact':
return self.get_prep_value(value)
elif lookup_type == 'in':
return map(self.get_prep_value, value)
raise TypeError("Lookup type %r not supported on bit strings" % lookup_type)
[docs] def get_default(self):
default = super(BitStringField, self).get_default()
return self.to_python(default)
from django.db.models import Lookup
@BitStringField.register_lookup
class BitstringAND(Lookup):
lookup_name = 'and'
def process_lhs(self, compiler, connection, lhs=None):
ret = super().process_lhs(compiler,connection,lhs=lhs)
print("process_lhs",ret)
return ret
def process_rhs(self, compiler, connection):
ret = super().process_rhs(compiler, connection)
print("process_rhs", ret)
return ret
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return " position(B'1' IN {0} & '{1}' ) > 0".format(lhs, self.rhs.bin),[]
@BitStringField.register_lookup
class BitstringOR(Lookup):
lookup_name = 'or'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return " position(B'1' IN %s | '%s') > 0"%(self.lhs.alias + '.' + self.lhs.field.name, self.rhs.bin),params
@BitStringField.register_lookup
class BitstringXOR(Lookup):
lookup_name = 'xor'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return " position(B'1' IN %s # '%s') > 0"%(self.lhs.alias + '.' + self.lhs.field.name, self.rhs.bin),params
@BitStringField.register_lookup
class BitstringContains(Lookup):
lookup_name = 'superset'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
print("params",params," position(B'1' IN ~{0} & '{1}') <= 0 and position(B'1' IN '{1}') >= 0".format(lhs, self.rhs.bin))
return " position(B'1' IN ~{0} & '{1}') <= 0 and position(B'1' IN '{1}') >= 0".format(lhs, self.rhs.bin), []
@BitStringField.register_lookup
class BitstringContains(Lookup):
lookup_name = 'subset'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
print(" position(B'1' IN ~B'{1}' & B'{0}') <= 0 and position(B'1' IN B'{0}') >= 0".format(lhs, self.rhs.bin))
return " position(B'1' IN ~B'{1}' & {0}) <= 0 and position(B'1' IN {0}) >= 0".format(lhs, self.rhs.bin), []
@BitStringField.register_lookup
class BitstringContains(Lookup):
lookup_name = 'intersects'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return " position(B'1' IN {0} & '{1}') > 0".format(lhs, self.rhs.bin), []
@BitStringField.register_lookup
class BitstringContains(Lookup):
lookup_name = 'disjoint'
def as_sql(self, compiler, connection):
lhs, lhs_params = self.process_lhs(compiler, connection)
rhs, rhs_params = self.process_rhs(compiler, connection)
params = lhs_params + rhs_params
return " position(B'1' IN {0} & '{1}') <= 0".format(lhs, self.rhs.bin), []