Source code for sqlalchemy_mptt.tests.test_inheritance

import unittest

import sqlalchemy as sa

from sqlalchemy_mptt.mixins import BaseNestedSets
from sqlalchemy_mptt.sqlalchemy_compat import compat_layer
from sqlalchemy_mptt.tests import (DatabaseSetupMixin, TreeTestingMixin,
                                   failures_expected_on)

Base = compat_layer.declarative_base()


[docs] class GenericTree(Base, BaseNestedSets): __tablename__ = "generic" ppk = sa.Column('idd', sa.Integer, primary_key=True) type = sa.Column(sa.Integer, default=0) visible = sa.Column(sa.Boolean) sqlalchemy_mptt_pk_name = 'ppk' __mapper_args__ = { 'polymorphic_identity': 0, 'polymorphic_on': type, } def __repr__(self): return "<Node (%s)>" % self.ppk
[docs] class SpecializedTree(GenericTree): __tablename__ = "specialized" ppk = sa.Column( 'idd', sa.Integer, sa.ForeignKey(GenericTree.ppk), primary_key=True ) __mapper_args__ = { 'polymorphic_identity': 1, } __table_args__ = tuple()
[docs] class TestTree(DatabaseSetupMixin, unittest.TestCase): base = Base
[docs] def test_create_generic(self): self.session.add(GenericTree(ppk=1)) self.session.commit() tree = compat_layer.get(self.session, GenericTree, 1) self.assertEqual(tree.ppk, 1) self.assertEqual(tree.tree_id, 1)
[docs] def test_create_spec(self): self.session.add(SpecializedTree(ppk=1)) self.session.commit() tree = compat_layer.get(self.session, SpecializedTree, 1) self.assertEqual(tree.ppk, 1) self.assertEqual(tree.tree_id, 1)
[docs] def test_create_delete(self): parent = SpecializedTree(ppk=1) child1 = SpecializedTree(ppk=2, parent=parent) child2 = GenericTree(ppk=3, parent=parent) GenericTree(ppk=4, parent=child2) SpecializedTree(ppk=5, parent=child2) self.session.add(parent) self.session.commit() tree = compat_layer.get(self.session, SpecializedTree, 1) self.assertEqual(tree.ppk, 1) self.assertEqual(tree.tree_id, 1) self.session.delete(child1) self.session.commit() self.assertEqual(None, compat_layer.get(self.session, SpecializedTree, 2)) self.session.delete(child2) self.session.commit() self.assertEqual(None, compat_layer.get(self.session, SpecializedTree, 3)) self.assertEqual(None, compat_layer.get(self.session, SpecializedTree, 4)) self.assertEqual(None, compat_layer.get(self.session, SpecializedTree, 5))
[docs] class TestGenericTree(TreeTestingMixin, unittest.TestCase): base = Base model = GenericTree
[docs] class TestSpecializedTree(TreeTestingMixin, unittest.TestCase): base = Base model = SpecializedTree
[docs] @unittest.expectedFailure def test_rebuild(self): # This test will always fail on specialized classes. super().test_rebuild()
Base2 = compat_layer.declarative_base()
[docs] class BaseInheritance(Base2): __tablename__ = "base_inheritance" ppk = sa.Column('idd', sa.Integer, primary_key=True) type = sa.Column(sa.Integer, default=0) visible = sa.Column(sa.Boolean) __mapper_args__ = { 'polymorphic_identity': 0, 'polymorphic_on': type, } def __repr__(self): return "<Node (%s)>" % self.ppk
[docs] class InheritanceTree(BaseInheritance, BaseNestedSets): __tablename__ = "inheriance_tree" ppk = sa.Column('idd', sa.Integer, sa.ForeignKey(BaseInheritance.ppk), primary_key=True) sqlalchemy_mptt_pk_name = 'ppk' __mapper_args__ = { 'polymorphic_identity': 1, }
[docs] class TestInheritanceTree(TreeTestingMixin, unittest.TestCase): base = Base2 model = InheritanceTree
[docs] @failures_expected_on(sqlalchemy_versions=['1.0', '1.1', '1.2', '1.3']) def test_rebuild(self): super().test_rebuild()